custom discreteMaxProduct
parent
e854d15033
commit
ec5d87e1a5
|
|
@ -19,7 +19,9 @@
|
|||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <memory>
|
||||
|
|
@ -119,6 +121,23 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
return gbn;
|
||||
}
|
||||
|
||||
DiscreteValues HybridBayesNet::discreteMaxProduct(
|
||||
const DiscreteFactorGraph &dfg) const {
|
||||
TableFactor product = TableProductAndNormalize(dfg);
|
||||
|
||||
uint64_t maxIdx = 0;
|
||||
double maxValue = 0.0;
|
||||
Eigen::SparseVector<double> sparseTable = product.sparseTable();
|
||||
for (TableFactor::SparseIt it(sparseTable); it; ++it) {
|
||||
if (it.value() > maxValue) {
|
||||
maxIdx = it.index();
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteValues assignment = product.findAssignments(maxIdx);
|
||||
return assignment;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Collect all the discrete factors to compute MPE
|
||||
|
|
@ -131,7 +150,7 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
}
|
||||
|
||||
// Solve for the MPE
|
||||
DiscreteValues mpe = discrete_fg.optimize();
|
||||
DiscreteValues mpe = this->discreteMaxProduct(discrete_fg);
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
return HybridValues(optimize(mpe), mpe);
|
||||
|
|
@ -191,6 +210,8 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
|||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
conditional->print();
|
||||
conditional->errorTree(continuousValues).print("errorTre", DefaultKeyFormatter);
|
||||
result = result + conditional->errorTree(continuousValues);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -268,6 +268,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// @}
|
||||
|
||||
private:
|
||||
/// Helper method to compute the max product assignment
|
||||
/// given a DiscreteFactorGraph
|
||||
DiscreteValues discreteMaxProduct(const DiscreteFactorGraph &dfg) const;
|
||||
|
||||
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
|||
Loading…
Reference in New Issue