From ec5d87e1a5684871f6a8f62ddf1855f26daf3c40 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 1 Jan 2025 14:01:43 -0500 Subject: [PATCH] custom discreteMaxProduct --- gtsam/hybrid/HybridBayesNet.cpp | 23 ++++++++++++++++++++++- gtsam/hybrid/HybridBayesNet.h | 4 ++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 7691bb209..66e4011dc 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,7 +19,9 @@ #include #include #include +#include #include +#include #include #include @@ -119,6 +121,23 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } +DiscreteValues HybridBayesNet::discreteMaxProduct( + const DiscreteFactorGraph &dfg) const { + TableFactor product = TableProductAndNormalize(dfg); + + uint64_t maxIdx = 0; + double maxValue = 0.0; + Eigen::SparseVector sparseTable = product.sparseTable(); + for (TableFactor::SparseIt it(sparseTable); it; ++it) { + if (it.value() > maxValue) { + maxIdx = it.index(); + } + } + + DiscreteValues assignment = product.findAssignments(maxIdx); + return assignment; +} + /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE @@ -131,7 +150,7 @@ HybridValues HybridBayesNet::optimize() const { } // Solve for the MPE - DiscreteValues mpe = discrete_fg.optimize(); + DiscreteValues mpe = this->discreteMaxProduct(discrete_fg); // Given the MPE, compute the optimal continuous values. return HybridValues(optimize(mpe), mpe); @@ -191,6 +210,8 @@ AlgebraicDecisionTree HybridBayesNet::errorTree( // Iterate over each conditional. for (auto &&conditional : *this) { + conditional->print(); + conditional->errorTree(continuousValues).print("errorTre", DefaultKeyFormatter); result = result + conditional->errorTree(continuousValues); } diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 3e07c71ce..263922636 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -268,6 +268,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @} private: + /// Helper method to compute the max product assignment + /// given a DiscreteFactorGraph + DiscreteValues discreteMaxProduct(const DiscreteFactorGraph &dfg) const; + #if GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access;