Merge pull request #1373 from borglab/default-ordering-func
commit
690a23043b
|
@ -62,9 +62,17 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
|
|||
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
|
||||
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||
/// The default dense elimination function
|
||||
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
|
||||
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||
boost::shared_ptr<FactorType> >
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateDiscrete(factors, keys); }
|
||||
return EliminateDiscrete(factors, keys);
|
||||
}
|
||||
/// The default ordering generation function
|
||||
static Ordering DefaultOrderingFunc(
|
||||
const FactorGraphType& graph,
|
||||
boost::optional<const VariableIndex&> variableIndex) {
|
||||
return Ordering::Colamd(*variableIndex);
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -58,6 +58,21 @@ namespace gtsam {
|
|||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
||||
KeySet discrete_keys = graph.discreteKeys();
|
||||
for (auto &factor : graph) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
||||
const VariableIndex index(graph);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianFactorGraphTree addGaussian(
|
||||
const GaussianFactorGraphTree &gfgTree,
|
||||
|
@ -448,21 +463,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
|||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||
KeySet discrete_keys = discreteKeys();
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
||||
const VariableIndex index(factors_);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
|
|
|
@ -53,6 +53,15 @@ GTSAM_EXPORT
|
|||
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
|
||||
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
*
|
||||
* @return const Ordering
|
||||
*/
|
||||
GTSAM_EXPORT const Ordering
|
||||
HybridOrdering(const HybridGaussianFactorGraph& graph);
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <>
|
||||
struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||
|
@ -74,6 +83,12 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
|
|||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateHybrid(factors, keys);
|
||||
}
|
||||
/// The default ordering generation function
|
||||
static Ordering DefaultOrderingFunc(
|
||||
const FactorGraphType& graph,
|
||||
boost::optional<const VariableIndex&> variableIndex) {
|
||||
return HybridOrdering(graph);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -228,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
*/
|
||||
double probPrime(const HybridValues& values) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
*
|
||||
* @return const Ordering
|
||||
*/
|
||||
const Ordering getHybridOrdering() const;
|
||||
|
||||
/**
|
||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||
* graph.
|
||||
|
|
|
@ -185,9 +185,8 @@ TEST(HybridBayesNet, OptimizeAssignment) {
|
|||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
|
@ -212,9 +211,8 @@ TEST(HybridBayesNet, Optimize) {
|
|||
TEST(HybridBayesNet, Error) {
|
||||
Switching s(3);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = hybridBayesNet->error(delta.continuous());
|
||||
|
@ -266,9 +264,8 @@ TEST(HybridBayesNet, Error) {
|
|||
TEST(HybridBayesNet, Prune) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
|
@ -284,9 +281,8 @@ TEST(HybridBayesNet, Prune) {
|
|||
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
|
||||
size_t maxNrLeaves = 3;
|
||||
auto discreteConditionals = hybridBayesNet->discreteConditionals();
|
||||
|
@ -353,8 +349,7 @@ TEST(HybridBayesNet, Sampling) {
|
|||
// Create the factor graph from the nonlinear factor graph.
|
||||
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
|
||||
// Eliminate into BN
|
||||
Ordering ordering = fg->getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||
HybridBayesNet::shared_ptr bn = fg->eliminateSequential();
|
||||
|
||||
// Set up sampling
|
||||
std::mt19937_64 gen(11);
|
||||
|
|
|
@ -37,9 +37,8 @@ using symbol_shorthand::X;
|
|||
TEST(HybridBayesTree, OptimizeMultifrontal) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree::shared_ptr hybridBayesTree =
|
||||
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
|
||||
s.linearizedFactorGraph.eliminateMultifrontal();
|
||||
HybridValues delta = hybridBayesTree->optimize();
|
||||
|
||||
VectorValues expectedValues;
|
||||
|
@ -203,16 +202,8 @@ TEST(HybridBayesTree, Choose) {
|
|||
|
||||
GaussianBayesTree gbt = isam.choose(assignment);
|
||||
|
||||
Ordering ordering;
|
||||
ordering += X(0);
|
||||
ordering += X(1);
|
||||
ordering += X(2);
|
||||
ordering += X(3);
|
||||
ordering += M(0);
|
||||
ordering += M(1);
|
||||
ordering += M(2);
|
||||
|
||||
// TODO(Varun) get segfault if ordering not provided
|
||||
// Specify ordering so it matches that of HybridGaussianISAM.
|
||||
Ordering ordering(KeyVector{X(0), X(1), X(2), X(3), M(0), M(1), M(2)});
|
||||
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
|
||||
|
||||
auto expected_gbt = bayesTree->choose(assignment);
|
||||
|
|
|
@ -90,7 +90,7 @@ TEST(HybridEstimation, Full) {
|
|||
}
|
||||
|
||||
HybridBayesNet::shared_ptr bayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
graph.eliminateSequential();
|
||||
|
||||
EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());
|
||||
|
||||
|
@ -481,8 +481,7 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
|
|||
const auto fg = createHybridGaussianFactorGraph();
|
||||
|
||||
// 2. Eliminate into BN
|
||||
const Ordering ordering = fg->getHybridOrdering();
|
||||
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential();
|
||||
|
||||
// Set up sampling
|
||||
std::mt19937_64 rng(11);
|
||||
|
|
|
@ -130,8 +130,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
|||
|
||||
hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
|
||||
|
||||
auto result =
|
||||
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
|
||||
auto result = hfg.eliminateSequential();
|
||||
|
||||
auto dc = result->at(2)->asDiscrete();
|
||||
DiscreteValues dv;
|
||||
|
@ -161,8 +160,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) {
|
|||
// Joint discrete probability table for c1, c2
|
||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||
|
||||
HybridBayesNet::shared_ptr result = hfg.eliminateSequential(
|
||||
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
|
||||
HybridBayesNet::shared_ptr result = hfg.eliminateSequential();
|
||||
|
||||
// There are 4 variables (2 continuous + 2 discrete) in the bayes net.
|
||||
EXPECT_LONGS_EQUAL(4, result->size());
|
||||
|
@ -187,8 +185,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
|||
// variable throws segfault
|
||||
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
||||
|
||||
HybridBayesTree::shared_ptr result =
|
||||
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
|
||||
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();
|
||||
|
||||
// The bayes tree should have 3 cliques
|
||||
EXPECT_LONGS_EQUAL(3, result->size());
|
||||
|
@ -218,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
|
|||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));
|
||||
|
||||
// Get a constrained ordering keeping c1 last
|
||||
auto ordering_full = hfg.getHybridOrdering();
|
||||
auto ordering_full = HybridOrdering(hfg);
|
||||
|
||||
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
||||
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
||||
|
@ -518,8 +515,7 @@ TEST(HybridGaussianFactorGraph, optimize) {
|
|||
|
||||
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
|
||||
|
||||
auto result =
|
||||
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
|
||||
auto result = hfg.eliminateSequential();
|
||||
|
||||
HybridValues hv = result->optimize();
|
||||
|
||||
|
@ -572,9 +568,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
|||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
|
||||
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
|
||||
const HybridValues delta = hybridBayesNet->optimize();
|
||||
const double error = graph.error(delta);
|
||||
|
@ -593,9 +587,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
|||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
|
||||
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.error(delta.continuous());
|
||||
|
@ -684,10 +676,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
|||
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "74/26"));
|
||||
|
||||
// Test elimination
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(0));
|
||||
ordering.push_back(M(0));
|
||||
const auto posterior = fg.eliminateSequential(ordering);
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
}
|
||||
|
||||
|
@ -719,10 +708,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
|||
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "23/77"));
|
||||
|
||||
// Test elimination
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(0));
|
||||
ordering.push_back(M(0));
|
||||
const auto posterior = fg.eliminateSequential(ordering);
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
}
|
||||
|
||||
|
@ -741,11 +727,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
|
|||
EXPECT_LONGS_EQUAL(5, fg.size());
|
||||
|
||||
// Test elimination
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(0));
|
||||
ordering.push_back(M(0));
|
||||
ordering.push_back(M(1));
|
||||
const auto posterior = fg.eliminateSequential(ordering);
|
||||
const auto posterior = fg.eliminateSequential();
|
||||
|
||||
// Compute the log-ratio between the Bayes net and the factor graph.
|
||||
auto compute_ratio = [&](HybridValues *sample) -> double {
|
||||
|
|
|
@ -150,8 +150,7 @@ TEST(HybridSerialization, GaussianMixture) {
|
|||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridSerialization, HybridBayesNet) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential());
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
|
@ -162,9 +161,7 @@ TEST(HybridSerialization, HybridBayesNet) {
|
|||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridSerialization, HybridBayesTree) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
HybridBayesTree hbt = *(s.linearizedFactorGraph.eliminateMultifrontal());
|
||||
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
|
|
|
@ -44,9 +44,16 @@ namespace gtsam {
|
|||
if (orderingType == Ordering::METIS) {
|
||||
Ordering computedOrdering = Ordering::Metis(asDerived());
|
||||
return eliminateSequential(computedOrdering, function, variableIndex);
|
||||
} else {
|
||||
} else if (orderingType == Ordering::COLAMD) {
|
||||
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
|
||||
return eliminateSequential(computedOrdering, function, variableIndex);
|
||||
} else if (orderingType == Ordering::NATURAL) {
|
||||
Ordering computedOrdering = Ordering::Natural(asDerived());
|
||||
return eliminateSequential(computedOrdering, function, variableIndex);
|
||||
} else {
|
||||
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
|
||||
asDerived(), variableIndex);
|
||||
return eliminateSequential(computedOrdering, function, variableIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,9 +107,16 @@ namespace gtsam {
|
|||
if (orderingType == Ordering::METIS) {
|
||||
Ordering computedOrdering = Ordering::Metis(asDerived());
|
||||
return eliminateMultifrontal(computedOrdering, function, variableIndex);
|
||||
} else {
|
||||
} else if (orderingType == Ordering::COLAMD) {
|
||||
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
|
||||
return eliminateMultifrontal(computedOrdering, function, variableIndex);
|
||||
} else if (orderingType == Ordering::NATURAL) {
|
||||
Ordering computedOrdering = Ordering::Natural(asDerived());
|
||||
return eliminateMultifrontal(computedOrdering, function, variableIndex);
|
||||
} else {
|
||||
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
|
||||
asDerived(), variableIndex);
|
||||
return eliminateMultifrontal(computedOrdering, function, variableIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,6 +54,12 @@ namespace gtsam {
|
|||
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminatePreferCholesky(factors, keys); }
|
||||
/// The default ordering generation function
|
||||
static Ordering DefaultOrderingFunc(
|
||||
const FactorGraphType& graph,
|
||||
boost::optional<const VariableIndex&> variableIndex) {
|
||||
return Ordering::Colamd(*variableIndex);
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -46,6 +46,12 @@ namespace gtsam {
|
|||
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateSymbolic(factors, keys); }
|
||||
/// The default ordering generation function
|
||||
static Ordering DefaultOrderingFunc(
|
||||
const FactorGraphType& graph,
|
||||
boost::optional<const VariableIndex&> variableIndex) {
|
||||
return Ordering::Colamd(*variableIndex);
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -25,7 +25,6 @@ from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
|||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
"""Unit tests for HybridGaussianFactorGraph."""
|
||||
|
||||
def test_create(self):
|
||||
"""Test construction of hybrid factor graph."""
|
||||
model = noiseModel.Unit.Create(3)
|
||||
|
@ -42,9 +41,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
hfg.push_back(jf2)
|
||||
hfg.push_back(gmf)
|
||||
|
||||
hbn = hfg.eliminateSequential(
|
||||
Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
|
||||
hfg, [C(0)]))
|
||||
hbn = hfg.eliminateSequential()
|
||||
|
||||
self.assertEqual(hbn.size(), 2)
|
||||
|
||||
|
@ -74,15 +71,14 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
dtf = gtsam.DecisionTreeFactor([(C(0), 2)], "0 1")
|
||||
hfg.push_back(dtf)
|
||||
|
||||
hbn = hfg.eliminateSequential(
|
||||
Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
|
||||
hfg, [C(0)]))
|
||||
hbn = hfg.eliminateSequential()
|
||||
|
||||
hv = hbn.optimize()
|
||||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||
|
||||
@staticmethod
|
||||
def tiny(num_measurements: int = 1, prior_mean: float = 5.0,
|
||||
def tiny(num_measurements: int = 1,
|
||||
prior_mean: float = 5.0,
|
||||
prior_sigma: float = 0.5) -> HybridBayesNet:
|
||||
"""
|
||||
Create a tiny two variable hybrid model which represents
|
||||
|
@ -129,20 +125,23 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
bayesNet2 = self.tiny(prior_sigma=5.0, num_measurements=1)
|
||||
# bn1: # 1/sqrt(2*pi*0.5^2)
|
||||
# bn2: # 1/sqrt(2*pi*5.0^2)
|
||||
expected_ratio = np.sqrt(2*np.pi*5.0**2)/np.sqrt(2*np.pi*0.5**2)
|
||||
expected_ratio = np.sqrt(2 * np.pi * 5.0**2) / np.sqrt(
|
||||
2 * np.pi * 0.5**2)
|
||||
mean0 = HybridValues()
|
||||
mean0.insert(X(0), [5.0])
|
||||
mean0.insert(Z(0), [5.0])
|
||||
mean0.insert(M(0), 0)
|
||||
self.assertAlmostEqual(bayesNet1.evaluate(mean0) /
|
||||
bayesNet2.evaluate(mean0), expected_ratio,
|
||||
bayesNet2.evaluate(mean0),
|
||||
expected_ratio,
|
||||
delta=1e-9)
|
||||
mean1 = HybridValues()
|
||||
mean1.insert(X(0), [5.0])
|
||||
mean1.insert(Z(0), [5.0])
|
||||
mean1.insert(M(0), 1)
|
||||
self.assertAlmostEqual(bayesNet1.evaluate(mean1) /
|
||||
bayesNet2.evaluate(mean1), expected_ratio,
|
||||
bayesNet2.evaluate(mean1),
|
||||
expected_ratio,
|
||||
delta=1e-9)
|
||||
|
||||
@staticmethod
|
||||
|
@ -171,11 +170,13 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return fg
|
||||
|
||||
@classmethod
|
||||
def estimate_marginals(cls, target, proposal_density: HybridBayesNet,
|
||||
def estimate_marginals(cls,
|
||||
target,
|
||||
proposal_density: HybridBayesNet,
|
||||
N=10000):
|
||||
"""Do importance sampling to estimate discrete marginal P(mode)."""
|
||||
# Allocate space for marginals on mode.
|
||||
marginals = np.zeros((2,))
|
||||
marginals = np.zeros((2, ))
|
||||
|
||||
# Do importance sampling.
|
||||
for s in range(N):
|
||||
|
@ -210,14 +211,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return bayesNet.evaluate(x)
|
||||
|
||||
# Create proposal density on (x0, mode), making sure it has same mean:
|
||||
posterior_information = 1/(prior_sigma**2) + 1/(0.5**2)
|
||||
posterior_information = 1 / (prior_sigma**2) + 1 / (0.5**2)
|
||||
posterior_sigma = posterior_information**(-0.5)
|
||||
proposal_density = self.tiny(
|
||||
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
|
||||
proposal_density = self.tiny(num_measurements=0,
|
||||
prior_mean=5.0,
|
||||
prior_sigma=posterior_sigma)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(
|
||||
target=unnormalized_posterior, proposal_density=proposal_density)
|
||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||
proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
# print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
@ -230,10 +232,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertEqual(fg.size(), 3)
|
||||
|
||||
# Test elimination.
|
||||
ordering = gtsam.Ordering()
|
||||
ordering.push_back(X(0))
|
||||
ordering.push_back(M(0))
|
||||
posterior = fg.eliminateSequential(ordering)
|
||||
posterior = fg.eliminateSequential()
|
||||
|
||||
def true_posterior(x):
|
||||
"""Posterior from elimination."""
|
||||
|
@ -241,8 +240,8 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return posterior.evaluate(x)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(
|
||||
target=true_posterior, proposal_density=proposal_density)
|
||||
marginals = self.estimate_marginals(target=true_posterior,
|
||||
proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; z0) = {marginals[0]}")
|
||||
# print(f"P(mode=1; z0) = {marginals[1]}")
|
||||
|
@ -253,8 +252,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
|
||||
@staticmethod
|
||||
def calculate_ratio(bayesNet: HybridBayesNet,
|
||||
fg: HybridGaussianFactorGraph,
|
||||
sample: HybridValues):
|
||||
fg: HybridGaussianFactorGraph, sample: HybridValues):
|
||||
"""Calculate ratio between Bayes net and factor graph."""
|
||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
||||
fg.probPrime(sample) > 0 else 0
|
||||
|
@ -285,14 +283,15 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
return bayesNet.evaluate(x)
|
||||
|
||||
# Create proposal density on (x0, mode), making sure it has same mean:
|
||||
posterior_information = 1/(prior_sigma**2) + 2.0/(3.0**2)
|
||||
posterior_information = 1 / (prior_sigma**2) + 2.0 / (3.0**2)
|
||||
posterior_sigma = posterior_information**(-0.5)
|
||||
proposal_density = self.tiny(
|
||||
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
|
||||
proposal_density = self.tiny(num_measurements=0,
|
||||
prior_mean=5.0,
|
||||
prior_sigma=posterior_sigma)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(
|
||||
target=unnormalized_posterior, proposal_density=proposal_density)
|
||||
marginals = self.estimate_marginals(target=unnormalized_posterior,
|
||||
proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
# print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
@ -319,10 +318,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertAlmostEqual(ratio, expected_ratio)
|
||||
|
||||
# Test elimination.
|
||||
ordering = gtsam.Ordering()
|
||||
ordering.push_back(X(0))
|
||||
ordering.push_back(M(0))
|
||||
posterior = fg.eliminateSequential(ordering)
|
||||
posterior = fg.eliminateSequential()
|
||||
|
||||
# Calculate ratio between Bayes net probability and the factor graph:
|
||||
expected_ratio = self.calculate_ratio(posterior, fg, values)
|
||||
|
|
|
@ -14,22 +14,27 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
|
||||
import gtsam
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import C, X
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
import gtsam
|
||||
|
||||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
"""Unit tests for HybridGaussianFactorGraph."""
|
||||
|
||||
def test_nonlinear_hybrid(self):
|
||||
nlfg = gtsam.HybridNonlinearFactorGraph()
|
||||
dk = gtsam.DiscreteKeys()
|
||||
dk.push_back((10, 2))
|
||||
nlfg.add(gtsam.BetweenFactorPoint3(1, 2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([1, 1, 1])))
|
||||
nlfg.add(
|
||||
gtsam.PriorFactorPoint3(2, gtsam.Point3(1, 2, 3), gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5])))
|
||||
gtsam.BetweenFactorPoint3(
|
||||
1, 2, gtsam.Point3(1, 2, 3),
|
||||
gtsam.noiseModel.Diagonal.Variances([1, 1, 1])))
|
||||
nlfg.add(
|
||||
gtsam.PriorFactorPoint3(
|
||||
2, gtsam.Point3(1, 2, 3),
|
||||
gtsam.noiseModel.Diagonal.Variances([0.5, 0.5, 0.5])))
|
||||
nlfg.push_back(
|
||||
gtsam.MixtureFactor([1], dk, [
|
||||
gtsam.PriorFactorPoint3(1, gtsam.Point3(0, 0, 0),
|
||||
|
@ -42,11 +47,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
values.insert_point3(1, gtsam.Point3(0, 0, 0))
|
||||
values.insert_point3(2, gtsam.Point3(2, 3, 1))
|
||||
hfg = nlfg.linearize(values)
|
||||
o = gtsam.Ordering()
|
||||
o.push_back(1)
|
||||
o.push_back(2)
|
||||
o.push_back(10)
|
||||
hbn = hfg.eliminateSequential(o)
|
||||
hbn = hfg.eliminateSequential()
|
||||
hbv = hbn.optimize()
|
||||
self.assertEqual(hbv.atDiscrete(10), 0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue