diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c031b9729..09b592bd6 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -153,12 +153,14 @@ std::pair discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; - for (auto &fp : factors) { - if (auto ptr = boost::dynamic_pointer_cast(fp)) { - dfg.push_back(ptr->inner()); - } else if (auto p = - boost::static_pointer_cast(fp)->inner()) { - dfg.push_back(boost::static_pointer_cast(p)); + + for (auto &factor : factors) { + if (auto p = boost::dynamic_pointer_cast(factor)) { + dfg.push_back(p->inner()); + } else if (auto p = boost::static_pointer_cast(factor)) { + auto discrete_conditional = + boost::static_pointer_cast(p->inner()); + dfg.push_back(discrete_conditional); } else { // It is an orphan wrapper } @@ -244,6 +246,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, return exp(-factor->error(empty_values)); }; DecisionTree fdt(separatorFactors, factorError); + auto discreteFactor = boost::make_shared(discreteSeparator, fdt); diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index 3cd21e32e..5e7337d0c 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -100,11 +100,23 @@ class MixtureFactor : public HybridFactor { bool normalized = false) : Base(keys, discreteKeys), normalized_(normalized) { std::vector nonlinear_factors; + KeySet continuous_keys_set(keys.begin(), keys.end()); + KeySet factor_keys_set; for (auto&& f : factors) { + // Insert all factor continuous keys in the continuous keys set. + std::copy(f->keys().begin(), f->keys().end(), + std::inserter(factor_keys_set, factor_keys_set.end())); + nonlinear_factors.push_back( boost::dynamic_pointer_cast(f)); } factors_ = Factors(discreteKeys, nonlinear_factors); + + if (continuous_keys_set != factor_keys_set) { + throw std::runtime_error( + "The specified continuous keys and the keys in the factors don't " + "match!"); + } } ~MixtureFactor() = default; diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 3723e6078..6b689b4e3 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -147,6 +147,32 @@ TEST(HybridGaussianFactorGraph, Resize) { EXPECT_LONGS_EQUAL(gfg.size(), 0); } +/*************************************************************************** + * Test that the MixtureFactor reports correctly if the number of continuous + * keys provided do not match the keys in the factors. + */ +TEST(HybridGaussianFactorGraph, MixtureFactor) { + auto nonlinearFactor = boost::make_shared>( + X(0), X(1), 0.0, Isotropic::Sigma(1, 0.1)); + auto discreteFactor = boost::make_shared(); + + auto noise_model = noiseModel::Isotropic::Sigma(1, 1.0); + auto still = boost::make_shared(X(0), X(1), 0.0, noise_model), + moving = boost::make_shared(X(0), X(1), 1.0, noise_model); + + std::vector components = {still, moving}; + + // Check for exception when number of continuous keys are under-specified. + KeyVector contKeys = {X(0)}; + THROWS_EXCEPTION(boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components)); + + // Check for exception when number of continuous keys are too many. + contKeys = {X(0), X(1), X(2)}; + THROWS_EXCEPTION(boost::make_shared( + contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, components)); +} + /***************************************************************************** * Test push_back on HFG makes the correct distinction. */ @@ -348,7 +374,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT(discreteFactor->root_->isLeaf() == false); - //TODO(Varun) Test emplace_discrete + // TODO(Varun) Test emplace_discrete } /****************************************************************************