custom discreteMaxProduct
parent
e854d15033
commit
ec5d87e1a5
|
|
@ -19,7 +19,9 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
@ -119,6 +121,23 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
return gbn;
|
return gbn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DiscreteValues HybridBayesNet::discreteMaxProduct(
|
||||||
|
const DiscreteFactorGraph &dfg) const {
|
||||||
|
TableFactor product = TableProductAndNormalize(dfg);
|
||||||
|
|
||||||
|
uint64_t maxIdx = 0;
|
||||||
|
double maxValue = 0.0;
|
||||||
|
Eigen::SparseVector<double> sparseTable = product.sparseTable();
|
||||||
|
for (TableFactor::SparseIt it(sparseTable); it; ++it) {
|
||||||
|
if (it.value() > maxValue) {
|
||||||
|
maxIdx = it.index();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues assignment = product.findAssignments(maxIdx);
|
||||||
|
return assignment;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Collect all the discrete factors to compute MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
|
|
@ -131,7 +150,7 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve for the MPE
|
// Solve for the MPE
|
||||||
DiscreteValues mpe = discrete_fg.optimize();
|
DiscreteValues mpe = this->discreteMaxProduct(discrete_fg);
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
// Given the MPE, compute the optimal continuous values.
|
||||||
return HybridValues(optimize(mpe), mpe);
|
return HybridValues(optimize(mpe), mpe);
|
||||||
|
|
@ -191,6 +210,8 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
|
conditional->print();
|
||||||
|
conditional->errorTree(continuousValues).print("errorTre", DefaultKeyFormatter);
|
||||||
result = result + conditional->errorTree(continuousValues);
|
result = result + conditional->errorTree(continuousValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -268,6 +268,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// Helper method to compute the max product assignment
|
||||||
|
/// given a DiscreteFactorGraph
|
||||||
|
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph &dfg) const;
|
||||||
|
|
||||||
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue