toFactorGraph method in HybridBayesNet
parent
ee7a7e0bcf
commit
e30813e81e
|
@ -377,4 +377,27 @@ AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||||
return error_tree.apply([](double error) { return exp(-error); });
|
return error_tree.apply([](double error) { return exp(-error); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
|
const VectorValues &measurements) const {
|
||||||
|
HybridGaussianFactorGraph fg;
|
||||||
|
|
||||||
|
// For all nodes in the Bayes net, if its frontal variable is in measurements,
|
||||||
|
// replace it by a likelihood factor:
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (conditional->frontalsIn(measurements)) {
|
||||||
|
if (auto gc = conditional->asGaussian())
|
||||||
|
fg.push_back(gc->likelihood(measurements));
|
||||||
|
else if (auto gm = conditional->asMixture())
|
||||||
|
fg.push_back(gm->likelihood(measurements));
|
||||||
|
else {
|
||||||
|
throw std::runtime_error("Unknown conditional type");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fg.push_back(conditional);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -229,6 +229,12 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
AlgebraicDecisionTree<Key> probPrime(
|
AlgebraicDecisionTree<Key> probPrime(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph by converting
|
||||||
|
* all conditionals with instantiated measurements into likelihood factors.
|
||||||
|
*/
|
||||||
|
HybridGaussianFactorGraph toFactorGraph(
|
||||||
|
const VectorValues &measurements) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -178,6 +178,16 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Return the error of the underlying conditional.
|
/// Return the error of the underlying conditional.
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
||||||
|
bool frontalsIn(const VectorValues& measurements) const {
|
||||||
|
for (Key key : frontals()) {
|
||||||
|
if (!measurements.exists(key)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -64,30 +64,6 @@ inline HybridBayesNet createHybridBayesNet(int numMeasurements = 1,
|
||||||
return bayesNet;
|
return bayesNet;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph.
|
|
||||||
*/
|
|
||||||
inline HybridGaussianFactorGraph convertBayesNet(
|
|
||||||
const HybridBayesNet& bayesNet, const VectorValues& measurements) {
|
|
||||||
HybridGaussianFactorGraph fg;
|
|
||||||
// For all nodes in the Bayes net, if its frontal variable is in measurements,
|
|
||||||
// replace it by a likelihood factor:
|
|
||||||
for (const HybridConditional::shared_ptr& conditional : bayesNet) {
|
|
||||||
if (measurements.exists(conditional->firstFrontalKey())) {
|
|
||||||
if (auto gc = conditional->asGaussian())
|
|
||||||
fg.push_back(gc->likelihood(measurements));
|
|
||||||
else if (auto gm = conditional->asMixture())
|
|
||||||
fg.push_back(gm->likelihood(measurements));
|
|
||||||
else {
|
|
||||||
throw std::runtime_error("Unknown conditional type");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fg.push_back(conditional);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fg;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a tiny two variable hybrid factor graph which represents a discrete
|
* Create a tiny two variable hybrid factor graph which represents a discrete
|
||||||
* mode and a continuous variable x0, given a number of measurements of the
|
* mode and a continuous variable x0, given a number of measurements of the
|
||||||
|
@ -101,10 +77,10 @@ inline HybridGaussianFactorGraph createHybridGaussianFactorGraph(
|
||||||
auto bayesNet = createHybridBayesNet(numMeasurements, manyModes);
|
auto bayesNet = createHybridBayesNet(numMeasurements, manyModes);
|
||||||
if (measurements) {
|
if (measurements) {
|
||||||
// Use the measurements to create a hybrid factor graph.
|
// Use the measurements to create a hybrid factor graph.
|
||||||
return convertBayesNet(bayesNet, *measurements);
|
return bayesNet.toFactorGraph(*measurements);
|
||||||
} else {
|
} else {
|
||||||
// Sample from the generative model to create a hybrid factor graph.
|
// Sample from the generative model to create a hybrid factor graph.
|
||||||
return convertBayesNet(bayesNet, bayesNet.sample().continuous());
|
return bayesNet.toFactorGraph(bayesNet.sample().continuous());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -735,7 +735,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
|
||||||
// Create Bayes net and convert to factor graph.
|
// Create Bayes net and convert to factor graph.
|
||||||
auto bn = tiny::createHybridBayesNet(numMeasurements, manyModes);
|
auto bn = tiny::createHybridBayesNet(numMeasurements, manyModes);
|
||||||
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
|
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
|
||||||
auto fg = tiny::convertBayesNet(bn, measurements);
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
EXPECT_LONGS_EQUAL(5, fg.size());
|
EXPECT_LONGS_EQUAL(5, fg.size());
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
|
|
|
@ -263,7 +263,7 @@ double GaussianConditional::evaluate(const VectorValues& x) const {
|
||||||
Vector frontalVec = gy.vector(KeyVector(beginFrontals(), endFrontals()));
|
Vector frontalVec = gy.vector(KeyVector(beginFrontals(), endFrontals()));
|
||||||
frontalVec = R().transpose().triangularView<Eigen::Lower>().solve(frontalVec);
|
frontalVec = R().transpose().triangularView<Eigen::Lower>().solve(frontalVec);
|
||||||
|
|
||||||
// Check for indeterminant solution
|
// Check for indeterminate solution
|
||||||
if (frontalVec.hasNaN()) throw IndeterminantLinearSystemException(this->keys().front());
|
if (frontalVec.hasNaN()) throw IndeterminantLinearSystemException(this->keys().front());
|
||||||
|
|
||||||
for (const_iterator it = beginParents(); it!= endParents(); it++)
|
for (const_iterator it = beginParents(); it!= endParents(); it++)
|
||||||
|
|
Loading…
Reference in New Issue