Merge pull request #1360 from borglab/hybrid/elimination

release/4.3a0
Varun Agrawal 2023-01-04 13:29:43 -05:00 committed by GitHub
commit 1c411eb5a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 1304 additions and 365 deletions

View File

@ -64,6 +64,9 @@ namespace gtsam {
*/ */
size_t nrAssignments_; size_t nrAssignments_;
/// Default constructor for serialization.
Leaf() {}
/// Constructor from constant /// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1) Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {} : constant_(constant), nrAssignments_(nrAssignments) {}
@ -154,6 +157,18 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf }; // Leaf
/****************************************************************************/ /****************************************************************************/
@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
/// Default constructor for serialization.
Choice() {}
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice }; // Choice
/****************************************************************************/ /****************************************************************************/

View File

@ -19,9 +19,11 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0; virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0; virtual bool isLeaf() const = 0;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
@ -236,7 +244,7 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f (side-effect) Function taking a value. * @param f (side-effect) Function taking the value of the leaf node.
* *
* @note Due to pruning, the number of leaves may not be the same as the * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with * number of assignments. E.g. if we have a tree on 2 binary variables with
@ -245,7 +253,7 @@ namespace gtsam {
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
* tree.visitWith(visitor); * tree.visit(visitor);
*/ */
template <typename Func> template <typename Func>
void visit(Func f) const; void visit(Func f) const;
@ -261,8 +269,8 @@ namespace gtsam {
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); };
* tree.visitWith(visitor); * tree.visitLeaf(visitor);
*/ */
template <typename Func> template <typename Func>
void visitLeaf(Func f) const; void visitLeaf(Func f) const;
@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree }; // DecisionTree
template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
/** free versions of apply */ /** free versions of apply */
/// Apply unary operator `op` to DecisionTree `f`. /// Apply unary operator `op` to DecisionTree `f`.

View File

@ -156,9 +156,9 @@ namespace gtsam {
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const { const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); DiscreteKeys pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering. // Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs); const auto assignments = DiscreteValues::CartesianProduct(rpairs);
// Construct unordered_map with values // Construct unordered_map with values

View File

@ -231,6 +231,16 @@ namespace gtsam {
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
}; };
// traits // traits

View File

@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose /// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given, DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const; bool forceComplete) const;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
}; };
// DiscreteConditional // DiscreteConditional

View File

@ -20,12 +20,11 @@
// #define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING // #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;

View File

@ -19,6 +19,7 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>

View File

@ -17,13 +17,14 @@
* @date Feb 14, 2011 * @date Feb 14, 2011
*/ */
#include <boost/make_shared.hpp>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <boost/make_shared.hpp>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);

View File

@ -0,0 +1,105 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* testSerializtionDiscrete.cpp
*
* @date January 2023
* @author Varun Agrawal
*/
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/inference/Symbol.h>
using namespace std;
using namespace gtsam;
using Tree = gtsam::DecisionTree<string, int>;
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
/* ****************************************************************************/
// Test DecisionTree serialization.
TEST(DiscreteSerialization, DecisionTree) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
using namespace serializationTestHelpers;
// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));
// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));
// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}
/* ************************************************************************* */
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DiscreteSerialization, DecisionTreeFactor) {
using namespace serializationTestHelpers;
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}
/* ************************************************************************* */
// Check serialization for DiscreteConditional & DiscreteDistribution
TEST(DiscreteSerialization, DiscreteConditional) {
using namespace serializationTestHelpers;
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));
DiscreteDistribution P(A % "3/2/1");
EXPECT(equalsObj<DiscreteDistribution>(P));
EXPECT(equalsXML<DiscreteDistribution>(P));
EXPECT(equalsBinary<DiscreteDistribution>(P));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -51,24 +51,28 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {} Conditionals(discreteParents, conditionalsList)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianFactorGraphTree GaussianMixture::add(
const GaussianMixture::Sum &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph; using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1.graph;
result.push_back(graph2); result.push_back(graph2.graph);
return result; return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) { auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(conditional);
return result; if (conditional) {
return GraphAndConstant(result, conditional->logNormalizationConstant());
} else {
return GraphAndConstant(result, 0.0);
}
}; };
return {conditionals_, lambda}; return {conditionals_, lambda};
} }
@ -98,7 +102,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -146,7 +162,13 @@ KeyVector GaussianMixture::continuousParents() const {
/* ************************************************************************* */ /* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood( boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const { const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals // Check that values has all frontals
for (auto &&kv : frontals) {
if (frontals.find(kv.first) == frontals.end()) {
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
}
}
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(

View File

@ -23,13 +23,13 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
namespace gtsam { namespace gtsam {
class GaussianMixtureFactor;
class HybridValues; class HybridValues;
/** /**
@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
*/ */
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `decisionTree`.
@ -186,10 +193,20 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure. * maintaining the decision tree structure.
* *
* @param sum Decision Tree of Gaussian Factor Graphs * @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum * @return GaussianFactorGraphTree
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
}; };
/// Return the DiscreteKey vector as a set. /// Return the DiscreteKey vector as a set.

View File

@ -81,32 +81,36 @@ void GaussianMixtureFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
return Mixture(factors_, [](const FactorAndConstant &factor_z) { const DiscreteValues &assignment) const {
return factor_z.factor; return factors_(assignment).factor;
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
const GaussianMixtureFactor::Sum &sum) const { return factors_(assignment).constant;
using Y = GaussianFactorGraph; }
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1.graph;
result.push_back(graph2); result.push_back(graph2.graph);
return result; return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const FactorAndConstant &factor_z) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.factor); result.push_back(factor_z.factor);
return result; return GraphAndConstant(result, factor_z.constant);
}; };
return {factors_, wrap}; return {factors_, wrap};
} }

View File

@ -25,10 +25,10 @@
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph;
class HybridValues; class HybridValues;
class DiscreteValues; class DiscreteValues;
class VectorValues; class VectorValues;
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using This = GaussianMixtureFactor; using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>; using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// Gaussian factor and log of normalizing constant. /// Gaussian factor and log of normalizing constant.
@ -60,8 +59,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Return error with constant correction. // Return error with constant correction.
double error(const VectorValues &values) const { double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities. // Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here. // Errors is the negative log-likelihood,
// hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant; return factor->error(values) - constant;
} }
@ -69,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool operator==(const FactorAndConstant &other) const { bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant; return factor == other.factor && constant == other.constant;
} }
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
}; };
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and log-constant.
@ -83,9 +93,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a * @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs. * DecisionTree of Gaussian Factor Graphs.
* *
* @return Sum (DecisionTree<Key, GaussianFactorGraph>) * @return GaussianFactorGraphTree
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public: public:
/// @name Constructors /// @name Constructors
@ -135,12 +145,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
void print( void print(
const std::string &s = "GaussianMixtureFactor\n", const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API /// @name Standard API
/// @{ /// @{
/// Getter for the underlying Gaussian Factor Decision Tree. /// Get factor at a given discrete assignment.
const Mixture factors() const; sharedFactor factor(const DiscreteValues &assignment) const;
/// Get constant at a given discrete assignment.
double constant(const DiscreteValues &assignment) const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -150,7 +164,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables. * variables.
* @return Sum * @return Sum
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/** /**
* @brief Compute error of the GaussianMixtureFactor as a tree. * @brief Compute error of the GaussianMixtureFactor as a tree.
@ -168,11 +182,21 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
}; };
// traits // traits

View File

@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> decisionTree;
@ -271,12 +282,15 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues); const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues); logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply. // If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues); logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability. // Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues); probability *= dc->operator()(discreteValues);

View File

@ -50,17 +50,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Testable /// @name Testable
/// @{ /// @{
/** Check equality */ /// GTSAM-style printing
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}
/// print graph
void print( void print(
const std::string &s = "", const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override { const KeyFormatter &formatter = DefaultKeyFormatter) const override;
Base::print(s, formatter);
} /// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface

View File

@ -17,6 +17,7 @@
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
@ -102,7 +103,37 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other); const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
if (auto gm = asMixture()) {
auto other = e->asMixture();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
}
if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete()));
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
} }
} // namespace gtsam } // namespace gtsam

View File

@ -176,15 +176,7 @@ class GTSAM_EXPORT HybridConditional
boost::shared_ptr<Factor> inner() const { return inner_; } boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional. /// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture. double error(const HybridValues& values) const override;
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
/// @} /// @}
@ -195,6 +187,20 @@ class GTSAM_EXPORT HybridConditional
void serialize(Archive& ar, const unsigned int /*version*/) { void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);
// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
} }
}; // HybridConditional }; // HybridConditional

View File

@ -26,7 +26,6 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other) : Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
->discreteKeys()), ->discreteKeys()),
@ -40,8 +39,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types? if (e == nullptr) return false;
return e != nullptr && Base::equals(*e, tol); if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// Default constructor - for serialization.
HybridDiscreteFactor() = default;
// Implicit conversion from a shared ptr of DF // Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other); HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// Return the error of the underlying Discrete Factor. /// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
}; };
// traits // traits

View File

@ -21,6 +21,8 @@
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/discrete/DecisionTree.h>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
@ -28,6 +30,36 @@ namespace gtsam {
class HybridValues; class HybridValues;
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -160,4 +192,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <> template <>
struct traits<HybridFactor> : public Testable<HybridFactor> {}; struct traits<HybridFactor> : public Testable<HybridFactor> {};
template <>
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
} // namespace gtsam } // namespace gtsam

View File

@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
/* ************************************************************************* */ /* ************************************************************************* */
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other); const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types? if (e == nullptr) return false;
return e != nullptr && Base::equals(*e, tol); if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
inner_->print("\n", formatter); if (inner_) {
inner_->print("\n", formatter);
} else {
std::cout << "\nGaussian: nullptr" << std::endl;
}
}; };
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -43,14 +43,17 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
using This = HybridGaussianFactor; using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
/// @name Constructors
/// @{
/// Default constructor - for serialization.
HybridGaussianFactor() = default; HybridGaussianFactor() = default;
/** /**
* Constructor from shared_ptr of GaussianFactor. * Constructor from shared_ptr of GaussianFactor.
* Example: * Example:
* boost::shared_ptr<GaussianFactor> ptr = * auto ptr = boost::make_shared<JacobianFactor>(...);
* boost::make_shared<JacobianFactor>(...); * HybridGaussianFactor factor(ptr);
*
*/ */
explicit HybridGaussianFactor(const boost::shared_ptr<GaussianFactor> &ptr); explicit HybridGaussianFactor(const boost::shared_ptr<GaussianFactor> &ptr);
@ -80,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/ */
explicit HybridGaussianFactor(HessianFactor &&hf); explicit HybridGaussianFactor(HessianFactor &&hf);
public: /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -99,9 +102,18 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Return pointer to the internal Gaussian factor. /// Return pointer to the internal Gaussian factor.
GaussianFactor::shared_ptr inner() const { return inner_; } GaussianFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor. /// Return the error of the underlying Gaussian factor.
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
}; };
// traits // traits

View File

@ -59,51 +59,50 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianFactorGraphTree addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { const GaussianFactorGraphTree &gfgTree,
using Y = GaussianFactorGraph; const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (gfgTree.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
sum = GaussianMixtureFactor::Sum(result); return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
} else { } else {
auto add = [&factor](const Y &graph) { auto add = [&factor](const GraphAndConstant &graph_z) {
auto result = graph; auto result = graph_z.graph;
result.push_back(factor); result.push_back(factor);
return result; return GraphAndConstant(result, graph_z.constant);
}; };
sum = sum.apply(add); return gfgTree.apply(add);
} }
return sum;
} }
/* ************************************************************************ */ /* ************************************************************************ */
GaussianMixtureFactor::Sum sumFrontals( // TODO(dellaert): Implementation-wise, it's probably more efficient to first
const HybridGaussianFactorGraph &factors) { // collect the discrete keys, and then loop over all assignments to populate a
// sum out frontals, this is the factor on the separator // vector.
gttic(sum); GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree);
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) { if (f->isHybrid()) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum); result = gm->add(result);
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum); result = gm->asMixture()->add(result);
} }
} else if (f->isContinuous()) { } else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) { if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner()); result = addGaussian(result, gf->inner());
} }
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian()); result = addGaussian(result, cg->asGaussian());
} }
} else if (f->isDiscrete()) { } else if (f->isDiscrete()) {
@ -125,17 +124,13 @@ GaussianMixtureFactor::Sum sumFrontals(
} }
} }
for (auto &f : deferredFactors) { gttoc(assembleGraphTree);
sum = addGaussian(sum, f);
}
gttoc(sum); return result;
return sum;
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
continuousElimination(const HybridGaussianFactorGraph &factors, continuousElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
@ -156,7 +151,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
@ -173,48 +168,53 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} }
} }
auto result = EliminateForMPE(dfg, frontalKeys); // NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridDiscreteFactor>(result.second)}; boost::make_shared<HybridDiscreteFactor>(result.second)};
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
bool hasNull =
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
};
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
hybridElimination(const HybridGaussianFactorGraph &factors, hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys, const Ordering &frontalKeys,
const KeySet &continuousSeparator, const KeyVector &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
// sum out frontals, this is the factor 𝜏 on the separator // Collect all the factors to create a set of Gaussian factor graphs in a
GaussianMixtureFactor::Sum sum = sumFrontals(factors); // decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree sum = factors.assembleGraphTree();
// If a tree leaf contains nullptr, // Convert factor graphs with a nullptr to an empty factor graph.
// convert that leaf to an empty GaussianFactorGraph. // This is done after assembly since it is non-trivial to keep track of which
// Needed since the DecisionTree will otherwise create // FG has a nullptr as we're looping over the factors.
// a GFG with a single (null) factor. sum = removeEmpty(sum);
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : gfg;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
if (graph.empty()) { if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, {nullptr, 0.0}};
} }
@ -222,24 +222,34 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttic_(hybrid_eliminate); gttic_(hybrid_eliminate);
#endif #endif
std::pair<boost::shared_ptr<GaussianConditional>, boost::shared_ptr<GaussianConditional> conditional;
boost::shared_ptr<GaussianFactor>> boost::shared_ptr<GaussianFactor> newFactor;
conditional_factor = EliminatePreferCholesky(graph, frontalKeys); boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph_z.graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the // Get the log of the log normalization constant inverse and
// eliminated GaussianConditional // add it to the previous constant.
keysOfEliminated = conditional_factor.first->keys(); const double logZ =
keysOfSeparator = conditional_factor.second->keys(); graph_z.constant - conditional->logNormalizationConstant();
// Get the log of the log normalization constant inverse.
// double logZ = -conditional->logNormalizationConstant();
// // IF this is the last continuous variable to eliminated, we need to
// // calculate the error here: the value of all factors at the mean, see
// // ml_map_rao.pdf.
// if (continuousSeparator.empty()) {
// const auto posterior_mean = conditional->solve(VectorValues());
// logZ += graph_z.graph.error(posterior_mean);
// }
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
return {conditional_factor.first, {conditional_factor.second, 0.0}}; return {conditional, {newFactor, logZ}};
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate); DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
tictoc_print_(); tictoc_print_();
@ -247,46 +257,50 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#endif #endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); GaussianMixture::Conditionals conditionals;
const auto &separatorFactors = pair.second; GaussianMixtureFactor::Factors newFactors;
std::tie(conditionals, newFactors) = unzip(eliminationResults);
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto gaussianMixture = boost::make_shared<GaussianMixture>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first); frontalKeys, continuousSeparator, discreteSeparator, conditionals);
// If there are no more continuous parents, then we should create here a // If there are no more continuous parents, then we should create a
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor here, with the error for each discrete choice.
if (keysOfSeparator.empty()) { if (continuousSeparator.empty()) {
VectorValues empty_values;
auto factorProb = auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
GaussianFactor::shared_ptr factor = factor_z.factor; // This is the probability q(μ) at the MLE point.
if (!factor) { // factor_z.factor is a factor without keys,
return 0.0; // If nullptr, return 0.0 probability // just containing the residual.
} else { return exp(-factor_z.error(VectorValues()));
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant()) +
factor_z.constant;
return std::exp(-error);
}
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
auto discreteFactor = const DecisionTree<Key, double> fdt(newFactors, factorProb);
// // Normalize the values of decision tree to be valid probabilities
// double sum = 0.0;
// auto visitor = [&](double y) { sum += y; };
// fdt.visit(visitor);
// // Check if sum is 0, and update accordingly.
// if (sum == 0) {
// sum = 1.0;
// }
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional), return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)}; boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else { } else {
// Create a resulting GaussianMixtureFactor on the separator. // Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>( return {boost::make_shared<HybridConditional>(gaussianMixture),
KeyVector(continuousSeparator.begin(), continuousSeparator.end()), boost::make_shared<GaussianMixtureFactor>(
discreteSeparator, separatorFactors); continuousSeparator, discreteSeparator, newFactors)};
return {boost::make_shared<HybridConditional>(conditional), factor};
} }
} }
/* ************************************************************************ /* ************************************************************************
* Function to eliminate variables **under the following assumptions**: * Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains * 1. When the ordering is fully continuous, and the graph only contains
@ -383,12 +397,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Fill in discrete discrete separator keys and continuous separator keys. // Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet; std::set<DiscreteKey> discreteSeparatorSet;
KeySet continuousSeparator; KeyVector continuousSeparator;
for (auto &k : separatorKeys) { for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else { } else {
continuousSeparator.insert(k); continuousSeparator.push_back(k);
} }
} }
@ -463,15 +477,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// If factor is hybrid, select based on assignment. // If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture = GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx)); boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error. // Compute factor error and add it.
factor_error = gaussianMixture->error(continuousValues); error_tree = error_tree + gaussianMixture->error(continuousValues);
// If first factor, assign error, else add it.
if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}
} else if (factors_.at(idx)->isContinuous()) { } else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error // If continuous only, get the (double) error

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
@ -118,14 +119,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
: Base(graph) {} : Base(graph) {}
/// @} /// @}
/// @name Adding factors.
/// @{
using Base::empty;
using Base::reserve;
using Base::size;
using Base::operator[];
using Base::add; using Base::add;
using Base::push_back; using Base::push_back;
using Base::resize; using Base::reserve;
/// Add a Jacobian factor to the factor graph. /// Add a Jacobian factor to the factor graph.
void add(JacobianFactor&& factor); void add(JacobianFactor&& factor);
@ -172,6 +171,25 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
} }
} }
/// @}
/// @name Testable
/// @{
// TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
// override;
// bool equals(const This& fg, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
using Base::empty;
using Base::size;
using Base::operator[];
using Base::resize;
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -217,6 +235,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @return const Ordering * @return const Ordering
*/ */
const Ordering getHybridOrdering() const; const Ordering getHybridOrdering() const;
/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
*
* For example, if there are two mixture factors, one with a discrete key A
* and one with a discrete key B, then the decision tree will have two levels,
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianFactorGraphTree assembleGraphTree() const;
/// @}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -99,9 +99,11 @@ void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_ cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl; << " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n", keyFormatter); std::cout << "HybridGaussianISAM:" << std::endl;
isam_.print("", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter); linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter); std::cout << "Nonlinear Graph:" << std::endl;
factors_.print("", keyFormatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -90,7 +90,7 @@ class GTSAM_EXPORT HybridNonlinearISAM {
const Values& getLinearizationPoint() const { return linPoint_; } const Values& getLinearizationPoint() const { return linPoint_; }
/** Return the current discrete assignment */ /** Return the current discrete assignment */
const DiscreteValues& getAssignment() const { return assignment_; } const DiscreteValues& assignment() const { return assignment_; }
/** get underlying nonlinear graph */ /** get underlying nonlinear graph */
const HybridNonlinearFactorGraph& getFactorsUnsafe() const { const HybridNonlinearFactorGraph& getFactorsUnsafe() const {

View File

@ -168,6 +168,15 @@ class GTSAM_EXPORT HybridValues {
return *this; return *this;
} }
/// Extract continuous values with given keys.
VectorValues continuousSubset(const KeyVector& keys) const {
VectorValues measurements;
for (const auto& key : keys) {
measurements.insert(key, continuous_.at(key));
}
return measurements;
}
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor {
} }
/// Error for HybridValues is not provided for nonlinear hybrid factor. /// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override { double error(const HybridValues& values) const override {
throw std::runtime_error( throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented."); "MixtureFactor::error(HybridValues) not implemented.");
} }
/**
* @brief Get the dimension of the factor (number of rows on linearization).
* Returns the dimension of the first component factor.
* @return size_t
*/
size_t dim() const { size_t dim() const {
// TODO(Varun) const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
throw std::runtime_error("MixtureFactor::dim not implemented."); auto factor = factors_(assignments.at(0));
return factor->dim();
} }
/// Testable /// Testable

View File

@ -40,6 +40,15 @@ virtual class HybridFactor {
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
gtsam::KeyVector keys() const; gtsam::KeyVector keys() const;
// Standard interface:
double error(const gtsam::HybridValues &values) const;
bool isDiscrete() const;
bool isContinuous() const;
bool isHybrid() const;
size_t nrContinuous() const;
gtsam::DiscreteKeys discreteKeys() const;
gtsam::KeyVector continuousKeys() const;
}; };
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -50,7 +59,13 @@ virtual class HybridConditional {
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
// Standard interface:
gtsam::GaussianMixture* asMixture() const;
gtsam::GaussianConditional* asGaussian() const;
gtsam::DiscreteConditional* asDiscrete() const;
gtsam::Factor* inner(); gtsam::Factor* inner();
double error(const gtsam::HybridValues& values) const;
}; };
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
@ -61,6 +76,7 @@ virtual class HybridDiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const; bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
gtsam::Factor* inner(); gtsam::Factor* inner();
double error(const gtsam::HybridValues &values) const;
}; };
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>

View File

@ -0,0 +1,96 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2023, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* @file TinyHybridExample.h
* @date December, 2022
* @author Frank Dellaert
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h>
#pragma once
namespace gtsam {
namespace tiny {
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
// Create mode key: 0 is low-noise, 1 is high-noise.
const DiscreteKey mode{M(0), 2};
/**
* Create a tiny two variable hybrid model which represents
* the generative probability P(z,x,mode) = P(z|x,mode)P(x)P(mode).
*/
inline HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
HybridBayesNet bayesNet;
// Create Gaussian mixture z_i = x0 + noise for each measurement.
for (int i = 0; i < num_measurements; i++) {
const auto conditional0 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5));
const auto conditional1 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3));
GaussianMixture gm({Z(i)}, {X(0)}, {mode}, {conditional0, conditional1});
bayesNet.emplaceMixture(gm); // copy :-(
}
// Create prior on X(0).
const auto prior_on_x0 =
GaussianConditional::FromMeanAndStddev(X(0), Vector1(5.0), 0.5);
bayesNet.emplaceGaussian(prior_on_x0); // copy :-(
// Add prior on mode.
bayesNet.emplaceDiscrete(mode, "4/6");
return bayesNet;
}
/**
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph.
*/
inline HybridGaussianFactorGraph convertBayesNet(
const HybridBayesNet& bayesNet, const VectorValues& measurements) {
HybridGaussianFactorGraph fg;
int num_measurements = bayesNet.size() - 2;
for (int i = 0; i < num_measurements; i++) {
auto conditional = bayesNet.atMixture(i);
auto factor = conditional->likelihood({{Z(i), measurements.at(Z(i))}});
fg.push_back(factor);
}
fg.push_back(bayesNet.atGaussian(num_measurements));
fg.push_back(bayesNet.atDiscrete(num_measurements + 1));
return fg;
}
/**
* Create a tiny two variable hybrid factor graph which represents a discrete
* mode and a continuous variable x0, given a number of measurements of the
* continuous variable x0. If no measurements are given, they are sampled from
* the generative Bayes net model HybridBayesNet::Example(num_measurements)
*/
inline HybridGaussianFactorGraph createHybridGaussianFactorGraph(
int num_measurements = 1,
boost::optional<VectorValues> measurements = boost::none) {
auto bayesNet = createHybridBayesNet(num_measurements);
if (measurements) {
return convertBayesNet(bayesNet, *measurements);
} else {
return convertBayesNet(bayesNet, bayesNet.sample().continuous());
}
}
} // namespace tiny
} // namespace gtsam

View File

@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
// Create sum of two mixture factors: it will be a decision tree now on both // Create sum of two mixture factors: it will be a decision tree now on both
// discrete variables m1 and m2: // discrete variables m1 and m2:
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree sum;
sum += mixtureFactorA; sum += mixtureFactorA;
sum += mixtureFactorB; sum += mixtureFactorB;
@ -89,8 +89,8 @@ TEST(GaussianMixtureFactor, Sum) {
mode[m1.first] = 1; mode[m1.first] = 1;
mode[m2.first] = 2; mode[m2.first] = 2;
auto actual = sum(mode); auto actual = sum(mode);
EXPECT(actual.at(0) == f11); EXPECT(actual.graph.at(0) == f11);
EXPECT(actual.at(1) == f22); EXPECT(actual.graph.at(1) == f22);
} }
TEST(GaussianMixtureFactor, Printing) { TEST(GaussianMixtureFactor, Printing) {

View File

@ -18,19 +18,18 @@
* @date December 2021 * @date December 2021
*/ */
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include "Switching.h" #include "Switching.h"
#include "TinyHybridExample.h"
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
using noiseModel::Isotropic; using noiseModel::Isotropic;
using symbol_shorthand::M; using symbol_shorthand::M;
@ -63,7 +62,7 @@ TEST(HybridBayesNet, Add) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia). // Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, evaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1"); bayesNet.emplaceDiscrete(Asia, "99/1");
HybridValues values; HybridValues values;
@ -71,6 +70,13 @@ TEST(HybridBayesNet, evaluatePureDiscrete) {
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
} }
/* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) {
auto bayesNet = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bayesNet.size());
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). // Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) { TEST(HybridBayesNet, evaluateHybrid) {
@ -180,7 +186,7 @@ TEST(HybridBayesNet, OptimizeAssignment) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test Bayes net optimize // Test Bayes net optimize
TEST(HybridBayesNet, Optimize) { TEST(HybridBayesNet, Optimize) {
Switching s(4); Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet = HybridBayesNet::shared_ptr hybridBayesNet =
@ -188,25 +194,24 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
// TODO(Varun) The expectedAssignment should be 111, not 101 // NOTE: The true assignment is 111, but the discrete priors cause 101
DiscreteValues expectedAssignment; DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1; expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 0; expectedAssignment[M(1)] = 1;
expectedAssignment[M(2)] = 1; expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete())); EXPECT(assert_equal(expectedAssignment, delta.discrete()));
// TODO(Varun) This should be all -Vector1::Ones()
VectorValues expectedValues; VectorValues expectedValues;
expectedValues.insert(X(0), -0.999904 * Vector1::Ones()); expectedValues.insert(X(0), -Vector1::Ones());
expectedValues.insert(X(1), -0.99029 * Vector1::Ones()); expectedValues.insert(X(1), -Vector1::Ones());
expectedValues.insert(X(2), -1.00971 * Vector1::Ones()); expectedValues.insert(X(2), -Vector1::Ones());
expectedValues.insert(X(3), -1.0001 * Vector1::Ones()); expectedValues.insert(X(3), -Vector1::Ones());
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test bayes net error // Test Bayes net error
TEST(HybridBayesNet, Error) { TEST(HybridBayesNet, Error) {
Switching s(3); Switching s(3);
@ -237,7 +242,7 @@ TEST(HybridBayesNet, Error) {
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
// Verify error computation and check for specific error value // Verify error computation and check for specific error value
DiscreteValues discrete_values {{M(0), 1}, {M(1), 1}}; DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
double total_error = 0; double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
@ -323,18 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
discrete_conditional_tree->apply(checker); discrete_conditional_tree->apply(checker);
} }
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test HybridBayesNet sampling. // Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) { TEST(HybridBayesNet, Sampling) {

View File

@ -212,7 +212,7 @@ TEST(HybridBayesTree, Choose) {
ordering += M(1); ordering += M(1);
ordering += M(2); ordering += M(2);
//TODO(Varun) get segfault if ordering not provided // TODO(Varun) get segfault if ordering not provided
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
auto expected_gbt = bayesTree->choose(assignment); auto expected_gbt = bayesTree->choose(assignment);
@ -220,20 +220,6 @@ TEST(HybridBayesTree, Choose) {
EXPECT(assert_equal(expected_gbt, gbt)); EXPECT(assert_equal(expected_gbt, gbt));
} }
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -280,11 +280,10 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
return probPrimeTree; return probPrimeTree;
} }
/****************************************************************************/ /*********************************************************************************
/**
* Test for correctness of different branches of the P'(Continuous | Discrete). * Test for correctness of different branches of the P'(Continuous | Discrete).
* The values should match those of P'(Continuous) for each discrete mode. * The values should match those of P'(Continuous) for each discrete mode.
*/ ********************************************************************************/
TEST(HybridEstimation, Probability) { TEST(HybridEstimation, Probability) {
constexpr size_t K = 4; constexpr size_t K = 4;
std::vector<double> measurements = {0, 1, 2, 2}; std::vector<double> measurements = {0, 1, 2, 2};
@ -441,18 +440,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
* Do hybrid elimination and do regression test on discrete conditional. * Do hybrid elimination and do regression test on discrete conditional.
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, eliminateSequentialRegression) { TEST(HybridEstimation, eliminateSequentialRegression) {
// 1. Create the factor graph from the nonlinear factor graph. // Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // Create expected discrete conditional on m0.
const Ordering ordering = fg->getHybridOrdering(); DiscreteKey m(M(0), 2);
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); DiscreteConditional expected(m % "0.51341712/1"); // regression
// GTSAM_PRINT(*bn);
// TODO(dellaert): dc should be discrete conditional on m0, but it is an // Eliminate into BN using one ordering
// unnormalized factor? DiscreteKey m(M(0), 2); DiscreteConditional expected(m Ordering ordering1;
// % "0.51341712/1"); auto dc = bn->back()->asDiscreteConditional(); ordering1 += X(0), X(1), M(0);
// EXPECT(assert_equal(expected, *dc, 1e-9)); HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering
Ordering ordering2;
ordering2 += X(0), X(1), M(0);
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc2, 1e-9));
} }
/********************************************************************************* /*********************************************************************************
@ -467,45 +478,35 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, CorrectnessViaSampling) { TEST(HybridEstimation, CorrectnessViaSampling) {
// 1. Create the factor graph from the nonlinear factor graph. // 1. Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); const auto fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // 2. Eliminate into BN
const Ordering ordering = fg->getHybridOrdering(); const Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
// Set up sampling // Set up sampling
std::mt19937_64 rng(11); std::mt19937_64 rng(11);
// 3. Do sampling // Compute the log-ratio between the Bayes net and the factor graph.
int num_samples = 10; auto compute_ratio = [&](const HybridValues& sample) -> double {
return bn->evaluate(sample) / fg->probPrime(sample);
// Functor to compute the ratio between the
// Bayes net and the factor graph.
auto compute_ratio =
[](const HybridBayesNet::shared_ptr& bayesNet,
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio);
return ratio;
}; };
// The error evaluated by the factor graph and the Bayes net should differ by // The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant. // the normalizing term computed via the Bayes net determinant.
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
double ratio = compute_ratio(bn, fg, sample); double expected_ratio = compute_ratio(sample);
// regression // regression
EXPECT_DOUBLES_EQUAL(1.0, ratio, 1e-9); EXPECT_DOUBLES_EQUAL(0.728588, expected_ratio, 1e-6);
// 4. Check that all samples == constant // 3. Do sampling
constexpr int num_samples = 10;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net // Sample from the bayes net
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9); // 4. Check that the ratio is constant.
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-6);
} }
} }

View File

@ -47,6 +47,7 @@
#include <vector> #include <vector>
#include "Switching.h" #include "Switching.h"
#include "TinyHybridExample.h"
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -133,7 +134,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
auto dc = result->at(2)->asDiscrete(); auto dc = result->at(2)->asDiscrete();
DiscreteValues dv; DiscreteValues dv;
dv[M(1)] = 0; dv[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3); // Regression test
EXPECT_DOUBLES_EQUAL(0.62245933120185448, dc->operator()(dv), 1e-3);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -612,6 +614,108 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
EXPECT(assert_equal(expected_probs, probs, 1e-7)); EXPECT(assert_equal(expected_probs, probs, 1e-7));
} }
/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
EXPECT_LONGS_EQUAL(3, fg.size());
auto sum = fg.assembleGraphTree();
// Get mixture factor:
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
using GF = GaussianFactor::shared_ptr;
// Get prior factor:
const GF prior =
boost::dynamic_pointer_cast<HybridGaussianFactor>(fg.at(1))->inner();
// Create DiscreteValues for both 0 and 1:
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
// Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expectedSum{
M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
mixture->constant(d0)},
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
mixture->constant(d1)}};
EXPECT(assert_equal(expectedSum(d0), sum(d0), 1e-5));
EXPECT(assert_equal(expectedSum(d1), sum(d1), 1e-5));
}
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
using tiny::mode;
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(14.1421), I_1x1 * 2.82843),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
expectedBayesNet.emplaceMixture(gm); // copy :-(
// Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "74/26");
// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}
/* ****************************************************************************/
// Check that eliminating tiny net with 2 measurements yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements,
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}});
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
using tiny::mode;
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(17.3205), I_1x1 * 3.4641),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.274), I_1x1 * 2.0548);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
expectedBayesNet.emplaceMixture(gm); // copy :-(
// Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "23/77");
// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -177,19 +177,19 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
// Check if the clique conditional generated from incremental elimination // Check if the clique conditional generated from incremental elimination
// matches that of batch elimination. // matches that of batch elimination.
@ -199,10 +199,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
isam[M(1)]->conditional()->inner()); isam[M(1)]->conditional()->inner());
// Account for the probability terms from evaluating continuous FGs // Account for the probability terms from evaluating continuous FGs
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2}; vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
auto expectedConditional = auto expectedConditional =
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs); boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6)); EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -443,7 +443,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
ordering.clear(); ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
discreteBayesNet = discreteBayesNet =
*discrete_fg.eliminateSequential(ordering, EliminateForMPE); *discrete_fg.eliminateSequential(ordering, EliminateDiscrete);
} }
// Create ordering. // Create ordering.
@ -638,22 +638,30 @@ conditional 2: Hybrid P( x2 | m0 m1)
0 0 Leaf p(x2) 0 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1489 ] d = [ -10.1489 ]
mean: 1 elements
x2: -1.0099
No noise model No noise model
0 1 Leaf p(x2) 0 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1479 ] d = [ -10.1479 ]
mean: 1 elements
x2: -1.0098
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x2) 1 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0504 ] d = [ -10.0504 ]
mean: 1 elements
x2: -1.0001
No noise model No noise model
1 1 Leaf p(x2) 1 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0494 ] d = [ -10.0494 ]
mean: 1 elements
x2: -1
No noise model No noise model
)"; )";

View File

@ -191,24 +191,23 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete(); *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
double m00_prob = decisionTree(m00); double m00_prob = decisionTree(m00);
auto discreteConditional = auto discreteConditional = bayesTree[M(1)]->conditional()->asDiscrete();
bayesTree[M(1)]->conditional()->asDiscrete();
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.0619233, m00_prob, 1e-5)); EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
// Check if the clique conditional generated from incremental elimination // Check if the clique conditional generated from incremental elimination
// matches that of batch elimination. // matches that of batch elimination.
@ -217,10 +216,10 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
bayesTree[M(1)]->conditional()->inner()); bayesTree[M(1)]->conditional()->inner());
// Account for the probability terms from evaluating continuous FGs // Account for the probability terms from evaluating continuous FGs
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2}; vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
auto expectedConditional = auto expectedConditional =
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs); boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6)); EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -358,10 +357,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
// Run update with pruning // Run update with pruning
size_t maxComponents = 5; size_t maxComponents = 5;
incrementalHybrid.update(graph1, initial); incrementalHybrid.update(graph1, initial);
incrementalHybrid.prune(maxComponents);
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree(); HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with 4 hybrid nodes, // Check if we have a bayes tree with 4 hybrid nodes,
// each with 2, 4, 8, and 5 (pruned) leaves respetively. // each with 2, 4, 8, and 5 (pruned) leaves respetively.
EXPECT_LONGS_EQUAL(4, bayesTree.size()); EXPECT_LONGS_EQUAL(4, bayesTree.size());
@ -383,10 +381,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
// Run update with pruning a second time. // Run update with pruning a second time.
incrementalHybrid.update(graph2, initial); incrementalHybrid.update(graph2, initial);
incrementalHybrid.prune(maxComponents);
bayesTree = incrementalHybrid.bayesTree(); bayesTree = incrementalHybrid.bayesTree();
bayesTree.prune(maxComponents);
// Check if we have a bayes tree with pruned hybrid nodes, // Check if we have a bayes tree with pruned hybrid nodes,
// with 5 (pruned) leaves. // with 5 (pruned) leaves.
CHECK_EQUAL(5, bayesTree.size()); CHECK_EQUAL(5, bayesTree.size());

View File

@ -70,8 +70,7 @@ MixtureFactor
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Test the error of the MixtureFactor static MixtureFactor getMixtureFactor() {
TEST(MixtureFactor, Error) {
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
double between0 = 0.0; double between0 = 0.0;
@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) {
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model); boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1}; std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); return MixtureFactor({X(1), X(2)}, {m1}, factors);
}
/* ************************************************************************* */
// Test the error of the MixtureFactor
TEST(MixtureFactor, Error) {
auto mixtureFactor = getMixtureFactor();
Values continuousValues; Values continuousValues;
continuousValues.insert<double>(X(1), 0); continuousValues.insert<double>(X(1), 0);
@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) {
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> errors = {0.5, 0}; std::vector<double> errors = {0.5, 0};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors); AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) {
EXPECT(assert_equal(expected_error, error_tree)); EXPECT(assert_equal(expected_error, error_tree));
} }
/* ************************************************************************* */
// Test dim of the MixtureFactor
TEST(MixtureFactor, Dim) {
auto mixtureFactor = getMixtureFactor();
EXPECT_LONGS_EQUAL(1, mixtureFactor.dim());
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -0,0 +1,179 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testSerializationHybrid.cpp
* @brief Unit tests for hybrid serialization
* @author Varun Agrawal
* @date January 2023
*/
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include "Switching.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
using namespace serializationTestHelpers;
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
"gtsam_GaussianMixtureFactor_Factors");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
"gtsam_GaussianMixtureFactor_Factors_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
"gtsam_GaussianMixtureFactor_Factors_Choice");
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
"gtsam_GaussianMixture_Conditionals");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
"gtsam_GaussianMixture_Conditionals_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
"gtsam_GaussianMixture_Conditionals_Choice");
// Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
/* ****************************************************************************/
// Test HybridGaussianFactor serialization.
TEST(HybridSerialization, HybridGaussianFactor) {
const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1));
EXPECT(equalsObj<HybridGaussianFactor>(factor));
EXPECT(equalsXML<HybridGaussianFactor>(factor));
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
}
/* ****************************************************************************/
// Test HybridDiscreteFactor serialization.
TEST(HybridSerialization, HybridDiscreteFactor) {
DiscreteKeys discreteKeys{{M(0), 2}};
const HybridDiscreteFactor factor(
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
}
/* ****************************************************************************/
// Test GaussianMixtureFactor serialization.
TEST(HybridSerialization, GaussianMixtureFactor) {
KeyVector continuousKeys{X(0)};
DiscreteKeys discreteKeys{{M(0), 2}};
auto A = Matrix::Zero(2, 1);
auto b0 = Matrix::Zero(2, 1);
auto b1 = Matrix::Ones(2, 1);
auto f0 = boost::make_shared<JacobianFactor>(X(0), A, b0);
auto f1 = boost::make_shared<JacobianFactor>(X(0), A, b1);
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors);
EXPECT(equalsObj<GaussianMixtureFactor>(factor));
EXPECT(equalsXML<GaussianMixtureFactor>(factor));
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
}
/* ****************************************************************************/
// Test HybridConditional serialization.
TEST(HybridSerialization, HybridConditional) {
const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity();
const auto conditional = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const HybridConditional hc(conditional);
EXPECT(equalsObj<HybridConditional>(hc));
EXPECT(equalsXML<HybridConditional>(hc));
EXPECT(equalsBinary<HybridConditional>(hc));
}
/* ****************************************************************************/
// Test GaussianMixture serialization.
TEST(HybridSerialization, GaussianMixture) {
const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity();
const auto conditional0 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const auto conditional1 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
{conditional0, conditional1});
EXPECT(equalsObj<GaussianMixture>(gm));
EXPECT(equalsXML<GaussianMixture>(gm));
EXPECT(equalsBinary<GaussianMixture>(gm));
}
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -67,7 +67,7 @@ namespace gtsam {
GaussianConditional GaussianConditional::FromMeanAndStddev(Key key, GaussianConditional GaussianConditional::FromMeanAndStddev(Key key,
const Vector& mu, const Vector& mu,
double sigma) { double sigma) {
// |Rx - d| = |x-(Ay + b)|/sigma // |Rx - d| = |x - mu|/sigma
const Matrix R = Matrix::Identity(mu.size(), mu.size()); const Matrix R = Matrix::Identity(mu.size(), mu.size());
const Vector& d = mu; const Vector& d = mu;
return GaussianConditional(key, d, R, return GaussianConditional(key, d, R,
@ -120,6 +120,10 @@ namespace gtsam {
<< endl; << endl;
} }
cout << formatMatrixIndented(" d = ", getb(), true) << "\n"; cout << formatMatrixIndented(" d = ", getb(), true) << "\n";
if (nrParents() == 0) {
const auto mean = solve({}); // solve for mean.
mean.print(" mean");
}
if (model_) if (model_)
model_->print(" Noise model: "); model_->print(" Noise model: ");
else else
@ -189,7 +193,7 @@ double GaussianConditional::logNormalizationConstant() const {
/* ************************************************************************* */ /* ************************************************************************* */
// density = k exp(-error(x)) // density = k exp(-error(x))
// log = log(k) -error(x) - 0.5 * n*log(2*pi) // log = log(k) -error(x)
double GaussianConditional::logDensity(const VectorValues& x) const { double GaussianConditional::logDensity(const VectorValues& x) const {
return logNormalizationConstant() - error(x); return logNormalizationConstant() - error(x);
} }

View File

@ -466,6 +466,31 @@ TEST(GaussianConditional, sample) {
// EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5)); // EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
} }
/* ************************************************************************* */
TEST(GaussianConditional, LogNormalizationConstant) {
// Create univariate standard gaussian conditional
auto std_gaussian =
GaussianConditional::FromMeanAndStddev(X(0), Vector1::Zero(), 1.0);
VectorValues values;
values.insert(X(0), Vector1::Zero());
double logDensity = std_gaussian.logDensity(values);
// Regression.
// These values were computed by hand for a univariate standard gaussian.
EXPECT_DOUBLES_EQUAL(-0.9189385332046727, logDensity, 1e-9);
EXPECT_DOUBLES_EQUAL(0.3989422804014327, exp(logDensity), 1e-9);
// Similar test for multivariate gaussian but with sigma 2.0
double sigma = 2.0;
auto conditional = GaussianConditional::FromMeanAndStddev(X(0), Vector3::Zero(), sigma);
VectorValues x;
x.insert(X(0), Vector3::Zero());
Matrix3 Sigma = I_3x3 * sigma * sigma;
double expectedLogNormalizingConstant = log(1 / sqrt((2 * M_PI * Sigma).determinant()));
EXPECT_DOUBLES_EQUAL(expectedLogNormalizingConstant, conditional.logNormalizationConstant(), 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(GaussianConditional, Print) { TEST(GaussianConditional, Print) {
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished(); Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
@ -482,6 +507,8 @@ TEST(GaussianConditional, Print) {
" R = [ 1 0 ]\n" " R = [ 1 0 ]\n"
" [ 0 1 ]\n" " [ 0 1 ]\n"
" d = [ 20 40 ]\n" " d = [ 20 40 ]\n"
" mean: 1 elements\n"
" x0: 20 40\n"
"isotropic dim=2 sigma=3\n"; "isotropic dim=2 sigma=3\n";
EXPECT(assert_print_equal(expected, conditional, "GaussianConditional")); EXPECT(assert_print_equal(expected, conditional, "GaussianConditional"));

View File

@ -18,9 +18,9 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, GaussianMixture, GaussianMixtureFactor, HybridBayesNet,
HybridGaussianFactorGraph, JacobianFactor, Ordering, HybridGaussianFactorGraph, HybridValues, JacobianFactor,
noiseModel) Ordering, noiseModel)
class TestHybridGaussianFactorGraph(GtsamTestCase): class TestHybridGaussianFactorGraph(GtsamTestCase):
@ -82,10 +82,12 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1) self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod @staticmethod
def tiny(num_measurements: int = 1) -> HybridBayesNet: def tiny(num_measurements: int = 1, prior_mean: float = 5.0,
prior_sigma: float = 0.5) -> HybridBayesNet:
""" """
Create a tiny two variable hybrid model which represents Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). the generative probability P(Z, x0, mode) = P(Z|x0, mode)P(x0)P(mode).
num_measurements: number of measurements in Z = {z0, z1...}
""" """
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
@ -94,23 +96,24 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
mode = (M(0), 2) mode = (M(0), 2)
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement. # Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
I = np.eye(1) I_1x1 = np.eye(1)
keys = DiscreteKeys() keys = DiscreteKeys()
keys.push_back(mode) keys.push_back(mode)
for i in range(num_measurements): for i in range(num_measurements):
conditional0 = GaussianConditional.FromMeanAndStddev(Z(i), conditional0 = GaussianConditional.FromMeanAndStddev(Z(i),
I, I_1x1,
X(0), [0], X(0), [0],
sigma=0.5) sigma=0.5)
conditional1 = GaussianConditional.FromMeanAndStddev(Z(i), conditional1 = GaussianConditional.FromMeanAndStddev(Z(i),
I, I_1x1,
X(0), [0], X(0), [0],
sigma=3) sigma=3)
bayesNet.emplaceMixture([Z(i)], [X(0)], keys, bayesNet.emplaceMixture([Z(i)], [X(0)], keys,
[conditional0, conditional1]) [conditional0, conditional1])
# Create prior on X(0). # Create prior on X(0).
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) prior_on_x0 = GaussianConditional.FromMeanAndStddev(
X(0), [prior_mean], prior_sigma)
bayesNet.addGaussian(prior_on_x0) bayesNet.addGaussian(prior_on_x0)
# Add prior on mode. # Add prior on mode.
@ -118,8 +121,41 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet return bayesNet
def test_evaluate(self):
"""Test evaluate with two different prior noise models."""
# TODO(dellaert): really a HBN test
# Create a tiny Bayes net P(x0) P(m0) P(z0|x0)
bayesNet1 = self.tiny(prior_sigma=0.5, num_measurements=1)
bayesNet2 = self.tiny(prior_sigma=5.0, num_measurements=1)
# bn1: # 1/sqrt(2*pi*0.5^2)
# bn2: # 1/sqrt(2*pi*5.0^2)
expected_ratio = np.sqrt(2*np.pi*5.0**2)/np.sqrt(2*np.pi*0.5**2)
mean0 = HybridValues()
mean0.insert(X(0), [5.0])
mean0.insert(Z(0), [5.0])
mean0.insert(M(0), 0)
self.assertAlmostEqual(bayesNet1.evaluate(mean0) /
bayesNet2.evaluate(mean0), expected_ratio,
delta=1e-9)
mean1 = HybridValues()
mean1.insert(X(0), [5.0])
mean1.insert(Z(0), [5.0])
mean1.insert(M(0), 1)
self.assertAlmostEqual(bayesNet1.evaluate(mean1) /
bayesNet2.evaluate(mean1), expected_ratio,
delta=1e-9)
@staticmethod @staticmethod
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): def measurements(sample: HybridValues, indices) -> gtsam.VectorValues:
"""Create measurements from a sample, grabbing Z(i) for indices."""
measurements = gtsam.VectorValues()
for i in indices:
measurements.insert(Z(i), sample.at(Z(i)))
return measurements
@classmethod
def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet,
sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements. """Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net. and thus represents the same joint probability as the Bayes net.
@ -128,31 +164,27 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
num_measurements = bayesNet.size() - 2 num_measurements = bayesNet.size() - 2
for i in range(num_measurements): for i in range(num_measurements):
conditional = bayesNet.atMixture(i) conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues() factor = conditional.likelihood(cls.measurements(sample, [i]))
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor) fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements)) fg.push_back(bayesNet.atGaussian(num_measurements))
fg.push_back(bayesNet.atDiscrete(num_measurements+1)) fg.push_back(bayesNet.atDiscrete(num_measurements+1))
return fg return fg
@classmethod @classmethod
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): def estimate_marginals(cls, target, proposal_density: HybridBayesNet,
"""Do importance sampling to get an estimate of the discrete marginal P(mode).""" N=10000):
# Use prior on x0, mode as proposal density. """Do importance sampling to estimate discrete marginal P(mode)."""
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) # Allocate space for marginals on mode.
# Allocate space for marginals.
marginals = np.zeros((2,)) marginals = np.zeros((2,))
# Do importance sampling. # Do importance sampling.
num_measurements = bayesNet.size() - 2
for s in range(N): for s in range(N):
proposed = prior.sample() proposed = proposal_density.sample() # sample from proposal
for i in range(num_measurements): target_proposed = target(proposed) # evaluate target
z_i = sample.at(Z(i)) # print(target_proposed, proposal_density.evaluate(proposed))
proposed.insert(Z(i), z_i) weight = target_proposed / proposal_density.evaluate(proposed)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) # print weight:
# print(f"weight: {weight}")
marginals[proposed.atDiscrete(M(0))] += weight marginals[proposed.atDiscrete(M(0))] += weight
# print marginals: # print marginals:
@ -161,72 +193,146 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
def test_tiny(self): def test_tiny(self):
"""Test a tiny two variable hybrid model.""" """Test a tiny two variable hybrid model."""
bayesNet = self.tiny() # P(x0)P(mode)P(z0|x0,mode)
sample = bayesNet.sample() prior_sigma = 0.5
# print(sample) bayesNet = self.tiny(prior_sigma=prior_sigma)
# Deterministic values exactly at the mean, for both x and Z:
values = HybridValues()
values.insert(X(0), [5.0])
values.insert(M(0), 0) # low-noise, standard deviation 0.5
z0: float = 5.0
values.insert(Z(0), [z0])
def unnormalized_posterior(x):
"""Posterior is proportional to joint, centered at 5.0 as well."""
x.insert(Z(0), [z0])
return bayesNet.evaluate(x)
# Create proposal density on (x0, mode), making sure it has same mean:
posterior_information = 1/(prior_sigma**2) + 1/(0.5**2)
posterior_sigma = posterior_information**(-0.5)
proposal_density = self.tiny(
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
# Estimate marginals using importance sampling. # Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample) marginals = self.estimate_marginals(
# print(f"True mode: {sample.atDiscrete(M(0))}") target=unnormalized_posterior, proposal_density=proposal_density)
# print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=0; Z) = {marginals[0]}")
# print(f"P(mode=1; Z) = {marginals[1]}")
# Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
fg = self.factor_graph_from_bayes_net(bayesNet, values)
self.assertEqual(fg.size(), 3)
# Test elimination.
ordering = gtsam.Ordering()
ordering.push_back(X(0))
ordering.push_back(M(0))
posterior = fg.eliminateSequential(ordering)
def true_posterior(x):
"""Posterior from elimination."""
x.insert(Z(0), [z0])
return posterior.evaluate(x)
# Estimate marginals using importance sampling.
marginals = self.estimate_marginals(
target=true_posterior, proposal_density=proposal_density)
# print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=0; z0) = {marginals[0]}") # print(f"P(mode=0; z0) = {marginals[0]}")
# print(f"P(mode=1; z0) = {marginals[1]}") # print(f"P(mode=1; z0) = {marginals[1]}")
# Check that the estimate is close to the true value. # Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 3)
@staticmethod @staticmethod
def calculate_ratio(bayesNet: HybridBayesNet, def calculate_ratio(bayesNet: HybridBayesNet,
fg: HybridGaussianFactorGraph, fg: HybridGaussianFactorGraph,
sample: HybridValues): sample: HybridValues):
"""Calculate ratio between Bayes net probability and the factor graph.""" """Calculate ratio between Bayes net and factor graph."""
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
fg.probPrime(sample) > 0 else 0
def test_ratio(self): def test_ratio(self):
""" """
Given a tiny two variable hybrid model, with 2 measurements, Given a tiny two variable hybrid model, with 2 measurements, test the
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n) ratio of the bayes net model representing P(z,x,n)=P(z|x, n)P(x)P(n)
and the factor graph P(x, n | z)=P(x | n, z)P(n|z), and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
both of which represent the same posterior. both of which represent the same posterior.
""" """
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) # Create generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2) prior_sigma = 0.5
# Sample from the Bayes net. bayesNet = self.tiny(prior_sigma=prior_sigma, num_measurements=2)
sample: HybridValues = bayesNet.sample()
# print(sample) # Deterministic values exactly at the mean, for both x and Z:
values = HybridValues()
values.insert(X(0), [5.0])
values.insert(M(0), 0) # high-noise, standard deviation 3
measurements = gtsam.VectorValues()
measurements.insert(Z(0), [4.0])
measurements.insert(Z(1), [6.0])
values.insert(measurements)
def unnormalized_posterior(x):
"""Posterior is proportional to joint, centered at 5.0 as well."""
x.insert(measurements)
return bayesNet.evaluate(x)
# Create proposal density on (x0, mode), making sure it has same mean:
posterior_information = 1/(prior_sigma**2) + 2.0/(3.0**2)
posterior_sigma = posterior_information**(-0.5)
proposal_density = self.tiny(
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
# Estimate marginals using importance sampling. # Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample) marginals = self.estimate_marginals(
# print(f"True mode: {sample.atDiscrete(M(0))}") target=unnormalized_posterior, proposal_density=proposal_density)
# print(f"P(mode=0; z0, z1) = {marginals[0]}") # print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=1; z0, z1) = {marginals[1]}") # print(f"P(mode=0; Z) = {marginals[0]}")
# print(f"P(mode=1; Z) = {marginals[1]}")
# Check marginals based on sampled mode. # Check that the estimate is close to the true value.
if sample.atDiscrete(M(0)) == 0: self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
self.assertGreater(marginals[0], marginals[1]) self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
else:
self.assertGreater(marginals[1], marginals[0])
fg = self.factor_graph_from_bayes_net(bayesNet, sample) # Convert to factor graph using measurements.
fg = self.factor_graph_from_bayes_net(bayesNet, values)
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
# Calculate ratio between Bayes net probability and the factor graph: # Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(bayesNet, fg, sample) expected_ratio = self.calculate_ratio(bayesNet, fg, values)
# print(f"expected_ratio: {expected_ratio}\n") # print(f"expected_ratio: {expected_ratio}\n")
# Create measurements from the sample.
measurements = gtsam.VectorValues()
for i in range(2):
measurements.insert(Z(i), sample.at(Z(i)))
# Check with a number of other samples. # Check with a number of other samples.
for i in range(10): for i in range(10):
other = bayesNet.sample() samples = bayesNet.sample()
other.update(measurements) samples.update(measurements)
ratio = self.calculate_ratio(bayesNet, fg, other) ratio = self.calculate_ratio(bayesNet, fg, samples)
# print(f"Ratio: {ratio}\n")
if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio)
# Test elimination.
ordering = gtsam.Ordering()
ordering.push_back(X(0))
ordering.push_back(M(0))
posterior = fg.eliminateSequential(ordering)
# Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(posterior, fg, values)
# print(f"expected_ratio: {expected_ratio}\n")
# Check with a number of other samples.
for i in range(10):
samples = posterior.sample()
samples.insert(measurements)
ratio = self.calculate_ratio(posterior, fg, samples)
# print(f"Ratio: {ratio}\n") # print(f"Ratio: {ratio}\n")
if (ratio > 0): if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio) self.assertAlmostEqual(ratio, expected_ratio)