diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 99f29b8e5..ffb8d7c69 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -39,10 +39,10 @@ #include #include -using boost::assign::operator+=; - namespace gtsam { + using boost::assign::operator+=; + /****************************************************************************/ // Node /****************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 000057518..b04800d21 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -19,7 +19,7 @@ */ #include -#include +#include #include #include #include @@ -36,8 +36,7 @@ GaussianMixture::GaussianMixture( conditionals_(conditionals) {} /* *******************************************************************************/ -const GaussianMixture::Conditionals & -GaussianMixture::conditionals() { +const GaussianMixture::Conditionals &GaussianMixture::conditionals() { return conditionals_; } @@ -48,8 +47,8 @@ GaussianMixture GaussianMixture::FromConditionals( const std::vector &conditionalsList) { Conditionals dt(discreteParents, conditionalsList); - return GaussianMixture(continuousFrontals, continuousParents, - discreteParents, dt); + return GaussianMixture(continuousFrontals, continuousParents, discreteParents, + dt); } /* *******************************************************************************/ @@ -66,8 +65,7 @@ GaussianMixture::Sum GaussianMixture::add( } /* *******************************************************************************/ -GaussianMixture::Sum -GaussianMixture::asGaussianFactorGraphTree() const { +GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { auto lambda = [](const GaussianFactor::shared_ptr &factor) { GaussianFactorGraph result; result.push_back(factor); @@ -77,21 +75,42 @@ GaussianMixture::asGaussianFactorGraphTree() const { } /* *******************************************************************************/ -bool GaussianMixture::equals(const HybridFactor &lf, - double tol) const { +size_t GaussianMixture::nrComponents() const { + size_t total = 0; + conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { + if (node) total += 1; + }); + return total; +} + +/* *******************************************************************************/ +GaussianConditional::shared_ptr GaussianMixture::operator()( + const DiscreteValues &discreteVals) const { + auto &ptr = conditionals_(discreteVals); + if (!ptr) return nullptr; + auto conditional = boost::dynamic_pointer_cast(ptr); + if (conditional) + return conditional; + else + throw std::logic_error( + "A GaussianMixture unexpectedly contained a non-conditional"); +} + +/* *******************************************************************************/ +bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); return e != nullptr && BaseFactor::equals(*e, tol); } /* *******************************************************************************/ void GaussianMixture::print(const std::string &s, - const KeyFormatter &formatter) const { + const KeyFormatter &formatter) const { std::cout << s; if (isContinuous()) std::cout << "Continuous "; if (isDiscrete()) std::cout << "Discrete "; if (isHybrid()) std::cout << "Hybrid "; BaseConditional::print("", formatter); - std::cout << "\nDiscrete Keys = "; + std::cout << " Discrete Keys = "; for (auto &dk : discreteKeys()) { std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; } diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index e85506715..fc1eb0f06 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -19,7 +19,9 @@ #pragma once +#include #include +#include #include #include #include @@ -99,6 +101,16 @@ class GTSAM_EXPORT GaussianMixture const DiscreteKeys &discreteParents, const std::vector &conditionals); + /// @} + /// @name Standard API + /// @{ + + GaussianConditional::shared_ptr operator()( + const DiscreteValues &discreteVals) const; + + /// Returns the total number of continuous components + size_t nrComponents() const; + /// @} /// @name Testable /// @{ diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index a81cf341d..8f832d8ea 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -51,16 +51,19 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactors( void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); + std::cout << "]{\n"; factors_.print( - "mixture = ", [&](Key k) { return formatter(k); }, + "", [&](Key k) { return formatter(k); }, [&](const GaussianFactor::shared_ptr &gf) -> std::string { RedirectCout rd; - if (!gf->empty()) + std::cout << ":\n"; + if (gf) gf->print("", formatter); else return {"nullptr"}; return rd.str(); }); + std::cout << "}" << std::endl; } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 21770f836..6c90ee6a7 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -84,6 +84,19 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { const DiscreteKeys &discreteKeys, const Factors &factors); + /** + * @brief Construct a new GaussianMixtureFactor object using a vector of + * GaussianFactor shared pointers. + * + * @param keys Vector of keys for continuous factors. + * @param discreteKeys Vector of discrete keys. + * @param factors Vector of gaussian factor shared pointers. + */ + GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys, + const std::vector &factors) + : GaussianMixtureFactor(keys, discreteKeys, + Factors(discreteKeys, factors)) {} + static This FromFactors( const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const std::vector &factors); @@ -111,6 +124,12 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @return Sum */ Sum add(const Sum &sum) const; + + /// Add MixtureFactor to a Sum, syntactic sugar. + friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { + sum = factor.add(sum); + return sum; + } }; // traits diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 2bdcdee8c..0455e1e90 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -47,7 +47,7 @@ bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { void HybridDiscreteFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); - inner_->print("inner: ", formatter); + inner_->print("\n", formatter); }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 127c9761c..8df2d524f 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -50,7 +50,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), isContinuous_(true), nrContinuous_(keys.size()) {} + : Base(keys), + isContinuous_(true), + nrContinuous_(keys.size()), + continuousKeys_(keys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const KeyVector &continuousKeys, @@ -60,13 +63,15 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys, isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), nrContinuous_(continuousKeys.size()), - discreteKeys_(discreteKeys) {} + discreteKeys_(discreteKeys), + continuousKeys_(continuousKeys) {} /* ************************************************************************ */ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true), - discreteKeys_(discreteKeys) {} + discreteKeys_(discreteKeys), + continuousKeys_({}) {} /* ************************************************************************ */ bool HybridFactor::equals(const HybridFactor &lf, double tol) const { @@ -83,7 +88,17 @@ void HybridFactor::print(const std::string &s, if (isContinuous_) std::cout << "Continuous "; if (isDiscrete_) std::cout << "Discrete "; if (isHybrid_) std::cout << "Hybrid "; - this->printKeys("", formatter); + for (size_t c=0; cprint("inner: ", formatter); + inner_->print("\n", formatter); }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp new file mode 100644 index 000000000..420e22315 --- /dev/null +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -0,0 +1,95 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testGaussianMixture.cpp + * @brief Unit tests for GaussianMixture class + * @author Varun Agrawal + * @author Fan Jiang + * @author Frank Dellaert + * @date December 2021 + */ + +#include +#include +#include +#include + +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ************************************************************************* */ +/* Check construction of GaussianMixture P(x1 | x2, m1) as well as accessing a + * specific mode i.e. P(x1 | x2, m1=1). + */ +TEST(GaussianMixture, Equals) { + // create a conditional gaussian node + Matrix S1(2, 2); + S1(0, 0) = 1; + S1(1, 0) = 2; + S1(0, 1) = 3; + S1(1, 1) = 4; + + Matrix S2(2, 2); + S2(0, 0) = 6; + S2(1, 0) = 0.2; + S2(0, 1) = 8; + S2(1, 1) = 0.4; + + Matrix R1(2, 2); + R1(0, 0) = 0.1; + R1(1, 0) = 0.3; + R1(0, 1) = 0.0; + R1(1, 1) = 0.34; + + Matrix R2(2, 2); + R2(0, 0) = 0.1; + R2(1, 0) = 0.3; + R2(0, 1) = 0.0; + R2(1, 1) = 0.34; + + SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); + + Vector2 d1(0.2, 0.5), d2(0.5, 0.2); + + auto conditional0 = boost::make_shared(X(1), d1, R1, + X(2), S1, model), + conditional1 = boost::make_shared(X(1), d2, R2, + X(2), S2, model); + + // Create decision tree + DiscreteKey m1(1, 2); + GaussianMixture::Conditionals conditionals( + {m1}, + vector{conditional0, conditional1}); + GaussianMixture mixtureFactor({X(1)}, {X(2)}, {m1}, conditionals); + + // Let's check that this worked: + DiscreteValues mode; + mode[m1.first] = 1; + auto actual = mixtureFactor(mode); + EXPECT(actual == conditional1); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp new file mode 100644 index 000000000..36477218b --- /dev/null +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -0,0 +1,159 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GaussianMixtureFactor.cpp + * @brief Unit tests for GaussianMixtureFactor + * @author Varun Agrawal + * @author Fan Jiang + * @author Frank Dellaert + * @date December 2021 + */ + +#include +#include +#include +#include +#include +#include + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using noiseModel::Isotropic; +using symbol_shorthand::M; +using symbol_shorthand::X; + +/* ************************************************************************* */ +// Check iterators of empty mixture. +TEST(GaussianMixtureFactor, Constructor) { + GaussianMixtureFactor factor; + GaussianMixtureFactor::const_iterator const_it = factor.begin(); + CHECK(const_it == factor.end()); + GaussianMixtureFactor::iterator it = factor.begin(); + CHECK(it == factor.end()); +} + +/* ************************************************************************* */ +// "Add" two mixture factors together. +TEST(GaussianMixtureFactor, Sum) { + DiscreteKey m1(1, 2), m2(2, 3); + + auto A1 = Matrix::Zero(2, 1); + auto A2 = Matrix::Zero(2, 2); + auto A3 = Matrix::Zero(2, 3); + auto b = Matrix::Zero(2, 1); + Vector2 sigmas; + sigmas << 1, 2; + auto model = noiseModel::Diagonal::Sigmas(sigmas, true); + + auto f10 = boost::make_shared(X(1), A1, X(2), A2, b); + auto f11 = boost::make_shared(X(1), A1, X(2), A2, b); + auto f20 = boost::make_shared(X(1), A1, X(3), A3, b); + auto f21 = boost::make_shared(X(1), A1, X(3), A3, b); + auto f22 = boost::make_shared(X(1), A1, X(3), A3, b); + std::vector factorsA{f10, f11}; + std::vector factorsB{f20, f21, f22}; + + // TODO(Frank): why specify keys at all? And: keys in factor should be *all* + // keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey? + // Design review! + GaussianMixtureFactor mixtureFactorA({X(1), X(2)}, {m1}, factorsA); + GaussianMixtureFactor mixtureFactorB({X(1), X(3)}, {m2}, factorsB); + + // Check that number of keys is 3 + EXPECT_LONGS_EQUAL(3, mixtureFactorA.keys().size()); + + // Check that number of discrete keys is 1 // TODO(Frank): should not exist? + EXPECT_LONGS_EQUAL(1, mixtureFactorA.discreteKeys().size()); + + // Create sum of two mixture factors: it will be a decision tree now on both + // discrete variables m1 and m2: + GaussianMixtureFactor::Sum sum; + sum += mixtureFactorA; + sum += mixtureFactorB; + + // Let's check that this worked: + Assignment mode; + mode[m1.first] = 1; + mode[m2.first] = 2; + auto actual = sum(mode); + EXPECT(actual.at(0) == f11); + EXPECT(actual.at(1) == f22); +} + +TEST(GaussianMixtureFactor, Printing) { + DiscreteKey m1(1, 2); + auto A1 = Matrix::Zero(2, 1); + auto A2 = Matrix::Zero(2, 2); + auto b = Matrix::Zero(2, 1); + auto f10 = boost::make_shared(X(1), A1, X(2), A2, b); + auto f11 = boost::make_shared(X(1), A1, X(2), A2, b); + std::vector factors{f10, f11}; + + GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + + std::string expected = + R"(Hybrid x1 x2; 1 ]{ + Choice(1) + 0 Leaf : + A[x1] = [ + 0; + 0 +] + A[x2] = [ + 0, 0; + 0, 0 +] + b = [ 0 0 ] + No noise model + + 1 Leaf : + A[x1] = [ + 0; + 0 +] + A[x2] = [ + 0, 0; + 0, 0 +] + b = [ 0 0 ] + No noise model + +} +)"; + EXPECT(assert_print_equal(expected, mixtureFactor)); +} + +TEST_UNSAFE(GaussianMixtureFactor, GaussianMixture) { + KeyVector keys; + keys.push_back(X(0)); + keys.push_back(X(1)); + + DiscreteKeys dKeys; + dKeys.emplace_back(M(0), 2); + dKeys.emplace_back(M(1), 2); + + auto gaussians = boost::make_shared(); + GaussianMixture::Conditionals conditionals(gaussians); + GaussianMixture gm({}, keys, dKeys, conditionals); + + EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ \ No newline at end of file