diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f72743206..d01c91840 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -64,6 +64,9 @@ namespace gtsam { */ size_t nrAssignments_; + /// Default constructor for serialization. + Leaf() {} + /// Constructor from constant Leaf(const Y& constant, size_t nrAssignments = 1) : constant_(constant), nrAssignments_(nrAssignments) {} @@ -154,6 +157,18 @@ namespace gtsam { } bool isLeaf() const override { return true; } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(constant_); + ar& BOOST_SERIALIZATION_NVP(nrAssignments_); + } }; // Leaf /****************************************************************************/ @@ -177,6 +192,9 @@ namespace gtsam { using ChoicePtr = boost::shared_ptr; public: + /// Default constructor for serialization. + Choice() {} + ~Choice() override { #ifdef DT_DEBUG_MEMORY std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() @@ -428,6 +446,19 @@ namespace gtsam { r->push_back(branch->choose(label, index)); return Unique(r); } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(label_); + ar& BOOST_SERIALIZATION_NVP(branches_); + ar& BOOST_SERIALIZATION_NVP(allSame_); + } }; // Choice /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 957a4eb48..a8764a98f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,9 +19,11 @@ #pragma once +#include #include #include +#include #include #include #include @@ -113,6 +115,12 @@ namespace gtsam { virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr choose(const L& label, size_t index) const = 0; virtual bool isLeaf() const = 0; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) {} }; /** ------------------------ Node base class --------------------------- */ @@ -236,7 +244,7 @@ namespace gtsam { /** * @brief Visit all leaves in depth-first fashion. * - * @param f (side-effect) Function taking a value. + * @param f (side-effect) Function taking the value of the leaf node. * * @note Due to pruning, the number of leaves may not be the same as the * number of assignments. E.g. if we have a tree on 2 binary variables with @@ -245,7 +253,7 @@ namespace gtsam { * Example: * int sum = 0; * auto visitor = [&](int y) { sum += y; }; - * tree.visitWith(visitor); + * tree.visit(visitor); */ template void visit(Func f) const; @@ -261,8 +269,8 @@ namespace gtsam { * * Example: * int sum = 0; - * auto visitor = [&](int y) { sum += y; }; - * tree.visitWith(visitor); + * auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); }; + * tree.visitLeaf(visitor); */ template void visitLeaf(Func f) const; @@ -364,8 +372,19 @@ namespace gtsam { compose(Iterator begin, Iterator end, const L& label) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(root_); + } }; // DecisionTree + template + struct traits> : public Testable> {}; + /** free versions of apply */ /// Apply unary operator `op` to DecisionTree `f`. diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4e16fc689..7f604086c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -156,9 +156,9 @@ namespace gtsam { std::vector> DecisionTreeFactor::enumerate() const { // Get all possible assignments - std::vector> pairs = discreteKeys(); + DiscreteKeys pairs = discreteKeys(); // Reverse to make cartesian product output a more natural ordering. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); + DiscreteKeys rpairs(pairs.rbegin(), pairs.rend()); const auto assignments = DiscreteValues::CartesianProduct(rpairs); // Construct unordered_map with values diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index bb9ddbd96..f1df7ae03 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -231,6 +231,16 @@ namespace gtsam { const Names& names = {}) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT); + ar& BOOST_SERIALIZATION_NVP(cardinalities_); + } }; // traits diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 6a286633d..b68953eb5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional /// Internal version of choose DiscreteConditional::ADT choose(const DiscreteValues& given, bool forceComplete) const; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } }; // DiscreteConditional diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 63b14b05e..46e6c3813 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,12 +20,11 @@ // #define DT_DEBUG_MEMORY // #define GTSAM_DT_NO_PRUNING #define DISABLE_DOT -#include - -#include -#include - #include +#include +#include +#include +#include using namespace std; using namespace gtsam; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 1829db034..869b3c630 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 9098f7a1d..3df4bf9e6 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -17,13 +17,14 @@ * @date Feb 14, 2011 */ -#include - #include +#include #include #include #include +#include + using namespace std; using namespace gtsam; @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) { DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional prior(B % "1/2"); DiscreteConditional pAB = prior * conditional; - GTSAM_PRINT(pAB); // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 DiscreteConditional actualA = pAB.marginal(A.first); diff --git a/gtsam/discrete/tests/testSerializationDiscrete.cpp b/gtsam/discrete/tests/testSerializationDiscrete.cpp new file mode 100644 index 000000000..df7df0b7e --- /dev/null +++ b/gtsam/discrete/tests/testSerializationDiscrete.cpp @@ -0,0 +1,105 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testSerializtionDiscrete.cpp + * + * @date January 2023 + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +using Tree = gtsam::DecisionTree; + +BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt") +BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf") +BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice") + +BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); + +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf") +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") + +/* ****************************************************************************/ +// Test DecisionTree serialization. +TEST(DiscreteSerialization, DecisionTree) { + Tree tree({{"A", 2}}, std::vector{1, 2}); + + using namespace serializationTestHelpers; + + // Object roundtrip + Tree outputObj = create(); + roundtrip(tree, outputObj); + EXPECT(tree.equals(outputObj)); + + // XML roundtrip + Tree outputXml = create(); + roundtripXML(tree, outputXml); + EXPECT(tree.equals(outputXml)); + + // Binary roundtrip + Tree outputBinary = create(); + roundtripBinary(tree, outputBinary); + EXPECT(tree.equals(outputBinary)); +} + +/* ************************************************************************* */ +// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor +TEST(DiscreteSerialization, DecisionTreeFactor) { + using namespace serializationTestHelpers; + + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + + DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(tree)); + EXPECT(equalsXML(tree)); + EXPECT(equalsBinary(tree)); + + DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + EXPECT(equalsObj(f)); + EXPECT(equalsXML(f)); + EXPECT(equalsBinary(f)); +} + +/* ************************************************************************* */ +// Check serialization for DiscreteConditional & DiscreteDistribution +TEST(DiscreteSerialization, DiscreteConditional) { + using namespace serializationTestHelpers; + + DiscreteKey A(Symbol('x', 1), 3); + DiscreteConditional conditional(A % "1/2/2"); + + EXPECT(equalsObj(conditional)); + EXPECT(equalsXML(conditional)); + EXPECT(equalsBinary(conditional)); + + DiscreteDistribution P(A % "3/2/1"); + EXPECT(equalsObj(P)); + EXPECT(equalsXML(P)); + EXPECT(equalsBinary(P)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 892a07d2d..12e88f81d 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -51,24 +51,28 @@ GaussianMixture::GaussianMixture( Conditionals(discreteParents, conditionalsList)) {} /* *******************************************************************************/ -GaussianMixture::Sum GaussianMixture::add( - const GaussianMixture::Sum &sum) const { - using Y = GaussianFactorGraph; +GaussianFactorGraphTree GaussianMixture::add( + const GaussianFactorGraphTree &sum) const { + using Y = GraphAndConstant; auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; + auto result = graph1.graph; + result.push_back(graph2.graph); + return Y(result, graph1.constant + graph2.constant); }; - const Sum tree = asGaussianFactorGraphTree(); + const auto tree = asGaussianFactorGraphTree(); return sum.empty() ? tree : sum.apply(tree, add); } /* *******************************************************************************/ -GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { - auto lambda = [](const GaussianFactor::shared_ptr &factor) { +GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { + auto lambda = [](const GaussianConditional::shared_ptr &conditional) { GaussianFactorGraph result; - result.push_back(factor); - return result; + result.push_back(conditional); + if (conditional) { + return GraphAndConstant(result, conditional->logNormalizationConstant()); + } else { + return GraphAndConstant(result, 0.0); + } }; return {conditionals_, lambda}; } @@ -98,7 +102,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()( /* *******************************************************************************/ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - return e != nullptr && BaseFactor::equals(*e, tol); + if (e == nullptr) return false; + + // This will return false if either conditionals_ is empty or e->conditionals_ + // is empty, but not if both are empty or both are not empty: + if (conditionals_.empty() ^ e->conditionals_.empty()) return false; + + // Check the base and the factors: + return BaseFactor::equals(*e, tol) && + conditionals_.equals(e->conditionals_, + [tol](const GaussianConditional::shared_ptr &f1, + const GaussianConditional::shared_ptr &f2) { + return f1->equals(*(f2), tol); + }); } /* *******************************************************************************/ @@ -146,7 +162,13 @@ KeyVector GaussianMixture::continuousParents() const { /* ************************************************************************* */ boost::shared_ptr GaussianMixture::likelihood( const VectorValues &frontals) const { - // TODO(dellaert): check that values has all frontals + // Check that values has all frontals + for (auto &&kv : frontals) { + if (frontals.find(kv.first) == frontals.end()) { + throw std::runtime_error("GaussianMixture: frontals missing factor key."); + } + } + const DiscreteKeys discreteParentKeys = discreteKeys(); const KeyVector continuousParentKeys = continuousParents(); const GaussianMixtureFactor::Factors likelihoods( diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a9b05f250..ba84b5ade 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -23,13 +23,13 @@ #include #include #include +#include #include #include #include namespace gtsam { -class GaussianMixtureFactor; class HybridValues; /** @@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture using BaseFactor = HybridFactor; using BaseConditional = Conditional; - /// Alias for DecisionTree of GaussianFactorGraphs - using Sum = DecisionTree; - /// typedef for Decision Tree of Gaussian Conditionals using Conditionals = DecisionTree; @@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. */ - Sum asGaussianFactorGraphTree() const; + GaussianFactorGraphTree asGaussianFactorGraphTree() const; /** * @brief Helper function to get the pruner functor. @@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture */ double error(const HybridValues &values) const override; + // /// Calculate probability density for given values `x`. + // double evaluate(const HybridValues &values) const; + + // /// Evaluate probability density, sugar. + // double operator()(const HybridValues &values) const { return + // evaluate(values); } + + // /// Calculate log-density for given values `x`. + // double logDensity(const HybridValues &values) const; + /** * @brief Prune the decision tree of Gaussian factors as per the discrete * `decisionTree`. @@ -186,10 +193,20 @@ class GTSAM_EXPORT GaussianMixture * maintaining the decision tree structure. * * @param sum Decision Tree of Gaussian Factor Graphs - * @return Sum + * @return GaussianFactorGraphTree */ - Sum add(const Sum &sum) const; + GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar &BOOST_SERIALIZATION_NVP(conditionals_); + } }; /// Return the DiscreteKey vector as a set. diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index e60368717..57f42e6f1 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -81,32 +81,36 @@ void GaussianMixtureFactor::print(const std::string &s, } /* *******************************************************************************/ -const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { - return Mixture(factors_, [](const FactorAndConstant &factor_z) { - return factor_z.factor; - }); +GaussianFactor::shared_ptr GaussianMixtureFactor::factor( + const DiscreteValues &assignment) const { + return factors_(assignment).factor; } /* *******************************************************************************/ -GaussianMixtureFactor::Sum GaussianMixtureFactor::add( - const GaussianMixtureFactor::Sum &sum) const { - using Y = GaussianFactorGraph; +double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const { + return factors_(assignment).constant; +} + +/* *******************************************************************************/ +GaussianFactorGraphTree GaussianMixtureFactor::add( + const GaussianFactorGraphTree &sum) const { + using Y = GraphAndConstant; auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; + auto result = graph1.graph; + result.push_back(graph2.graph); + return Y(result, graph1.constant + graph2.constant); }; - const Sum tree = asGaussianFactorGraphTree(); + const auto tree = asGaussianFactorGraphTree(); return sum.empty() ? tree : sum.apply(tree, add); } /* *******************************************************************************/ -GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() +GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree() const { auto wrap = [](const FactorAndConstant &factor_z) { GaussianFactorGraph result; result.push_back(factor_z.factor); - return result; + return GraphAndConstant(result, factor_z.constant); }; return {factors_, wrap}; } diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index ce011fecc..01de2f0f7 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -25,10 +25,10 @@ #include #include #include +#include namespace gtsam { -class GaussianFactorGraph; class HybridValues; class DiscreteValues; class VectorValues; @@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using This = GaussianMixtureFactor; using shared_ptr = boost::shared_ptr; - using Sum = DecisionTree; using sharedFactor = boost::shared_ptr; /// Gaussian factor and log of normalizing constant. @@ -60,8 +59,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { // Return error with constant correction. double error(const VectorValues &values) const { - // Note minus sign: constant is log of normalization constant for probabilities. - // Errors is the negative log-likelihood, hence we subtract the constant here. + // Note: constant is log of normalization constant for probabilities. + // Errors is the negative log-likelihood, + // hence we subtract the constant here. + if (!factor) return 0.0; // If nullptr, return 0.0 error return factor->error(values) - constant; } @@ -69,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { bool operator==(const FactorAndConstant &other) const { return factor == other.factor && constant == other.constant; } + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_NVP(factor); + ar &BOOST_SERIALIZATION_NVP(constant); + } }; /// typedef for Decision Tree of Gaussian factors and log-constant. @@ -83,9 +93,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * @brief Helper function to return factors and functional to create a * DecisionTree of Gaussian Factor Graphs. * - * @return Sum (DecisionTree) + * @return GaussianFactorGraphTree */ - Sum asGaussianFactorGraphTree() const; + GaussianFactorGraphTree asGaussianFactorGraphTree() const; public: /// @name Constructors @@ -135,12 +145,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { void print( const std::string &s = "GaussianMixtureFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; + /// @} /// @name Standard API /// @{ - /// Getter for the underlying Gaussian Factor Decision Tree. - const Mixture factors() const; + /// Get factor at a given discrete assignment. + sharedFactor factor(const DiscreteValues &assignment) const; + + /// Get constant at a given discrete assignment. + double constant(const DiscreteValues &assignment) const; /** * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while @@ -150,7 +164,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { * variables. * @return Sum */ - Sum add(const Sum &sum) const; + GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; /** * @brief Compute error of the GaussianMixtureFactor as a tree. @@ -168,11 +182,21 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { double error(const HybridValues &values) const override; /// Add MixtureFactor to a Sum, syntactic sugar. - friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { + friend GaussianFactorGraphTree &operator+=( + GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) { sum = factor.add(sum); return sum; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(factors_); + } }; // traits diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e471cb02f..4404ccfdc 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42); namespace gtsam { +/* ************************************************************************* */ +void HybridBayesNet::print(const std::string &s, + const KeyFormatter &formatter) const { + Base::print(s, formatter); +} + +/* ************************************************************************* */ +bool HybridBayesNet::equals(const This &bn, double tol) const { + return Base::equals(bn, tol); +} + /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { AlgebraicDecisionTree decisionTree; @@ -271,12 +282,15 @@ double HybridBayesNet::evaluate(const HybridValues &values) const { // Iterate over each conditional. for (auto &&conditional : *this) { + // TODO: should be delegated to derived classes. if (auto gm = conditional->asMixture()) { const auto component = (*gm)(discreteValues); logDensity += component->logDensity(continuousValues); + } else if (auto gc = conditional->asGaussian()) { // If continuous only, evaluate the probability and multiply. logDensity += gc->logDensity(continuousValues); + } else if (auto dc = conditional->asDiscrete()) { // Conditional is discrete-only, so return its probability. probability *= dc->operator()(discreteValues); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0d2c337b7..dcdf3a8e5 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -50,18 +50,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// @name Testable /// @{ - /** Check equality */ - bool equals(const This &bn, double tol = 1e-9) const { - return Base::equals(bn, tol); - } - - /// print graph + /// GTSAM-style printing void print( const std::string &s = "", - const KeyFormatter &formatter = DefaultKeyFormatter) const override { - Base::print(s, formatter); - } + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + /// GTSAM-style equals + bool equals(const This& fg, double tol = 1e-9) const; + /// @} /// @name Standard Interface /// @{ diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 8e071532d..0bfcfec4d 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -102,7 +103,37 @@ void HybridConditional::print(const std::string &s, /* ************************************************************************ */ bool HybridConditional::equals(const HybridFactor &other, double tol) const { const This *e = dynamic_cast(&other); - return e != nullptr && BaseFactor::equals(*e, tol); + if (e == nullptr) return false; + if (auto gm = asMixture()) { + auto other = e->asMixture(); + return other != nullptr && gm->equals(*other, tol); + } + if (auto gc = asGaussian()) { + auto other = e->asGaussian(); + return other != nullptr && gc->equals(*other, tol); + } + if (auto dc = asDiscrete()) { + auto other = e->asDiscrete(); + return other != nullptr && dc->equals(*other, tol); + } + + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); +} + +/* ************************************************************************ */ +double HybridConditional::error(const HybridValues &values) const { + if (auto gm = asMixture()) { + return gm->error(values); + } + if (auto gc = asGaussian()) { + return gc->error(values.continuous()); + } + if (auto dc = asDiscrete()) { + return -log((*dc)(values.discrete())); + } + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); } } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index e949fb865..021ca1361 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -176,15 +176,7 @@ class GTSAM_EXPORT HybridConditional boost::shared_ptr inner() const { return inner_; } /// Return the error of the underlying conditional. - /// Currently only implemented for Gaussian mixture. - double error(const HybridValues& values) const override { - if (auto gm = asMixture()) { - return gm->error(values); - } else { - throw std::runtime_error( - "HybridConditional::error: only implemented for Gaussian mixture"); - } - } + double error(const HybridValues& values) const override; /// @} @@ -195,6 +187,20 @@ class GTSAM_EXPORT HybridConditional void serialize(Archive& ar, const unsigned int /*version*/) { ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar& BOOST_SERIALIZATION_NVP(inner_); + + // register the various casts based on the type of inner_ + // https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting + if (isDiscrete()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else if (isContinuous()) { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } else { + boost::serialization::void_cast_register( + static_cast(NULL), static_cast(NULL)); + } } }; // HybridConditional diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 605ea5738..afdb6472a 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -26,7 +26,6 @@ namespace gtsam { /* ************************************************************************ */ -// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right! HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) : Base(boost::dynamic_pointer_cast(other) ->discreteKeys()), @@ -40,8 +39,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) /* ************************************************************************ */ bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { const This *e = dynamic_cast(&lf); - // TODO(Varun) How to compare inner_ when they are abstract types? - return e != nullptr && Base::equals(*e, tol); + if (e == nullptr) return false; + if (!Base::equals(*e, tol)) return false; + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 7ac97443a..7a43ab3a5 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { /// @name Constructors /// @{ + /// Default constructor - for serialization. + HybridDiscreteFactor() = default; + // Implicit conversion from a shared ptr of DF HybridDiscreteFactor(DiscreteFactor::shared_ptr other); @@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { /// Return the error of the underlying Discrete Factor. double error(const HybridValues &values) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(inner_); + } }; // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index a28fee8ed..8c1b0dad3 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -28,6 +30,36 @@ namespace gtsam { class HybridValues; +/// Gaussian factor graph and log of normalizing constant. +struct GraphAndConstant { + GaussianFactorGraph graph; + double constant; + + GraphAndConstant(const GaussianFactorGraph &graph, double constant) + : graph(graph), constant(constant) {} + + // Check pointer equality. + bool operator==(const GraphAndConstant &other) const { + return graph == other.graph && constant == other.constant; + } + + // Implement GTSAM-style print: + void print(const std::string &s = "Graph: ", + const KeyFormatter &formatter = DefaultKeyFormatter) const { + graph.print(s, formatter); + std::cout << "Constant: " << constant << std::endl; + } + + // Implement GTSAM-style equals: + bool equals(const GraphAndConstant &other, double tol = 1e-9) const { + return graph.equals(other.graph, tol) && + fabs(constant - other.constant) < tol; + } +}; + +/// Alias for DecisionTree of GaussianFactorGraphs +using GaussianFactorGraphTree = DecisionTree; + KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); @@ -160,4 +192,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { template <> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 5a89a04a8..4fe18bea7 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf) /* ************************************************************************* */ bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { const This *e = dynamic_cast(&other); - // TODO(Varun) How to compare inner_ when they are abstract types? - return e != nullptr && Base::equals(*e, tol); + if (e == nullptr) return false; + if (!Base::equals(*e, tol)) return false; + return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false) + : !(e->inner_); } /* ************************************************************************* */ void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); - inner_->print("\n", formatter); + if (inner_) { + inner_->print("\n", formatter); + } else { + std::cout << "\nGaussian: nullptr" << std::endl; + } }; /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index dc2f62857..6bb022396 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -43,14 +43,17 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { using This = HybridGaussianFactor; using shared_ptr = boost::shared_ptr; + /// @name Constructors + /// @{ + + /// Default constructor - for serialization. HybridGaussianFactor() = default; /** * Constructor from shared_ptr of GaussianFactor. * Example: - * boost::shared_ptr ptr = - * boost::make_shared(...); - * + * auto ptr = boost::make_shared(...); + * HybridGaussianFactor factor(ptr); */ explicit HybridGaussianFactor(const boost::shared_ptr &ptr); @@ -80,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { */ explicit HybridGaussianFactor(HessianFactor &&hf); - public: + /// @} /// @name Testable /// @{ @@ -99,9 +102,18 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { /// Return pointer to the internal Gaussian factor. GaussianFactor::shared_ptr inner() const { return inner_; } - /// Return the error of the underlying Discrete Factor. + /// Return the error of the underlying Gaussian factor. double error(const HybridValues &values) const override; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(inner_); + } }; // traits diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 6af0fb1a9..f6b713a76 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -59,51 +59,50 @@ namespace gtsam { template class EliminateableFactorGraph; /* ************************************************************************ */ -static GaussianMixtureFactor::Sum &addGaussian( - GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { - using Y = GaussianFactorGraph; +static GaussianFactorGraphTree addGaussian( + const GaussianFactorGraphTree &gfgTree, + const GaussianFactor::shared_ptr &factor) { // If the decision tree is not initialized, then initialize it. - if (sum.empty()) { + if (gfgTree.empty()) { GaussianFactorGraph result; result.push_back(factor); - sum = GaussianMixtureFactor::Sum(result); + return GaussianFactorGraphTree(GraphAndConstant(result, 0.0)); } else { - auto add = [&factor](const Y &graph) { - auto result = graph; + auto add = [&factor](const GraphAndConstant &graph_z) { + auto result = graph_z.graph; result.push_back(factor); - return result; + return GraphAndConstant(result, graph_z.constant); }; - sum = sum.apply(add); + return gfgTree.apply(add); } - return sum; } /* ************************************************************************ */ -GaussianMixtureFactor::Sum sumFrontals( - const HybridGaussianFactorGraph &factors) { - // sum out frontals, this is the factor on the separator - gttic(sum); +// TODO(dellaert): Implementation-wise, it's probably more efficient to first +// collect the discrete keys, and then loop over all assignments to populate a +// vector. +GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { + gttic(assembleGraphTree); - GaussianMixtureFactor::Sum sum; - std::vector deferredFactors; + GaussianFactorGraphTree result; - for (auto &f : factors) { + for (auto &f : factors_) { + // TODO(dellaert): just use a virtual method defined in HybridFactor. if (f->isHybrid()) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. if (auto gm = boost::dynamic_pointer_cast(f)) { - sum = gm->add(sum); + result = gm->add(result); } if (auto gm = boost::dynamic_pointer_cast(f)) { - sum = gm->asMixture()->add(sum); + result = gm->asMixture()->add(result); } } else if (f->isContinuous()) { if (auto gf = boost::dynamic_pointer_cast(f)) { - deferredFactors.push_back(gf->inner()); + result = addGaussian(result, gf->inner()); } if (auto cg = boost::dynamic_pointer_cast(f)) { - deferredFactors.push_back(cg->asGaussian()); + result = addGaussian(result, cg->asGaussian()); } } else if (f->isDiscrete()) { @@ -125,17 +124,13 @@ GaussianMixtureFactor::Sum sumFrontals( } } - for (auto &f : deferredFactors) { - sum = addGaussian(sum, f); - } + gttoc(assembleGraphTree); - gttoc(sum); - - return sum; + return result; } /* ************************************************************************ */ -std::pair +static std::pair continuousElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { GaussianFactorGraph gfg; @@ -156,7 +151,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors, } /* ************************************************************************ */ -std::pair +static std::pair discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; @@ -173,48 +168,53 @@ discreteElimination(const HybridGaussianFactorGraph &factors, } } - auto result = EliminateForMPE(dfg, frontalKeys); + // NOTE: This does sum-product. For max-product, use EliminateForMPE. + auto result = EliminateDiscrete(dfg, frontalKeys); return {boost::make_shared(result.first), boost::make_shared(result.second)}; } /* ************************************************************************ */ -std::pair +// If any GaussianFactorGraph in the decision tree contains a nullptr, convert +// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will +// otherwise create a GFG with a single (null) factor. +GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { + auto emptyGaussian = [](const GraphAndConstant &graph_z) { + bool hasNull = + std::any_of(graph_z.graph.begin(), graph_z.graph.end(), + [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); + return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z; + }; + return GaussianFactorGraphTree(sum, emptyGaussian); +} + +/* ************************************************************************ */ +static std::pair hybridElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys, - const KeySet &continuousSeparator, + const KeyVector &continuousSeparator, const std::set &discreteSeparatorSet) { // NOTE: since we use the special JunctionTree, // only possibility is continuous conditioned on discrete. DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); - // sum out frontals, this is the factor 𝜏 on the separator - GaussianMixtureFactor::Sum sum = sumFrontals(factors); + // Collect all the factors to create a set of Gaussian factor graphs in a + // decision tree indexed by all discrete keys involved. + GaussianFactorGraphTree sum = factors.assembleGraphTree(); - // If a tree leaf contains nullptr, - // convert that leaf to an empty GaussianFactorGraph. - // Needed since the DecisionTree will otherwise create - // a GFG with a single (null) factor. - auto emptyGaussian = [](const GaussianFactorGraph &gfg) { - bool hasNull = - std::any_of(gfg.begin(), gfg.end(), - [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); - - return hasNull ? GaussianFactorGraph() : gfg; - }; - sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); + // Convert factor graphs with a nullptr to an empty factor graph. + // This is done after assembly since it is non-trivial to keep track of which + // FG has a nullptr as we're looping over the factors. + sum = removeEmpty(sum); using EliminationPair = std::pair, GaussianMixtureFactor::FactorAndConstant>; - KeyVector keysOfEliminated; // Not the ordering - KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? - // This is the elimination method on the leaf nodes - auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { - if (graph.empty()) { + auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair { + if (graph_z.graph.empty()) { return {nullptr, {nullptr, 0.0}}; } @@ -222,24 +222,34 @@ hybridElimination(const HybridGaussianFactorGraph &factors, gttic_(hybrid_eliminate); #endif - std::pair, - boost::shared_ptr> - conditional_factor = EliminatePreferCholesky(graph, frontalKeys); + boost::shared_ptr conditional; + boost::shared_ptr newFactor; + boost::tie(conditional, newFactor) = + EliminatePreferCholesky(graph_z.graph, frontalKeys); - // Initialize the keysOfEliminated to be the keys of the - // eliminated GaussianConditional - keysOfEliminated = conditional_factor.first->keys(); - keysOfSeparator = conditional_factor.second->keys(); + // Get the log of the log normalization constant inverse and + // add it to the previous constant. + const double logZ = + graph_z.constant - conditional->logNormalizationConstant(); + // Get the log of the log normalization constant inverse. + // double logZ = -conditional->logNormalizationConstant(); + // // IF this is the last continuous variable to eliminated, we need to + // // calculate the error here: the value of all factors at the mean, see + // // ml_map_rao.pdf. + // if (continuousSeparator.empty()) { + // const auto posterior_mean = conditional->solve(VectorValues()); + // logZ += graph_z.graph.error(posterior_mean); + // } #ifdef HYBRID_TIMING gttoc_(hybrid_eliminate); #endif - return {conditional_factor.first, {conditional_factor.second, 0.0}}; + return {conditional, {newFactor, logZ}}; }; // Perform elimination! - DecisionTree eliminationResults(sum, eliminate); + DecisionTree eliminationResults(sum, eliminateFunc); #ifdef HYBRID_TIMING tictoc_print_(); @@ -247,46 +257,50 @@ hybridElimination(const HybridGaussianFactorGraph &factors, #endif // Separate out decision tree into conditionals and remaining factors. - auto pair = unzip(eliminationResults); - const auto &separatorFactors = pair.second; + GaussianMixture::Conditionals conditionals; + GaussianMixtureFactor::Factors newFactors; + std::tie(conditionals, newFactors) = unzip(eliminationResults); // Create the GaussianMixture from the conditionals - auto conditional = boost::make_shared( - frontalKeys, keysOfSeparator, discreteSeparator, pair.first); + auto gaussianMixture = boost::make_shared( + frontalKeys, continuousSeparator, discreteSeparator, conditionals); - // If there are no more continuous parents, then we should create here a - // DiscreteFactor, with the error for each discrete choice. - if (keysOfSeparator.empty()) { - VectorValues empty_values; + // If there are no more continuous parents, then we should create a + // DiscreteFactor here, with the error for each discrete choice. + if (continuousSeparator.empty()) { auto factorProb = [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { - GaussianFactor::shared_ptr factor = factor_z.factor; - if (!factor) { - return 0.0; // If nullptr, return 0.0 probability - } else { - // This is the probability q(μ) at the MLE point. - double error = - 0.5 * std::abs(factor->augmentedInformation().determinant()) + - factor_z.constant; - return std::exp(-error); - } + // This is the probability q(μ) at the MLE point. + // factor_z.factor is a factor without keys, + // just containing the residual. + return exp(-factor_z.error(VectorValues())); }; - DecisionTree fdt(separatorFactors, factorProb); - auto discreteFactor = + const DecisionTree fdt(newFactors, factorProb); + // // Normalize the values of decision tree to be valid probabilities + // double sum = 0.0; + // auto visitor = [&](double y) { sum += y; }; + // fdt.visit(visitor); + // // Check if sum is 0, and update accordingly. + // if (sum == 0) { + // sum = 1.0; + // } + // fdt = DecisionTree(fdt, + // [sum](const double &x) { return x / sum; + // }); + const auto discreteFactor = boost::make_shared(discreteSeparator, fdt); - return {boost::make_shared(conditional), + return {boost::make_shared(gaussianMixture), boost::make_shared(discreteFactor)}; - } else { // Create a resulting GaussianMixtureFactor on the separator. - auto factor = boost::make_shared( - KeyVector(continuousSeparator.begin(), continuousSeparator.end()), - discreteSeparator, separatorFactors); - return {boost::make_shared(conditional), factor}; + return {boost::make_shared(gaussianMixture), + boost::make_shared( + continuousSeparator, discreteSeparator, newFactors)}; } } + /* ************************************************************************ * Function to eliminate variables **under the following assumptions**: * 1. When the ordering is fully continuous, and the graph only contains @@ -383,12 +397,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // Fill in discrete discrete separator keys and continuous separator keys. std::set discreteSeparatorSet; - KeySet continuousSeparator; + KeyVector continuousSeparator; for (auto &k : separatorKeys) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); } else { - continuousSeparator.insert(k); + continuousSeparator.push_back(k); } } @@ -463,15 +477,8 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // If factor is hybrid, select based on assignment. GaussianMixtureFactor::shared_ptr gaussianMixture = boost::static_pointer_cast(factors_.at(idx)); - // Compute factor error. - factor_error = gaussianMixture->error(continuousValues); - - // If first factor, assign error, else add it. - if (idx == 0) { - error_tree = factor_error; - } else { - error_tree = error_tree + factor_error; - } + // Compute factor error and add it. + error_tree = error_tree + gaussianMixture->error(continuousValues); } else if (factors_.at(idx)->isContinuous()) { // If continuous only, get the (double) error diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index c851adfe5..144d144bb 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include #include @@ -118,14 +119,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph : Base(graph) {} /// @} + /// @name Adding factors. + /// @{ - using Base::empty; - using Base::reserve; - using Base::size; - using Base::operator[]; using Base::add; using Base::push_back; - using Base::resize; + using Base::reserve; /// Add a Jacobian factor to the factor graph. void add(JacobianFactor&& factor); @@ -172,6 +171,25 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } + /// @} + /// @name Testable + /// @{ + + // TODO(dellaert): customize print and equals. + // void print(const std::string& s = "HybridGaussianFactorGraph", + // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const + // override; + // bool equals(const This& fg, double tol = 1e-9) const override; + + /// @} + /// @name Standard Interface + /// @{ + + using Base::empty; + using Base::size; + using Base::operator[]; + using Base::resize; + /** * @brief Compute error for each discrete assignment, * and return as a tree. @@ -217,6 +235,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @return const Ordering */ const Ordering getHybridOrdering() const; + + /** + * @brief Create a decision tree of factor graphs out of this hybrid factor + * graph. + * + * For example, if there are two mixture factors, one with a discrete key A + * and one with a discrete key B, then the decision tree will have two levels, + * one for A and one for B. The leaves of the tree will be the Gaussian + * factors that have only continuous keys. + */ + GaussianFactorGraphTree assembleGraphTree() const; + + /// @} }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index 57e0daf8d..d6b83e30d 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -99,9 +99,11 @@ void HybridNonlinearISAM::print(const string& s, const KeyFormatter& keyFormatter) const { cout << s << "ReorderInterval: " << reorderInterval_ << " Current Count: " << reorderCounter_ << endl; - isam_.print("HybridGaussianISAM:\n", keyFormatter); + std::cout << "HybridGaussianISAM:" << std::endl; + isam_.print("", keyFormatter); linPoint_.print("Linearization Point:\n", keyFormatter); - factors_.print("Nonlinear Graph:\n", keyFormatter); + std::cout << "Nonlinear Graph:" << std::endl; + factors_.print("", keyFormatter); } /* ************************************************************************* */ diff --git a/gtsam/hybrid/HybridNonlinearISAM.h b/gtsam/hybrid/HybridNonlinearISAM.h index 47aa81c55..53bacb0ff 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.h +++ b/gtsam/hybrid/HybridNonlinearISAM.h @@ -90,7 +90,7 @@ class GTSAM_EXPORT HybridNonlinearISAM { const Values& getLinearizationPoint() const { return linPoint_; } /** Return the current discrete assignment */ - const DiscreteValues& getAssignment() const { return assignment_; } + const DiscreteValues& assignment() const { return assignment_; } /** get underlying nonlinear graph */ const HybridNonlinearFactorGraph& getFactorsUnsafe() const { diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index efe65bc31..4c4f5fa1e 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -168,6 +168,15 @@ class GTSAM_EXPORT HybridValues { return *this; } + /// Extract continuous values with given keys. + VectorValues continuousSubset(const KeyVector& keys) const { + VectorValues measurements; + for (const auto& key : keys) { + measurements.insert(key, continuous_.at(key)); + } + return measurements; + } + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index fc1a9a2b8..5285dd191 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor { } /// Error for HybridValues is not provided for nonlinear hybrid factor. - double error(const HybridValues &values) const override { + double error(const HybridValues& values) const override { throw std::runtime_error( "MixtureFactor::error(HybridValues) not implemented."); } + /** + * @brief Get the dimension of the factor (number of rows on linearization). + * Returns the dimension of the first component factor. + * @return size_t + */ size_t dim() const { - // TODO(Varun) - throw std::runtime_error("MixtureFactor::dim not implemented."); + const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_); + auto factor = factors_(assignments.at(0)); + return factor->dim(); } /// Testable diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 21a496eee..87aa3bc14 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -40,6 +40,15 @@ virtual class HybridFactor { bool empty() const; size_t size() const; gtsam::KeyVector keys() const; + + // Standard interface: + double error(const gtsam::HybridValues &values) const; + bool isDiscrete() const; + bool isContinuous() const; + bool isHybrid() const; + size_t nrContinuous() const; + gtsam::DiscreteKeys discreteKeys() const; + gtsam::KeyVector continuousKeys() const; }; #include @@ -50,7 +59,13 @@ virtual class HybridConditional { bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; size_t nrFrontals() const; size_t nrParents() const; + + // Standard interface: + gtsam::GaussianMixture* asMixture() const; + gtsam::GaussianConditional* asGaussian() const; + gtsam::DiscreteConditional* asDiscrete() const; gtsam::Factor* inner(); + double error(const gtsam::HybridValues& values) const; }; #include @@ -61,6 +76,7 @@ virtual class HybridDiscreteFactor { gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const; gtsam::Factor* inner(); + double error(const gtsam::HybridValues &values) const; }; #include diff --git a/gtsam/hybrid/tests/TinyHybridExample.h b/gtsam/hybrid/tests/TinyHybridExample.h new file mode 100644 index 000000000..ba04263f8 --- /dev/null +++ b/gtsam/hybrid/tests/TinyHybridExample.h @@ -0,0 +1,96 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010-2023, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file TinyHybridExample.h + * @date December, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include +#pragma once + +namespace gtsam { +namespace tiny { + +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +// Create mode key: 0 is low-noise, 1 is high-noise. +const DiscreteKey mode{M(0), 2}; + +/** + * Create a tiny two variable hybrid model which represents + * the generative probability P(z,x,mode) = P(z|x,mode)P(x)P(mode). + */ +inline HybridBayesNet createHybridBayesNet(int num_measurements = 1) { + HybridBayesNet bayesNet; + + // Create Gaussian mixture z_i = x0 + noise for each measurement. + for (int i = 0; i < num_measurements; i++) { + const auto conditional0 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5)); + const auto conditional1 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3)); + GaussianMixture gm({Z(i)}, {X(0)}, {mode}, {conditional0, conditional1}); + bayesNet.emplaceMixture(gm); // copy :-( + } + + // Create prior on X(0). + const auto prior_on_x0 = + GaussianConditional::FromMeanAndStddev(X(0), Vector1(5.0), 0.5); + bayesNet.emplaceGaussian(prior_on_x0); // copy :-( + + // Add prior on mode. + bayesNet.emplaceDiscrete(mode, "4/6"); + + return bayesNet; +} + +/** + * Convert a hybrid Bayes net to a hybrid Gaussian factor graph. + */ +inline HybridGaussianFactorGraph convertBayesNet( + const HybridBayesNet& bayesNet, const VectorValues& measurements) { + HybridGaussianFactorGraph fg; + int num_measurements = bayesNet.size() - 2; + for (int i = 0; i < num_measurements; i++) { + auto conditional = bayesNet.atMixture(i); + auto factor = conditional->likelihood({{Z(i), measurements.at(Z(i))}}); + fg.push_back(factor); + } + fg.push_back(bayesNet.atGaussian(num_measurements)); + fg.push_back(bayesNet.atDiscrete(num_measurements + 1)); + return fg; +} + +/** + * Create a tiny two variable hybrid factor graph which represents a discrete + * mode and a continuous variable x0, given a number of measurements of the + * continuous variable x0. If no measurements are given, they are sampled from + * the generative Bayes net model HybridBayesNet::Example(num_measurements) + */ +inline HybridGaussianFactorGraph createHybridGaussianFactorGraph( + int num_measurements = 1, + boost::optional measurements = boost::none) { + auto bayesNet = createHybridBayesNet(num_measurements); + if (measurements) { + return convertBayesNet(bayesNet, *measurements); + } else { + return convertBayesNet(bayesNet, bayesNet.sample().continuous()); + } +} + +} // namespace tiny +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index d17968a3a..16b33a0d5 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) { // Create sum of two mixture factors: it will be a decision tree now on both // discrete variables m1 and m2: - GaussianMixtureFactor::Sum sum; + GaussianFactorGraphTree sum; sum += mixtureFactorA; sum += mixtureFactorB; @@ -89,8 +89,8 @@ TEST(GaussianMixtureFactor, Sum) { mode[m1.first] = 1; mode[m2.first] = 2; auto actual = sum(mode); - EXPECT(actual.at(0) == f11); - EXPECT(actual.at(1) == f22); + EXPECT(actual.graph.at(0) == f11); + EXPECT(actual.graph.at(1) == f22); } TEST(GaussianMixtureFactor, Printing) { diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index fe8cdcd64..ef552bd92 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,19 +18,18 @@ * @date December 2021 */ -#include #include #include #include #include "Switching.h" +#include "TinyHybridExample.h" // Include for test suite #include using namespace std; using namespace gtsam; -using namespace gtsam::serializationTestHelpers; using noiseModel::Isotropic; using symbol_shorthand::M; @@ -63,7 +62,7 @@ TEST(HybridBayesNet, Add) { /* ****************************************************************************/ // Test evaluate for a pure discrete Bayes net P(Asia). -TEST(HybridBayesNet, evaluatePureDiscrete) { +TEST(HybridBayesNet, EvaluatePureDiscrete) { HybridBayesNet bayesNet; bayesNet.emplaceDiscrete(Asia, "99/1"); HybridValues values; @@ -71,6 +70,13 @@ TEST(HybridBayesNet, evaluatePureDiscrete) { EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); } +/* ****************************************************************************/ +// Test creation of a tiny hybrid Bayes net. +TEST(HybridBayesNet, Tiny) { + auto bayesNet = tiny::createHybridBayesNet(); + EXPECT_LONGS_EQUAL(3, bayesNet.size()); +} + /* ****************************************************************************/ // Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). TEST(HybridBayesNet, evaluateHybrid) { @@ -180,7 +186,7 @@ TEST(HybridBayesNet, OptimizeAssignment) { /* ****************************************************************************/ // Test Bayes net optimize TEST(HybridBayesNet, Optimize) { - Switching s(4); + Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1"); Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); HybridBayesNet::shared_ptr hybridBayesNet = @@ -188,25 +194,24 @@ TEST(HybridBayesNet, Optimize) { HybridValues delta = hybridBayesNet->optimize(); - // TODO(Varun) The expectedAssignment should be 111, not 101 + // NOTE: The true assignment is 111, but the discrete priors cause 101 DiscreteValues expectedAssignment; expectedAssignment[M(0)] = 1; - expectedAssignment[M(1)] = 0; + expectedAssignment[M(1)] = 1; expectedAssignment[M(2)] = 1; EXPECT(assert_equal(expectedAssignment, delta.discrete())); - // TODO(Varun) This should be all -Vector1::Ones() VectorValues expectedValues; - expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); - expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); - expectedValues.insert(X(2), -1.00971 * Vector1::Ones()); - expectedValues.insert(X(3), -1.0001 * Vector1::Ones()); + expectedValues.insert(X(0), -Vector1::Ones()); + expectedValues.insert(X(1), -Vector1::Ones()); + expectedValues.insert(X(2), -Vector1::Ones()); + expectedValues.insert(X(3), -Vector1::Ones()); EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } /* ****************************************************************************/ -// Test bayes net error +// Test Bayes net error TEST(HybridBayesNet, Error) { Switching s(3); @@ -237,7 +242,7 @@ TEST(HybridBayesNet, Error) { EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); // Verify error computation and check for specific error value - DiscreteValues discrete_values {{M(0), 1}, {M(1), 1}}; + DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}}; double total_error = 0; for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { @@ -323,18 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { discrete_conditional_tree->apply(checker); } -/* ****************************************************************************/ -// Test HybridBayesNet serialization. -TEST(HybridBayesNet, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); - - EXPECT(equalsObj(hbn)); - EXPECT(equalsXML(hbn)); - EXPECT(equalsBinary(hbn)); -} - /* ****************************************************************************/ // Test HybridBayesNet sampling. TEST(HybridBayesNet, Sampling) { diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index b4d049210..b957a67d0 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -155,7 +155,7 @@ TEST(HybridBayesTree, Optimize) { dfg.push_back( boost::dynamic_pointer_cast(factor->inner())); } - + // Add the probabilities for each branch DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; vector probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656, @@ -211,29 +211,15 @@ TEST(HybridBayesTree, Choose) { ordering += M(0); ordering += M(1); ordering += M(2); - - //TODO(Varun) get segfault if ordering not provided + + // TODO(Varun) get segfault if ordering not provided auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); - + auto expected_gbt = bayesTree->choose(assignment); EXPECT(assert_equal(expected_gbt, gbt)); } -/* ****************************************************************************/ -// Test HybridBayesTree serialization. -TEST(HybridBayesTree, Serialization) { - Switching s(4); - Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); - HybridBayesTree hbt = - *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); - - using namespace gtsam::serializationTestHelpers; - EXPECT(equalsObj(hbt)); - EXPECT(equalsXML(hbt)); - EXPECT(equalsBinary(hbt)); -} - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 660cb3317..4a53a3141 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -280,11 +280,10 @@ AlgebraicDecisionTree getProbPrimeTree( return probPrimeTree; } -/****************************************************************************/ -/** +/********************************************************************************* * Test for correctness of different branches of the P'(Continuous | Discrete). * The values should match those of P'(Continuous) for each discrete mode. - */ + ********************************************************************************/ TEST(HybridEstimation, Probability) { constexpr size_t K = 4; std::vector measurements = {0, 1, 2, 2}; @@ -441,18 +440,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() { * Do hybrid elimination and do regression test on discrete conditional. ********************************************************************************/ TEST(HybridEstimation, eliminateSequentialRegression) { - // 1. Create the factor graph from the nonlinear factor graph. + // Create the factor graph from the nonlinear factor graph. HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); - // 2. Eliminate into BN - const Ordering ordering = fg->getHybridOrdering(); - HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); - // GTSAM_PRINT(*bn); + // Create expected discrete conditional on m0. + DiscreteKey m(M(0), 2); + DiscreteConditional expected(m % "0.51341712/1"); // regression - // TODO(dellaert): dc should be discrete conditional on m0, but it is an - // unnormalized factor? DiscreteKey m(M(0), 2); DiscreteConditional expected(m - // % "0.51341712/1"); auto dc = bn->back()->asDiscreteConditional(); - // EXPECT(assert_equal(expected, *dc, 1e-9)); + // Eliminate into BN using one ordering + Ordering ordering1; + ordering1 += X(0), X(1), M(0); + HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1); + + // Check that the discrete conditional matches the expected. + auto dc1 = bn1->back()->asDiscrete(); + EXPECT(assert_equal(expected, *dc1, 1e-9)); + + // Eliminate into BN using a different ordering + Ordering ordering2; + ordering2 += X(0), X(1), M(0); + HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2); + + // Check that the discrete conditional matches the expected. + auto dc2 = bn2->back()->asDiscrete(); + EXPECT(assert_equal(expected, *dc2, 1e-9)); } /********************************************************************************* @@ -467,45 +478,35 @@ TEST(HybridEstimation, eliminateSequentialRegression) { ********************************************************************************/ TEST(HybridEstimation, CorrectnessViaSampling) { // 1. Create the factor graph from the nonlinear factor graph. - HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); + const auto fg = createHybridGaussianFactorGraph(); // 2. Eliminate into BN const Ordering ordering = fg->getHybridOrdering(); - HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); + const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); // Set up sampling std::mt19937_64 rng(11); - // 3. Do sampling - int num_samples = 10; - - // Functor to compute the ratio between the - // Bayes net and the factor graph. - auto compute_ratio = - [](const HybridBayesNet::shared_ptr& bayesNet, - const HybridGaussianFactorGraph::shared_ptr& factorGraph, - const HybridValues& sample) -> double { - const DiscreteValues assignment = sample.discrete(); - // Compute in log form for numerical stability - double log_ratio = bayesNet->error({sample.continuous(), assignment}) - - factorGraph->error({sample.continuous(), assignment}); - double ratio = exp(-log_ratio); - return ratio; + // Compute the log-ratio between the Bayes net and the factor graph. + auto compute_ratio = [&](const HybridValues& sample) -> double { + return bn->evaluate(sample) / fg->probPrime(sample); }; // The error evaluated by the factor graph and the Bayes net should differ by // the normalizing term computed via the Bayes net determinant. const HybridValues sample = bn->sample(&rng); - double ratio = compute_ratio(bn, fg, sample); + double expected_ratio = compute_ratio(sample); // regression - EXPECT_DOUBLES_EQUAL(1.0, ratio, 1e-9); + EXPECT_DOUBLES_EQUAL(0.728588, expected_ratio, 1e-6); - // 4. Check that all samples == constant + // 3. Do sampling + constexpr int num_samples = 10; for (size_t i = 0; i < num_samples; i++) { // Sample from the bayes net const HybridValues sample = bn->sample(&rng); - EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9); + // 4. Check that the ratio is constant. + EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-6); } } diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 171b91d51..fa371cf16 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -47,6 +47,7 @@ #include #include "Switching.h" +#include "TinyHybridExample.h" using namespace std; using namespace gtsam; @@ -133,7 +134,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { auto dc = result->at(2)->asDiscrete(); DiscreteValues dv; dv[M(1)] = 0; - EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3); + // Regression test + EXPECT_DOUBLES_EQUAL(0.62245933120185448, dc->operator()(dv), 1e-3); } /* ************************************************************************* */ @@ -612,6 +614,108 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) { EXPECT(assert_equal(expected_probs, probs, 1e-7)); } +/* ****************************************************************************/ +// Check that assembleGraphTree assembles Gaussian factor graphs for each +// assignment. +TEST(HybridGaussianFactorGraph, assembleGraphTree) { + using symbol_shorthand::Z; + const int num_measurements = 1; + auto fg = tiny::createHybridGaussianFactorGraph( + num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); + EXPECT_LONGS_EQUAL(3, fg.size()); + + auto sum = fg.assembleGraphTree(); + + // Get mixture factor: + auto mixture = boost::dynamic_pointer_cast(fg.at(0)); + using GF = GaussianFactor::shared_ptr; + + // Get prior factor: + const GF prior = + boost::dynamic_pointer_cast(fg.at(1))->inner(); + + // Create DiscreteValues for both 0 and 1: + DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}}; + + // Expected decision tree with two factor graphs: + // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) + GaussianFactorGraphTree expectedSum{ + M(0), + {GaussianFactorGraph(std::vector{mixture->factor(d0), prior}), + mixture->constant(d0)}, + {GaussianFactorGraph(std::vector{mixture->factor(d1), prior}), + mixture->constant(d1)}}; + + EXPECT(assert_equal(expectedSum(d0), sum(d0), 1e-5)); + EXPECT(assert_equal(expectedSum(d1), sum(d1), 1e-5)); +} + +/* ****************************************************************************/ +// Check that eliminating tiny net with 1 measurement yields correct result. +TEST(HybridGaussianFactorGraph, EliminateTiny1) { + using symbol_shorthand::Z; + const int num_measurements = 1; + auto fg = tiny::createHybridGaussianFactorGraph( + num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); + + // Create expected Bayes Net: + HybridBayesNet expectedBayesNet; + + // Create Gaussian mixture on X(0). + using tiny::mode; + // regression, but mean checked to be 5.0 in both cases: + const auto conditional0 = boost::make_shared( + X(0), Vector1(14.1421), I_1x1 * 2.82843), + conditional1 = boost::make_shared( + X(0), Vector1(10.1379), I_1x1 * 2.02759); + GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1}); + expectedBayesNet.emplaceMixture(gm); // copy :-( + + // Add prior on mode. + expectedBayesNet.emplaceDiscrete(mode, "74/26"); + + // Test elimination + Ordering ordering; + ordering.push_back(X(0)); + ordering.push_back(M(0)); + const auto posterior = fg.eliminateSequential(ordering); + EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); +} + +/* ****************************************************************************/ +// Check that eliminating tiny net with 2 measurements yields correct result. +TEST(HybridGaussianFactorGraph, EliminateTiny2) { + // Create factor graph with 2 measurements such that posterior mean = 5.0. + using symbol_shorthand::Z; + const int num_measurements = 2; + auto fg = tiny::createHybridGaussianFactorGraph( + num_measurements, + VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}}); + + // Create expected Bayes Net: + HybridBayesNet expectedBayesNet; + + // Create Gaussian mixture on X(0). + using tiny::mode; + // regression, but mean checked to be 5.0 in both cases: + const auto conditional0 = boost::make_shared( + X(0), Vector1(17.3205), I_1x1 * 3.4641), + conditional1 = boost::make_shared( + X(0), Vector1(10.274), I_1x1 * 2.0548); + GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1}); + expectedBayesNet.emplaceMixture(gm); // copy :-( + + // Add prior on mode. + expectedBayesNet.emplaceDiscrete(mode, "23/77"); + + // Test elimination + Ordering ordering; + ordering.push_back(X(0)); + ordering.push_back(M(0)); + const auto posterior = fg.eliminateSequential(ordering); + EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 14f9db8e4..1ce10b452 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -177,19 +177,19 @@ TEST(HybridGaussianElimination, IncrementalInference) { // Test the probability values with regression tests. DiscreteValues assignment; - EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); + EXPECT(assert_equal(0.0952922, m00_prob, 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 0; - EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 1; assignment[M(1)] = 0; - EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 1; - EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 1; assignment[M(1)] = 1; - EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5)); // Check if the clique conditional generated from incremental elimination // matches that of batch elimination. @@ -199,10 +199,10 @@ TEST(HybridGaussianElimination, IncrementalInference) { isam[M(1)]->conditional()->inner()); // Account for the probability terms from evaluating continuous FGs DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; - vector probs = {0.061923317, 0.20415914, 0.18374323, 0.2}; + vector probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485}; auto expectedConditional = boost::make_shared(discrete_keys, probs); - EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6)); + EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6)); } /* ****************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index a4de4a1ae..d84f4b352 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -443,7 +443,7 @@ TEST(HybridFactorGraph, Full_Elimination) { ordering.clear(); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); discreteBayesNet = - *discrete_fg.eliminateSequential(ordering, EliminateForMPE); + *discrete_fg.eliminateSequential(ordering, EliminateDiscrete); } // Create ordering. @@ -638,22 +638,30 @@ conditional 2: Hybrid P( x2 | m0 m1) 0 0 Leaf p(x2) R = [ 10.0494 ] d = [ -10.1489 ] + mean: 1 elements + x2: -1.0099 No noise model 0 1 Leaf p(x2) R = [ 10.0494 ] d = [ -10.1479 ] + mean: 1 elements + x2: -1.0098 No noise model 1 Choice(m0) 1 0 Leaf p(x2) R = [ 10.0494 ] d = [ -10.0504 ] + mean: 1 elements + x2: -1.0001 No noise model 1 1 Leaf p(x2) R = [ 10.0494 ] d = [ -10.0494 ] + mean: 1 elements + x2: -1 No noise model )"; diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index c1689b6ab..68a55abdd 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -191,24 +191,23 @@ TEST(HybridNonlinearISAM, IncrementalInference) { *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete(); double m00_prob = decisionTree(m00); - auto discreteConditional = - bayesTree[M(1)]->conditional()->asDiscrete(); + auto discreteConditional = bayesTree[M(1)]->conditional()->asDiscrete(); // Test the probability values with regression tests. DiscreteValues assignment; - EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); + EXPECT(assert_equal(0.0952922, m00_prob, 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 0; - EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 1; assignment[M(1)] = 0; - EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 0; assignment[M(1)] = 1; - EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5)); assignment[M(0)] = 1; assignment[M(1)] = 1; - EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5)); + EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5)); // Check if the clique conditional generated from incremental elimination // matches that of batch elimination. @@ -217,10 +216,10 @@ TEST(HybridNonlinearISAM, IncrementalInference) { bayesTree[M(1)]->conditional()->inner()); // Account for the probability terms from evaluating continuous FGs DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; - vector probs = {0.061923317, 0.20415914, 0.18374323, 0.2}; + vector probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485}; auto expectedConditional = boost::make_shared(discrete_keys, probs); - EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6)); + EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6)); } /* ****************************************************************************/ @@ -358,10 +357,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { // Run update with pruning size_t maxComponents = 5; incrementalHybrid.update(graph1, initial); + incrementalHybrid.prune(maxComponents); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(maxComponents); - // Check if we have a bayes tree with 4 hybrid nodes, // each with 2, 4, 8, and 5 (pruned) leaves respetively. EXPECT_LONGS_EQUAL(4, bayesTree.size()); @@ -383,10 +381,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { // Run update with pruning a second time. incrementalHybrid.update(graph2, initial); + incrementalHybrid.prune(maxComponents); bayesTree = incrementalHybrid.bayesTree(); - bayesTree.prune(maxComponents); - // Check if we have a bayes tree with pruned hybrid nodes, // with 5 (pruned) leaves. CHECK_EQUAL(5, bayesTree.size()); diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index fe3212eda..9e4d66bf2 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -70,8 +70,7 @@ MixtureFactor } /* ************************************************************************* */ -// Test the error of the MixtureFactor -TEST(MixtureFactor, Error) { +static MixtureFactor getMixtureFactor() { DiscreteKey m1(1, 2); double between0 = 0.0; @@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) { boost::make_shared>(X(1), X(2), between1, model); std::vector factors{f0, f1}; - MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + return MixtureFactor({X(1), X(2)}, {m1}, factors); +} + +/* ************************************************************************* */ +// Test the error of the MixtureFactor +TEST(MixtureFactor, Error) { + auto mixtureFactor = getMixtureFactor(); Values continuousValues; continuousValues.insert(X(1), 0); @@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) { AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + DiscreteKey m1(1, 2); std::vector discrete_keys = {m1}; std::vector errors = {0.5, 0}; AlgebraicDecisionTree expected_error(discrete_keys, errors); @@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) { EXPECT(assert_equal(expected_error, error_tree)); } +/* ************************************************************************* */ +// Test dim of the MixtureFactor +TEST(MixtureFactor, Dim) { + auto mixtureFactor = getMixtureFactor(); + EXPECT_LONGS_EQUAL(1, mixtureFactor.dim()); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp new file mode 100644 index 000000000..941a1cdb3 --- /dev/null +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -0,0 +1,179 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testSerializationHybrid.cpp + * @brief Unit tests for hybrid serialization + * @author Varun Agrawal + * @date January 2023 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Switching.h" + +// Include for test suite +#include + +using namespace std; +using namespace gtsam; +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +using namespace serializationTestHelpers; + +BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor"); +BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); +BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); +BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); + +BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); +using ADT = AlgebraicDecisionTree; +BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); +BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf"); +BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") + +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors, + "gtsam_GaussianMixtureFactor_Factors"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf, + "gtsam_GaussianMixtureFactor_Factors_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice, + "gtsam_GaussianMixtureFactor_Factors_Choice"); + +BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals, + "gtsam_GaussianMixture_Conditionals"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf, + "gtsam_GaussianMixture_Conditionals_Leaf"); +BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice, + "gtsam_GaussianMixture_Conditionals_Choice"); +// Needed since GaussianConditional::FromMeanAndStddev uses it +BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); + +BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet"); + +/* ****************************************************************************/ +// Test HybridGaussianFactor serialization. +TEST(HybridSerialization, HybridGaussianFactor) { + const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1)); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test HybridDiscreteFactor serialization. +TEST(HybridSerialization, HybridDiscreteFactor) { + DiscreteKeys discreteKeys{{M(0), 2}}; + const HybridDiscreteFactor factor( + DecisionTreeFactor(discreteKeys, std::vector{0.4, 0.6})); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test GaussianMixtureFactor serialization. +TEST(HybridSerialization, GaussianMixtureFactor) { + KeyVector continuousKeys{X(0)}; + DiscreteKeys discreteKeys{{M(0), 2}}; + + auto A = Matrix::Zero(2, 1); + auto b0 = Matrix::Zero(2, 1); + auto b1 = Matrix::Ones(2, 1); + auto f0 = boost::make_shared(X(0), A, b0); + auto f1 = boost::make_shared(X(0), A, b1); + std::vector factors{f0, f1}; + + const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors); + + EXPECT(equalsObj(factor)); + EXPECT(equalsXML(factor)); + EXPECT(equalsBinary(factor)); +} + +/* ****************************************************************************/ +// Test HybridConditional serialization. +TEST(HybridSerialization, HybridConditional) { + const DiscreteKey mode(M(0), 2); + Matrix1 I = Matrix1::Identity(); + const auto conditional = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const HybridConditional hc(conditional); + + EXPECT(equalsObj(hc)); + EXPECT(equalsXML(hc)); + EXPECT(equalsBinary(hc)); +} + +/* ****************************************************************************/ +// Test GaussianMixture serialization. +TEST(HybridSerialization, GaussianMixture) { + const DiscreteKey mode(M(0), 2); + Matrix1 I = Matrix1::Identity(); + const auto conditional0 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); + const auto conditional1 = boost::make_shared( + GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); + const GaussianMixture gm({Z(0)}, {X(0)}, {mode}, + {conditional0, conditional1}); + + EXPECT(equalsObj(gm)); + EXPECT(equalsXML(gm)); + EXPECT(equalsBinary(gm)); +} + +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridSerialization, HybridBayesNet) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridSerialization, HybridBayesTree) { + Switching s(2); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index ecfa02282..39a21a617 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -67,7 +67,7 @@ namespace gtsam { GaussianConditional GaussianConditional::FromMeanAndStddev(Key key, const Vector& mu, double sigma) { - // |Rx - d| = |x-(Ay + b)|/sigma + // |Rx - d| = |x - mu|/sigma const Matrix R = Matrix::Identity(mu.size(), mu.size()); const Vector& d = mu; return GaussianConditional(key, d, R, @@ -120,6 +120,10 @@ namespace gtsam { << endl; } cout << formatMatrixIndented(" d = ", getb(), true) << "\n"; + if (nrParents() == 0) { + const auto mean = solve({}); // solve for mean. + mean.print(" mean"); + } if (model_) model_->print(" Noise model: "); else @@ -189,7 +193,7 @@ double GaussianConditional::logNormalizationConstant() const { /* ************************************************************************* */ // density = k exp(-error(x)) -// log = log(k) -error(x) - 0.5 * n*log(2*pi) +// log = log(k) -error(x) double GaussianConditional::logDensity(const VectorValues& x) const { return logNormalizationConstant() - error(x); } diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index 20d730856..d8c3f9455 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -466,6 +466,31 @@ TEST(GaussianConditional, sample) { // EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5)); } +/* ************************************************************************* */ +TEST(GaussianConditional, LogNormalizationConstant) { + // Create univariate standard gaussian conditional + auto std_gaussian = + GaussianConditional::FromMeanAndStddev(X(0), Vector1::Zero(), 1.0); + VectorValues values; + values.insert(X(0), Vector1::Zero()); + double logDensity = std_gaussian.logDensity(values); + + // Regression. + // These values were computed by hand for a univariate standard gaussian. + EXPECT_DOUBLES_EQUAL(-0.9189385332046727, logDensity, 1e-9); + EXPECT_DOUBLES_EQUAL(0.3989422804014327, exp(logDensity), 1e-9); + + // Similar test for multivariate gaussian but with sigma 2.0 + double sigma = 2.0; + auto conditional = GaussianConditional::FromMeanAndStddev(X(0), Vector3::Zero(), sigma); + VectorValues x; + x.insert(X(0), Vector3::Zero()); + Matrix3 Sigma = I_3x3 * sigma * sigma; + double expectedLogNormalizingConstant = log(1 / sqrt((2 * M_PI * Sigma).determinant())); + + EXPECT_DOUBLES_EQUAL(expectedLogNormalizingConstant, conditional.logNormalizationConstant(), 1e-9); +} + /* ************************************************************************* */ TEST(GaussianConditional, Print) { Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); @@ -482,6 +507,8 @@ TEST(GaussianConditional, Print) { " R = [ 1 0 ]\n" " [ 0 1 ]\n" " d = [ 20 40 ]\n" + " mean: 1 elements\n" + " x0: 20 40\n" "isotropic dim=2 sigma=3\n"; EXPECT(assert_print_equal(expected, conditional, "GaussianConditional")); diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 5398160dc..f83b95442 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -18,9 +18,9 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, - GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, - HybridGaussianFactorGraph, JacobianFactor, Ordering, - noiseModel) + GaussianMixture, GaussianMixtureFactor, HybridBayesNet, + HybridGaussianFactorGraph, HybridValues, JacobianFactor, + Ordering, noiseModel) class TestHybridGaussianFactorGraph(GtsamTestCase): @@ -82,10 +82,12 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): self.assertEqual(hv.atDiscrete(C(0)), 1) @staticmethod - def tiny(num_measurements: int = 1) -> HybridBayesNet: + def tiny(num_measurements: int = 1, prior_mean: float = 5.0, + prior_sigma: float = 0.5) -> HybridBayesNet: """ Create a tiny two variable hybrid model which represents - the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). + the generative probability P(Z, x0, mode) = P(Z|x0, mode)P(x0)P(mode). + num_measurements: number of measurements in Z = {z0, z1...} """ # Create hybrid Bayes net. bayesNet = HybridBayesNet() @@ -94,23 +96,24 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): mode = (M(0), 2) # Create Gaussian mixture Z(0) = X(0) + noise for each measurement. - I = np.eye(1) + I_1x1 = np.eye(1) keys = DiscreteKeys() keys.push_back(mode) for i in range(num_measurements): conditional0 = GaussianConditional.FromMeanAndStddev(Z(i), - I, + I_1x1, X(0), [0], sigma=0.5) conditional1 = GaussianConditional.FromMeanAndStddev(Z(i), - I, + I_1x1, X(0), [0], sigma=3) bayesNet.emplaceMixture([Z(i)], [X(0)], keys, [conditional0, conditional1]) # Create prior on X(0). - prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) + prior_on_x0 = GaussianConditional.FromMeanAndStddev( + X(0), [prior_mean], prior_sigma) bayesNet.addGaussian(prior_on_x0) # Add prior on mode. @@ -118,8 +121,41 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): return bayesNet + def test_evaluate(self): + """Test evaluate with two different prior noise models.""" + # TODO(dellaert): really a HBN test + # Create a tiny Bayes net P(x0) P(m0) P(z0|x0) + bayesNet1 = self.tiny(prior_sigma=0.5, num_measurements=1) + bayesNet2 = self.tiny(prior_sigma=5.0, num_measurements=1) + # bn1: # 1/sqrt(2*pi*0.5^2) + # bn2: # 1/sqrt(2*pi*5.0^2) + expected_ratio = np.sqrt(2*np.pi*5.0**2)/np.sqrt(2*np.pi*0.5**2) + mean0 = HybridValues() + mean0.insert(X(0), [5.0]) + mean0.insert(Z(0), [5.0]) + mean0.insert(M(0), 0) + self.assertAlmostEqual(bayesNet1.evaluate(mean0) / + bayesNet2.evaluate(mean0), expected_ratio, + delta=1e-9) + mean1 = HybridValues() + mean1.insert(X(0), [5.0]) + mean1.insert(Z(0), [5.0]) + mean1.insert(M(0), 1) + self.assertAlmostEqual(bayesNet1.evaluate(mean1) / + bayesNet2.evaluate(mean1), expected_ratio, + delta=1e-9) + @staticmethod - def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): + def measurements(sample: HybridValues, indices) -> gtsam.VectorValues: + """Create measurements from a sample, grabbing Z(i) for indices.""" + measurements = gtsam.VectorValues() + for i in indices: + measurements.insert(Z(i), sample.at(Z(i))) + return measurements + + @classmethod + def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet, + sample: HybridValues): """Create a factor graph from the Bayes net with sampled measurements. The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` and thus represents the same joint probability as the Bayes net. @@ -128,31 +164,27 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): num_measurements = bayesNet.size() - 2 for i in range(num_measurements): conditional = bayesNet.atMixture(i) - measurement = gtsam.VectorValues() - measurement.insert(Z(i), sample.at(Z(i))) - factor = conditional.likelihood(measurement) + factor = conditional.likelihood(cls.measurements(sample, [i])) fg.push_back(factor) fg.push_back(bayesNet.atGaussian(num_measurements)) fg.push_back(bayesNet.atDiscrete(num_measurements+1)) return fg @classmethod - def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): - """Do importance sampling to get an estimate of the discrete marginal P(mode).""" - # Use prior on x0, mode as proposal density. - prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) - - # Allocate space for marginals. + def estimate_marginals(cls, target, proposal_density: HybridBayesNet, + N=10000): + """Do importance sampling to estimate discrete marginal P(mode).""" + # Allocate space for marginals on mode. marginals = np.zeros((2,)) # Do importance sampling. - num_measurements = bayesNet.size() - 2 for s in range(N): - proposed = prior.sample() - for i in range(num_measurements): - z_i = sample.at(Z(i)) - proposed.insert(Z(i), z_i) - weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) + proposed = proposal_density.sample() # sample from proposal + target_proposed = target(proposed) # evaluate target + # print(target_proposed, proposal_density.evaluate(proposed)) + weight = target_proposed / proposal_density.evaluate(proposed) + # print weight: + # print(f"weight: {weight}") marginals[proposed.atDiscrete(M(0))] += weight # print marginals: @@ -161,72 +193,146 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): def test_tiny(self): """Test a tiny two variable hybrid model.""" - bayesNet = self.tiny() - sample = bayesNet.sample() - # print(sample) + # P(x0)P(mode)P(z0|x0,mode) + prior_sigma = 0.5 + bayesNet = self.tiny(prior_sigma=prior_sigma) + + # Deterministic values exactly at the mean, for both x and Z: + values = HybridValues() + values.insert(X(0), [5.0]) + values.insert(M(0), 0) # low-noise, standard deviation 0.5 + z0: float = 5.0 + values.insert(Z(0), [z0]) + + def unnormalized_posterior(x): + """Posterior is proportional to joint, centered at 5.0 as well.""" + x.insert(Z(0), [z0]) + return bayesNet.evaluate(x) + + # Create proposal density on (x0, mode), making sure it has same mean: + posterior_information = 1/(prior_sigma**2) + 1/(0.5**2) + posterior_sigma = posterior_information**(-0.5) + proposal_density = self.tiny( + num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma) # Estimate marginals using importance sampling. - marginals = self.estimate_marginals(bayesNet, sample) - # print(f"True mode: {sample.atDiscrete(M(0))}") + marginals = self.estimate_marginals( + target=unnormalized_posterior, proposal_density=proposal_density) + # print(f"True mode: {values.atDiscrete(M(0))}") + # print(f"P(mode=0; Z) = {marginals[0]}") + # print(f"P(mode=1; Z) = {marginals[1]}") + + # Check that the estimate is close to the true value. + self.assertAlmostEqual(marginals[0], 0.74, delta=0.01) + self.assertAlmostEqual(marginals[1], 0.26, delta=0.01) + + fg = self.factor_graph_from_bayes_net(bayesNet, values) + self.assertEqual(fg.size(), 3) + + # Test elimination. + ordering = gtsam.Ordering() + ordering.push_back(X(0)) + ordering.push_back(M(0)) + posterior = fg.eliminateSequential(ordering) + + def true_posterior(x): + """Posterior from elimination.""" + x.insert(Z(0), [z0]) + return posterior.evaluate(x) + + # Estimate marginals using importance sampling. + marginals = self.estimate_marginals( + target=true_posterior, proposal_density=proposal_density) + # print(f"True mode: {values.atDiscrete(M(0))}") # print(f"P(mode=0; z0) = {marginals[0]}") # print(f"P(mode=1; z0) = {marginals[1]}") # Check that the estimate is close to the true value. - self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) - self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) - - fg = self.factor_graph_from_bayes_net(bayesNet, sample) - self.assertEqual(fg.size(), 3) + self.assertAlmostEqual(marginals[0], 0.74, delta=0.01) + self.assertAlmostEqual(marginals[1], 0.26, delta=0.01) @staticmethod def calculate_ratio(bayesNet: HybridBayesNet, fg: HybridGaussianFactorGraph, sample: HybridValues): - """Calculate ratio between Bayes net probability and the factor graph.""" - return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 + """Calculate ratio between Bayes net and factor graph.""" + return bayesNet.evaluate(sample) / fg.probPrime(sample) if \ + fg.probPrime(sample) > 0 else 0 def test_ratio(self): """ - Given a tiny two variable hybrid model, with 2 measurements, - test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n) + Given a tiny two variable hybrid model, with 2 measurements, test the + ratio of the bayes net model representing P(z,x,n)=P(z|x, n)P(x)P(n) and the factor graph P(x, n | z)=P(x | n, z)P(n|z), both of which represent the same posterior. """ - # Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) - bayesNet = self.tiny(num_measurements=2) - # Sample from the Bayes net. - sample: HybridValues = bayesNet.sample() - # print(sample) + # Create generative model P(z, x, n)=P(z|x, n)P(x)P(n) + prior_sigma = 0.5 + bayesNet = self.tiny(prior_sigma=prior_sigma, num_measurements=2) + + # Deterministic values exactly at the mean, for both x and Z: + values = HybridValues() + values.insert(X(0), [5.0]) + values.insert(M(0), 0) # high-noise, standard deviation 3 + measurements = gtsam.VectorValues() + measurements.insert(Z(0), [4.0]) + measurements.insert(Z(1), [6.0]) + values.insert(measurements) + + def unnormalized_posterior(x): + """Posterior is proportional to joint, centered at 5.0 as well.""" + x.insert(measurements) + return bayesNet.evaluate(x) + + # Create proposal density on (x0, mode), making sure it has same mean: + posterior_information = 1/(prior_sigma**2) + 2.0/(3.0**2) + posterior_sigma = posterior_information**(-0.5) + proposal_density = self.tiny( + num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma) # Estimate marginals using importance sampling. - marginals = self.estimate_marginals(bayesNet, sample) - # print(f"True mode: {sample.atDiscrete(M(0))}") - # print(f"P(mode=0; z0, z1) = {marginals[0]}") - # print(f"P(mode=1; z0, z1) = {marginals[1]}") + marginals = self.estimate_marginals( + target=unnormalized_posterior, proposal_density=proposal_density) + # print(f"True mode: {values.atDiscrete(M(0))}") + # print(f"P(mode=0; Z) = {marginals[0]}") + # print(f"P(mode=1; Z) = {marginals[1]}") - # Check marginals based on sampled mode. - if sample.atDiscrete(M(0)) == 0: - self.assertGreater(marginals[0], marginals[1]) - else: - self.assertGreater(marginals[1], marginals[0]) + # Check that the estimate is close to the true value. + self.assertAlmostEqual(marginals[0], 0.23, delta=0.01) + self.assertAlmostEqual(marginals[1], 0.77, delta=0.01) - fg = self.factor_graph_from_bayes_net(bayesNet, sample) + # Convert to factor graph using measurements. + fg = self.factor_graph_from_bayes_net(bayesNet, values) self.assertEqual(fg.size(), 4) # Calculate ratio between Bayes net probability and the factor graph: - expected_ratio = self.calculate_ratio(bayesNet, fg, sample) + expected_ratio = self.calculate_ratio(bayesNet, fg, values) # print(f"expected_ratio: {expected_ratio}\n") - # Create measurements from the sample. - measurements = gtsam.VectorValues() - for i in range(2): - measurements.insert(Z(i), sample.at(Z(i))) - # Check with a number of other samples. for i in range(10): - other = bayesNet.sample() - other.update(measurements) - ratio = self.calculate_ratio(bayesNet, fg, other) + samples = bayesNet.sample() + samples.update(measurements) + ratio = self.calculate_ratio(bayesNet, fg, samples) + # print(f"Ratio: {ratio}\n") + if (ratio > 0): + self.assertAlmostEqual(ratio, expected_ratio) + + # Test elimination. + ordering = gtsam.Ordering() + ordering.push_back(X(0)) + ordering.push_back(M(0)) + posterior = fg.eliminateSequential(ordering) + + # Calculate ratio between Bayes net probability and the factor graph: + expected_ratio = self.calculate_ratio(posterior, fg, values) + # print(f"expected_ratio: {expected_ratio}\n") + + # Check with a number of other samples. + for i in range(10): + samples = posterior.sample() + samples.insert(measurements) + ratio = self.calculate_ratio(posterior, fg, samples) # print(f"Ratio: {ratio}\n") if (ratio > 0): self.assertAlmostEqual(ratio, expected_ratio)