toFactorGraph method in HybridBayesNet

release/4.3a0
Frank Dellaert 2023-01-04 19:52:37 -08:00
parent ee7a7e0bcf
commit e30813e81e
6 changed files with 44 additions and 29 deletions

View File

@ -377,4 +377,27 @@ AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
return error_tree.apply([](double error) { return exp(-error); }); return error_tree.apply([](double error) { return exp(-error); });
} }
/* ************************************************************************* */
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
const VectorValues &measurements) const {
HybridGaussianFactorGraph fg;
// For all nodes in the Bayes net, if its frontal variable is in measurements,
// replace it by a likelihood factor:
for (auto &&conditional : *this) {
if (conditional->frontalsIn(measurements)) {
if (auto gc = conditional->asGaussian())
fg.push_back(gc->likelihood(measurements));
else if (auto gm = conditional->asMixture())
fg.push_back(gm->likelihood(measurements));
else {
throw std::runtime_error("Unknown conditional type");
}
} else {
fg.push_back(conditional);
}
}
return fg;
}
} // namespace gtsam } // namespace gtsam

View File

@ -229,6 +229,12 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
AlgebraicDecisionTree<Key> probPrime( AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const; const VectorValues &continuousValues) const;
/**
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph by converting
* all conditionals with instantiated measurements into likelihood factors.
*/
HybridGaussianFactorGraph toFactorGraph(
const VectorValues &measurements) const;
/// @} /// @}
private: private:

View File

@ -178,6 +178,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional. /// Return the error of the underlying conditional.
double error(const HybridValues& values) const override; double error(const HybridValues& values) const override;
/// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const {
for (Key key : frontals()) {
if (!measurements.exists(key)) {
return false;
}
}
return true;
}
/// @} /// @}
private: private:

View File

@ -64,30 +64,6 @@ inline HybridBayesNet createHybridBayesNet(int numMeasurements = 1,
return bayesNet; return bayesNet;
} }
/**
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph.
*/
inline HybridGaussianFactorGraph convertBayesNet(
const HybridBayesNet& bayesNet, const VectorValues& measurements) {
HybridGaussianFactorGraph fg;
// For all nodes in the Bayes net, if its frontal variable is in measurements,
// replace it by a likelihood factor:
for (const HybridConditional::shared_ptr& conditional : bayesNet) {
if (measurements.exists(conditional->firstFrontalKey())) {
if (auto gc = conditional->asGaussian())
fg.push_back(gc->likelihood(measurements));
else if (auto gm = conditional->asMixture())
fg.push_back(gm->likelihood(measurements));
else {
throw std::runtime_error("Unknown conditional type");
}
} else {
fg.push_back(conditional);
}
}
return fg;
}
/** /**
* Create a tiny two variable hybrid factor graph which represents a discrete * Create a tiny two variable hybrid factor graph which represents a discrete
* mode and a continuous variable x0, given a number of measurements of the * mode and a continuous variable x0, given a number of measurements of the
@ -101,10 +77,10 @@ inline HybridGaussianFactorGraph createHybridGaussianFactorGraph(
auto bayesNet = createHybridBayesNet(numMeasurements, manyModes); auto bayesNet = createHybridBayesNet(numMeasurements, manyModes);
if (measurements) { if (measurements) {
// Use the measurements to create a hybrid factor graph. // Use the measurements to create a hybrid factor graph.
return convertBayesNet(bayesNet, *measurements); return bayesNet.toFactorGraph(*measurements);
} else { } else {
// Sample from the generative model to create a hybrid factor graph. // Sample from the generative model to create a hybrid factor graph.
return convertBayesNet(bayesNet, bayesNet.sample().continuous()); return bayesNet.toFactorGraph(bayesNet.sample().continuous());
} }
} }

View File

@ -735,7 +735,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
// Create Bayes net and convert to factor graph. // Create Bayes net and convert to factor graph.
auto bn = tiny::createHybridBayesNet(numMeasurements, manyModes); auto bn = tiny::createHybridBayesNet(numMeasurements, manyModes);
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}; const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
auto fg = tiny::convertBayesNet(bn, measurements); auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(5, fg.size()); EXPECT_LONGS_EQUAL(5, fg.size());
// Test elimination // Test elimination

View File

@ -263,7 +263,7 @@ double GaussianConditional::evaluate(const VectorValues& x) const {
Vector frontalVec = gy.vector(KeyVector(beginFrontals(), endFrontals())); Vector frontalVec = gy.vector(KeyVector(beginFrontals(), endFrontals()));
frontalVec = R().transpose().triangularView<Eigen::Lower>().solve(frontalVec); frontalVec = R().transpose().triangularView<Eigen::Lower>().solve(frontalVec);
// Check for indeterminant solution // Check for indeterminate solution
if (frontalVec.hasNaN()) throw IndeterminantLinearSystemException(this->keys().front()); if (frontalVec.hasNaN()) throw IndeterminantLinearSystemException(this->keys().front());
for (const_iterator it = beginParents(); it!= endParents(); it++) for (const_iterator it = beginParents(); it!= endParents(); it++)