diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index f1b79b123..f05dfd423 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -298,7 +298,6 @@ static std::shared_ptr createDiscreteFactor( // exp(-factor->error(kEmpty)) / conditional->normalizationConstant(); // We take negative of the logNormalizationConstant `log(1/k)` // to get `log(k)`. - // factor->print("Discrete Separator"); return -factor->error(kEmpty) + (-conditional->logNormalizationConstant()); }; diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 006a3b026..c03df4d61 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -161,6 +161,35 @@ TEST(MixtureFactor, DifferentMeans) { DiscreteValues{{M(1), 1}, {M(2), 0}}); EXPECT(assert_equal(expected, actual)); + + { + DiscreteValues dv{{M(1), 0}, {M(2), 0}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9); + } + { + DiscreteValues dv{{M(1), 0}, {M(2), 1}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(1.77418393408, error, 1e-9); + } + { + DiscreteValues dv{{M(1), 1}, {M(2), 0}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9); + } + { + DiscreteValues dv{{M(1), 1}, {M(2), 1}}; + VectorValues cont = bn->optimize(dv); + double error = bn->error(HybridValues(cont, dv)); + // regression + EXPECT_DOUBLES_EQUAL(1.10751726741, error, 1e-9); + } } /* ************************************************************************* */ @@ -217,15 +246,35 @@ TEST(MixtureFactor, DifferentCovariances) { auto hbn = mixture_fg.eliminateSequential(); - HybridValues actual_values = hbn->optimize(); - VectorValues cv; cv.insert(X(1), Vector1(0.0)); cv.insert(X(2), Vector1(0.0)); + + // Check that we get different error values at the MLE point μ. + AlgebraicDecisionTree errorTree = hbn->errorTree(cv); + auto cond0 = hbn->at(0)->asMixture(); + auto cond1 = hbn->at(1)->asMixture(); + auto discrete_cond = hbn->at(2)->asDiscrete(); + + HybridValues hv0(cv, DiscreteValues{{M(1), 0}}); + HybridValues hv1(cv, DiscreteValues{{M(1), 1}}); + AlgebraicDecisionTree expectedErrorTree( + m1, + cond0->error(hv0) // cond0(0)->logNormalizationConstant() + // - cond0(1)->logNormalizationConstant + + cond1->error(hv0) + discrete_cond->error(DiscreteValues{{M(1), 0}}), + cond0->error(hv1) // cond1(0)->logNormalizationConstant() + // - cond1(1)->logNormalizationConstant + + cond1->error(hv1) + + discrete_cond->error(DiscreteValues{{M(1), 0}})); + EXPECT(assert_equal(expectedErrorTree, errorTree)); + DiscreteValues dv; dv.insert({M(1), 1}); HybridValues expected_values(cv, dv); + HybridValues actual_values = hbn->optimize(); + EXPECT(assert_equal(expected_values, actual_values)); }