custom discreteMaxProduct

release/4.3a0
Varun Agrawal 2025-01-01 14:01:43 -05:00
parent e854d15033
commit ec5d87e1a5
2 changed files with 26 additions and 1 deletions

View File

@ -19,7 +19,9 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteTableConditional.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <memory> #include <memory>
@ -119,6 +121,23 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
DiscreteValues HybridBayesNet::discreteMaxProduct(
const DiscreteFactorGraph &dfg) const {
TableFactor product = TableProductAndNormalize(dfg);
uint64_t maxIdx = 0;
double maxValue = 0.0;
Eigen::SparseVector<double> sparseTable = product.sparseTable();
for (TableFactor::SparseIt it(sparseTable); it; ++it) {
if (it.value() > maxValue) {
maxIdx = it.index();
}
}
DiscreteValues assignment = product.findAssignments(maxIdx);
return assignment;
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE // Collect all the discrete factors to compute MPE
@ -131,7 +150,7 @@ HybridValues HybridBayesNet::optimize() const {
} }
// Solve for the MPE // Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize(); DiscreteValues mpe = this->discreteMaxProduct(discrete_fg);
// Given the MPE, compute the optimal continuous values. // Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe); return HybridValues(optimize(mpe), mpe);
@ -191,6 +210,8 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
conditional->print();
conditional->errorTree(continuousValues).print("errorTre", DefaultKeyFormatter);
result = result + conditional->errorTree(continuousValues); result = result + conditional->errorTree(continuousValues);
} }

View File

@ -268,6 +268,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @} /// @}
private: private:
/// Helper method to compute the max product assignment
/// given a DiscreteFactorGraph
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph &dfg) const;
#if GTSAM_ENABLE_BOOST_SERIALIZATION #if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;