Merge pull request #1360 from borglab/hybrid/elimination
commit
1c411eb5a2
|
|
@ -64,6 +64,9 @@ namespace gtsam {
|
|||
*/
|
||||
size_t nrAssignments_;
|
||||
|
||||
/// Default constructor for serialization.
|
||||
Leaf() {}
|
||||
|
||||
/// Constructor from constant
|
||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
|
@ -154,6 +157,18 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||
}
|
||||
}; // Leaf
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
@ -177,6 +192,9 @@ namespace gtsam {
|
|||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
/// Default constructor for serialization.
|
||||
Choice() {}
|
||||
|
||||
~Choice() override {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||
|
|
@ -428,6 +446,19 @@ namespace gtsam {
|
|||
r->push_back(branch->choose(label, index));
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||
}
|
||||
}; // Choice
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -19,9 +19,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
|
@ -113,6 +115,12 @@ namespace gtsam {
|
|||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||
virtual bool isLeaf() const = 0;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||
};
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
|
||||
|
|
@ -236,7 +244,7 @@ namespace gtsam {
|
|||
/**
|
||||
* @brief Visit all leaves in depth-first fashion.
|
||||
*
|
||||
* @param f (side-effect) Function taking a value.
|
||||
* @param f (side-effect) Function taking the value of the leaf node.
|
||||
*
|
||||
* @note Due to pruning, the number of leaves may not be the same as the
|
||||
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
||||
|
|
@ -245,7 +253,7 @@ namespace gtsam {
|
|||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](int y) { sum += y; };
|
||||
* tree.visitWith(visitor);
|
||||
* tree.visit(visitor);
|
||||
*/
|
||||
template <typename Func>
|
||||
void visit(Func f) const;
|
||||
|
|
@ -261,8 +269,8 @@ namespace gtsam {
|
|||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](int y) { sum += y; };
|
||||
* tree.visitWith(visitor);
|
||||
* auto visitor = [&](const Leaf& leaf) { sum += leaf.constant(); };
|
||||
* tree.visitLeaf(visitor);
|
||||
*/
|
||||
template <typename Func>
|
||||
void visitLeaf(Func f) const;
|
||||
|
|
@ -364,8 +372,19 @@ namespace gtsam {
|
|||
compose(Iterator begin, Iterator end, const L& label) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||
}
|
||||
}; // DecisionTree
|
||||
|
||||
template <class L, class Y>
|
||||
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||
|
||||
/** free versions of apply */
|
||||
|
||||
/// Apply unary operator `op` to DecisionTree `f`.
|
||||
|
|
|
|||
|
|
@ -156,9 +156,9 @@ namespace gtsam {
|
|||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||
const {
|
||||
// Get all possible assignments
|
||||
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
||||
DiscreteKeys pairs = discreteKeys();
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
|
||||
DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
|
||||
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
|
||||
|
||||
// Construct unordered_map with values
|
||||
|
|
|
|||
|
|
@ -231,6 +231,16 @@ namespace gtsam {
|
|||
const Names& names = {}) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
|
||||
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Internal version of choose
|
||||
DiscreteConditional::ADT choose(const DiscreteValues& given,
|
||||
bool forceComplete) const;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
}
|
||||
};
|
||||
// DiscreteConditional
|
||||
|
||||
|
|
|
|||
|
|
@ -20,12 +20,11 @@
|
|||
// #define DT_DEBUG_MEMORY
|
||||
// #define GTSAM_DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@
|
|||
* @date Feb 14, 2011
|
||||
*/
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
|
|
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
|
|||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
DiscreteConditional pAB = prior * conditional;
|
||||
GTSAM_PRINT(pAB);
|
||||
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testSerializtionDiscrete.cpp
|
||||
*
|
||||
* @date January 2023
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
using Tree = gtsam::DecisionTree<string, int>;
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTreeStringInt_Leaf")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTreeStringInt_Choice")
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf")
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test DecisionTree serialization.
|
||||
TEST(DiscreteSerialization, DecisionTree) {
|
||||
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
// Object roundtrip
|
||||
Tree outputObj = create<Tree>();
|
||||
roundtrip<Tree>(tree, outputObj);
|
||||
EXPECT(tree.equals(outputObj));
|
||||
|
||||
// XML roundtrip
|
||||
Tree outputXml = create<Tree>();
|
||||
roundtripXML<Tree>(tree, outputXml);
|
||||
EXPECT(tree.equals(outputXml));
|
||||
|
||||
// Binary roundtrip
|
||||
Tree outputBinary = create<Tree>();
|
||||
roundtripBinary<Tree>(tree, outputBinary);
|
||||
EXPECT(tree.equals(outputBinary));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
|
||||
TEST(DiscreteSerialization, DecisionTreeFactor) {
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
|
||||
|
||||
DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
EXPECT(equalsObj<DecisionTreeFactor::ADT>(tree));
|
||||
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
|
||||
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));
|
||||
|
||||
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
|
||||
EXPECT(equalsObj<DecisionTreeFactor>(f));
|
||||
EXPECT(equalsXML<DecisionTreeFactor>(f));
|
||||
EXPECT(equalsBinary<DecisionTreeFactor>(f));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check serialization for DiscreteConditional & DiscreteDistribution
|
||||
TEST(DiscreteSerialization, DiscreteConditional) {
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
DiscreteKey A(Symbol('x', 1), 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
|
||||
EXPECT(equalsObj<DiscreteConditional>(conditional));
|
||||
EXPECT(equalsXML<DiscreteConditional>(conditional));
|
||||
EXPECT(equalsBinary<DiscreteConditional>(conditional));
|
||||
|
||||
DiscreteDistribution P(A % "3/2/1");
|
||||
EXPECT(equalsObj<DiscreteDistribution>(P));
|
||||
EXPECT(equalsXML<DiscreteDistribution>(P));
|
||||
EXPECT(equalsBinary<DiscreteDistribution>(P));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -51,24 +51,28 @@ GaussianMixture::GaussianMixture(
|
|||
Conditionals(discreteParents, conditionalsList)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::add(
|
||||
const GaussianMixture::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
GaussianFactorGraphTree GaussianMixture::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GraphAndConstant;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
auto result = graph1.graph;
|
||||
result.push_back(graph2.graph);
|
||||
return Y(result, graph1.constant + graph2.constant);
|
||||
};
|
||||
const Sum tree = asGaussianFactorGraphTree();
|
||||
const auto tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
||||
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
||||
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
result.push_back(conditional);
|
||||
if (conditional) {
|
||||
return GraphAndConstant(result, conditional->logNormalizationConstant());
|
||||
} else {
|
||||
return GraphAndConstant(result, 0.0);
|
||||
}
|
||||
};
|
||||
return {conditionals_, lambda};
|
||||
}
|
||||
|
|
@ -98,7 +102,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
|
|||
/* *******************************************************************************/
|
||||
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&lf);
|
||||
return e != nullptr && BaseFactor::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
|
||||
// This will return false if either conditionals_ is empty or e->conditionals_
|
||||
// is empty, but not if both are empty or both are not empty:
|
||||
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
||||
|
||||
// Check the base and the factors:
|
||||
return BaseFactor::equals(*e, tol) &&
|
||||
conditionals_.equals(e->conditionals_,
|
||||
[tol](const GaussianConditional::shared_ptr &f1,
|
||||
const GaussianConditional::shared_ptr &f2) {
|
||||
return f1->equals(*(f2), tol);
|
||||
});
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
@ -146,7 +162,13 @@ KeyVector GaussianMixture::continuousParents() const {
|
|||
/* ************************************************************************* */
|
||||
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||
const VectorValues &frontals) const {
|
||||
// TODO(dellaert): check that values has all frontals
|
||||
// Check that values has all frontals
|
||||
for (auto &&kv : frontals) {
|
||||
if (frontals.find(kv.first) == frontals.end()) {
|
||||
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
|
||||
}
|
||||
}
|
||||
|
||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||
const KeyVector continuousParentKeys = continuousParents();
|
||||
const GaussianMixtureFactor::Factors likelihoods(
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@
|
|||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianMixtureFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
|
|
@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
|
|||
using BaseFactor = HybridFactor;
|
||||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian Conditionals
|
||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
|
|
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/**
|
||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
|
|
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
double error(const HybridValues &values) const override;
|
||||
|
||||
// /// Calculate probability density for given values `x`.
|
||||
// double evaluate(const HybridValues &values) const;
|
||||
|
||||
// /// Evaluate probability density, sugar.
|
||||
// double operator()(const HybridValues &values) const { return
|
||||
// evaluate(values); }
|
||||
|
||||
// /// Calculate log-density for given values `x`.
|
||||
// double logDensity(const HybridValues &values) const;
|
||||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
* `decisionTree`.
|
||||
|
|
@ -186,10 +193,20 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* maintaining the decision tree structure.
|
||||
*
|
||||
* @param sum Decision Tree of Gaussian Factor Graphs
|
||||
* @return Sum
|
||||
* @return GaussianFactorGraphTree
|
||||
*/
|
||||
Sum add(const Sum &sum) const;
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class Archive>
|
||||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
||||
}
|
||||
};
|
||||
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
|
|
|
|||
|
|
@ -81,32 +81,36 @@ void GaussianMixtureFactor::print(const std::string &s,
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
||||
return Mixture(factors_, [](const FactorAndConstant &factor_z) {
|
||||
return factor_z.factor;
|
||||
});
|
||||
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
|
||||
const DiscreteValues &assignment) const {
|
||||
return factors_(assignment).factor;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||
const GaussianMixtureFactor::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
|
||||
return factors_(assignment).constant;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GraphAndConstant;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
auto result = graph1.graph;
|
||||
result.push_back(graph2.graph);
|
||||
return Y(result, graph1.constant + graph2.constant);
|
||||
};
|
||||
const Sum tree = asGaussianFactorGraphTree();
|
||||
const auto tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||
const {
|
||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor_z.factor);
|
||||
return result;
|
||||
return GraphAndConstant(result, factor_z.constant);
|
||||
};
|
||||
return {factors_, wrap};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@
|
|||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianFactorGraph;
|
||||
class HybridValues;
|
||||
class DiscreteValues;
|
||||
class VectorValues;
|
||||
|
|
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
using This = GaussianMixtureFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||
|
||||
/// Gaussian factor and log of normalizing constant.
|
||||
|
|
@ -60,8 +59,10 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
|
||||
// Return error with constant correction.
|
||||
double error(const VectorValues &values) const {
|
||||
// Note minus sign: constant is log of normalization constant for probabilities.
|
||||
// Errors is the negative log-likelihood, hence we subtract the constant here.
|
||||
// Note: constant is log of normalization constant for probabilities.
|
||||
// Errors is the negative log-likelihood,
|
||||
// hence we subtract the constant here.
|
||||
if (!factor) return 0.0; // If nullptr, return 0.0 error
|
||||
return factor->error(values) - constant;
|
||||
}
|
||||
|
||||
|
|
@ -69,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
bool operator==(const FactorAndConstant &other) const {
|
||||
return factor == other.factor && constant == other.constant;
|
||||
}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_NVP(factor);
|
||||
ar &BOOST_SERIALIZATION_NVP(constant);
|
||||
}
|
||||
};
|
||||
|
||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||
|
|
@ -83,9 +93,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @brief Helper function to return factors and functional to create a
|
||||
* DecisionTree of Gaussian Factor Graphs.
|
||||
*
|
||||
* @return Sum (DecisionTree<Key, GaussianFactorGraph>)
|
||||
* @return GaussianFactorGraphTree
|
||||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
|
|
@ -135,12 +145,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
void print(
|
||||
const std::string &s = "GaussianMixtureFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard API
|
||||
/// @{
|
||||
|
||||
/// Getter for the underlying Gaussian Factor Decision Tree.
|
||||
const Mixture factors() const;
|
||||
/// Get factor at a given discrete assignment.
|
||||
sharedFactor factor(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Get constant at a given discrete assignment.
|
||||
double constant(const DiscreteValues &assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||
|
|
@ -150,7 +164,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* variables.
|
||||
* @return Sum
|
||||
*/
|
||||
Sum add(const Sum &sum) const;
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||
|
|
@ -168,11 +182,21 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
double error(const HybridValues &values) const override;
|
||||
|
||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
||||
friend GaussianFactorGraphTree &operator+=(
|
||||
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
|
||||
sum = factor.add(sum);
|
||||
return sum;
|
||||
}
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(factors_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42);
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
|
@ -271,12 +282,15 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
|
|||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
// TODO: should be delegated to derived classes.
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
const auto component = (*gm)(discreteValues);
|
||||
logDensity += component->logDensity(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, evaluate the probability and multiply.
|
||||
logDensity += gc->logDensity(continuousValues);
|
||||
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// Conditional is discrete-only, so return its probability.
|
||||
probability *= dc->operator()(discreteValues);
|
||||
|
|
|
|||
|
|
@ -50,18 +50,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This &bn, double tol = 1e-9) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/// print graph
|
||||
/// GTSAM-style printing
|
||||
void print(
|
||||
const std::string &s = "",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// GTSAM-style equals
|
||||
bool equals(const This& fg, double tol = 1e-9) const;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
|
||||
|
|
@ -102,7 +103,37 @@ void HybridConditional::print(const std::string &s,
|
|||
/* ************************************************************************ */
|
||||
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&other);
|
||||
return e != nullptr && BaseFactor::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
if (auto gm = asMixture()) {
|
||||
auto other = e->asMixture();
|
||||
return other != nullptr && gm->equals(*other, tol);
|
||||
}
|
||||
if (auto gc = asGaussian()) {
|
||||
auto other = e->asGaussian();
|
||||
return other != nullptr && gc->equals(*other, tol);
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
auto other = e->asDiscrete();
|
||||
return other != nullptr && dc->equals(*other, tol);
|
||||
}
|
||||
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridConditional::error(const HybridValues &values) const {
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->error(values);
|
||||
}
|
||||
if (auto gc = asGaussian()) {
|
||||
return gc->error(values.continuous());
|
||||
}
|
||||
if (auto dc = asDiscrete()) {
|
||||
return -log((*dc)(values.discrete()));
|
||||
}
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: conditional type not handled");
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -176,15 +176,7 @@ class GTSAM_EXPORT HybridConditional
|
|||
boost::shared_ptr<Factor> inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying conditional.
|
||||
/// Currently only implemented for Gaussian mixture.
|
||||
double error(const HybridValues& values) const override {
|
||||
if (auto gm = asMixture()) {
|
||||
return gm->error(values);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"HybridConditional::error: only implemented for Gaussian mixture");
|
||||
}
|
||||
}
|
||||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
@ -195,6 +187,20 @@ class GTSAM_EXPORT HybridConditional
|
|||
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||
ar& BOOST_SERIALIZATION_NVP(inner_);
|
||||
|
||||
// register the various casts based on the type of inner_
|
||||
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
|
||||
if (isDiscrete()) {
|
||||
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
|
||||
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||
} else if (isContinuous()) {
|
||||
boost::serialization::void_cast_register<GaussianConditional, Factor>(
|
||||
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
||||
} else {
|
||||
boost::serialization::void_cast_register<GaussianMixture, Factor>(
|
||||
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
|
||||
}
|
||||
}
|
||||
|
||||
}; // HybridConditional
|
||||
|
|
|
|||
|
|
@ -26,7 +26,6 @@
|
|||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
|
||||
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
|
||||
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
|
||||
->discreteKeys()),
|
||||
|
|
@ -40,8 +39,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
|||
/* ************************************************************************ */
|
||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&lf);
|
||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
||||
return e != nullptr && Base::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
if (!Base::equals(*e, tol)) return false;
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor - for serialization.
|
||||
HybridDiscreteFactor() = default;
|
||||
|
||||
// Implicit conversion from a shared ptr of DF
|
||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
||||
|
||||
|
|
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@
|
|||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/inference/Factor.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
|
@ -28,6 +30,36 @@ namespace gtsam {
|
|||
|
||||
class HybridValues;
|
||||
|
||||
/// Gaussian factor graph and log of normalizing constant.
|
||||
struct GraphAndConstant {
|
||||
GaussianFactorGraph graph;
|
||||
double constant;
|
||||
|
||||
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
||||
: graph(graph), constant(constant) {}
|
||||
|
||||
// Check pointer equality.
|
||||
bool operator==(const GraphAndConstant &other) const {
|
||||
return graph == other.graph && constant == other.constant;
|
||||
}
|
||||
|
||||
// Implement GTSAM-style print:
|
||||
void print(const std::string &s = "Graph: ",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const {
|
||||
graph.print(s, formatter);
|
||||
std::cout << "Constant: " << constant << std::endl;
|
||||
}
|
||||
|
||||
// Implement GTSAM-style equals:
|
||||
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
|
||||
return graph.equals(other.graph, tol) &&
|
||||
fabs(constant - other.constant) < tol;
|
||||
}
|
||||
};
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
|
|
@ -160,4 +192,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
|||
template <>
|
||||
struct traits<HybridFactor> : public Testable<HybridFactor> {};
|
||||
|
||||
template <>
|
||||
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
|
|||
/* ************************************************************************* */
|
||||
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&other);
|
||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
||||
return e != nullptr && Base::equals(*e, tol);
|
||||
if (e == nullptr) return false;
|
||||
if (!Base::equals(*e, tol)) return false;
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridGaussianFactor::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
inner_->print("\n", formatter);
|
||||
if (inner_) {
|
||||
inner_->print("\n", formatter);
|
||||
} else {
|
||||
std::cout << "\nGaussian: nullptr" << std::endl;
|
||||
}
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -43,14 +43,17 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
using This = HybridGaussianFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor - for serialization.
|
||||
HybridGaussianFactor() = default;
|
||||
|
||||
/**
|
||||
* Constructor from shared_ptr of GaussianFactor.
|
||||
* Example:
|
||||
* boost::shared_ptr<GaussianFactor> ptr =
|
||||
* boost::make_shared<JacobianFactor>(...);
|
||||
*
|
||||
* auto ptr = boost::make_shared<JacobianFactor>(...);
|
||||
* HybridGaussianFactor factor(ptr);
|
||||
*/
|
||||
explicit HybridGaussianFactor(const boost::shared_ptr<GaussianFactor> &ptr);
|
||||
|
||||
|
|
@ -80,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
*/
|
||||
explicit HybridGaussianFactor(HessianFactor &&hf);
|
||||
|
||||
public:
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
|
|
@ -99,9 +102,18 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
/// Return pointer to the internal Gaussian factor.
|
||||
GaussianFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
/// Return the error of the underlying Gaussian factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
|
|
|
|||
|
|
@ -59,51 +59,50 @@ namespace gtsam {
|
|||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianMixtureFactor::Sum &addGaussian(
|
||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||
using Y = GaussianFactorGraph;
|
||||
static GaussianFactorGraphTree addGaussian(
|
||||
const GaussianFactorGraphTree &gfgTree,
|
||||
const GaussianFactor::shared_ptr &factor) {
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (sum.empty()) {
|
||||
if (gfgTree.empty()) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
sum = GaussianMixtureFactor::Sum(result);
|
||||
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
||||
|
||||
} else {
|
||||
auto add = [&factor](const Y &graph) {
|
||||
auto result = graph;
|
||||
auto add = [&factor](const GraphAndConstant &graph_z) {
|
||||
auto result = graph_z.graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
return GraphAndConstant(result, graph_z.constant);
|
||||
};
|
||||
sum = sum.apply(add);
|
||||
return gfgTree.apply(add);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianMixtureFactor::Sum sumFrontals(
|
||||
const HybridGaussianFactorGraph &factors) {
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
// TODO(dellaert): Implementation-wise, it's probably more efficient to first
|
||||
// collect the discrete keys, and then loop over all assignments to populate a
|
||||
// vector.
|
||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||
gttic(assembleGraphTree);
|
||||
|
||||
GaussianMixtureFactor::Sum sum;
|
||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
||||
GaussianFactorGraphTree result;
|
||||
|
||||
for (auto &f : factors) {
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (f->isHybrid()) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
sum = gm->add(sum);
|
||||
result = gm->add(result);
|
||||
}
|
||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
sum = gm->asMixture()->add(sum);
|
||||
result = gm->asMixture()->add(result);
|
||||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
deferredFactors.push_back(gf->inner());
|
||||
result = addGaussian(result, gf->inner());
|
||||
}
|
||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
deferredFactors.push_back(cg->asGaussian());
|
||||
result = addGaussian(result, cg->asGaussian());
|
||||
}
|
||||
|
||||
} else if (f->isDiscrete()) {
|
||||
|
|
@ -125,17 +124,13 @@ GaussianMixtureFactor::Sum sumFrontals(
|
|||
}
|
||||
}
|
||||
|
||||
for (auto &f : deferredFactors) {
|
||||
sum = addGaussian(sum, f);
|
||||
}
|
||||
gttoc(assembleGraphTree);
|
||||
|
||||
gttoc(sum);
|
||||
|
||||
return sum;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
GaussianFactorGraph gfg;
|
||||
|
|
@ -156,7 +151,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg;
|
||||
|
|
@ -173,48 +168,53 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
}
|
||||
|
||||
auto result = EliminateForMPE(dfg, frontalKeys);
|
||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridDiscreteFactor>(result.second)};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||
// otherwise create a GFG with a single (null) factor.
|
||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
|
||||
bool hasNull =
|
||||
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
|
||||
};
|
||||
return GaussianFactorGraphTree(sum, emptyGaussian);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys,
|
||||
const KeySet &continuousSeparator,
|
||||
const KeyVector &continuousSeparator,
|
||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||
// NOTE: since we use the special JunctionTree,
|
||||
// only possibility is continuous conditioned on discrete.
|
||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
|
||||
// sum out frontals, this is the factor 𝜏 on the separator
|
||||
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved.
|
||||
GaussianFactorGraphTree sum = factors.assembleGraphTree();
|
||||
|
||||
// If a tree leaf contains nullptr,
|
||||
// convert that leaf to an empty GaussianFactorGraph.
|
||||
// Needed since the DecisionTree will otherwise create
|
||||
// a GFG with a single (null) factor.
|
||||
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
|
||||
bool hasNull =
|
||||
std::any_of(gfg.begin(), gfg.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
|
||||
return hasNull ? GaussianFactorGraph() : gfg;
|
||||
};
|
||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||
// This is done after assembly since it is non-trivial to keep track of which
|
||||
// FG has a nullptr as we're looping over the factors.
|
||||
sum = removeEmpty(sum);
|
||||
|
||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
GaussianMixtureFactor::FactorAndConstant>;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||
if (graph.empty()) {
|
||||
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
||||
if (graph_z.graph.empty()) {
|
||||
return {nullptr, {nullptr, 0.0}};
|
||||
}
|
||||
|
||||
|
|
@ -222,24 +222,34 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
gttic_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
boost::shared_ptr<GaussianFactor>>
|
||||
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
||||
boost::shared_ptr<GaussianConditional> conditional;
|
||||
boost::shared_ptr<GaussianFactor> newFactor;
|
||||
boost::tie(conditional, newFactor) =
|
||||
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
||||
|
||||
// Initialize the keysOfEliminated to be the keys of the
|
||||
// eliminated GaussianConditional
|
||||
keysOfEliminated = conditional_factor.first->keys();
|
||||
keysOfSeparator = conditional_factor.second->keys();
|
||||
// Get the log of the log normalization constant inverse and
|
||||
// add it to the previous constant.
|
||||
const double logZ =
|
||||
graph_z.constant - conditional->logNormalizationConstant();
|
||||
// Get the log of the log normalization constant inverse.
|
||||
// double logZ = -conditional->logNormalizationConstant();
|
||||
// // IF this is the last continuous variable to eliminated, we need to
|
||||
// // calculate the error here: the value of all factors at the mean, see
|
||||
// // ml_map_rao.pdf.
|
||||
// if (continuousSeparator.empty()) {
|
||||
// const auto posterior_mean = conditional->solve(VectorValues());
|
||||
// logZ += graph_z.graph.error(posterior_mean);
|
||||
// }
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
gttoc_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
return {conditional_factor.first, {conditional_factor.second, 0.0}};
|
||||
return {conditional, {newFactor, logZ}};
|
||||
};
|
||||
|
||||
// Perform elimination!
|
||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
tictoc_print_();
|
||||
|
|
@ -247,46 +257,50 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
#endif
|
||||
|
||||
// Separate out decision tree into conditionals and remaining factors.
|
||||
auto pair = unzip(eliminationResults);
|
||||
const auto &separatorFactors = pair.second;
|
||||
GaussianMixture::Conditionals conditionals;
|
||||
GaussianMixtureFactor::Factors newFactors;
|
||||
std::tie(conditionals, newFactors) = unzip(eliminationResults);
|
||||
|
||||
// Create the GaussianMixture from the conditionals
|
||||
auto conditional = boost::make_shared<GaussianMixture>(
|
||||
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
|
||||
auto gaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||
|
||||
// If there are no more continuous parents, then we should create here a
|
||||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
// If there are no more continuous parents, then we should create a
|
||||
// DiscreteFactor here, with the error for each discrete choice.
|
||||
if (continuousSeparator.empty()) {
|
||||
auto factorProb =
|
||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
||||
GaussianFactor::shared_ptr factor = factor_z.factor;
|
||||
if (!factor) {
|
||||
return 0.0; // If nullptr, return 0.0 probability
|
||||
} else {
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
double error =
|
||||
0.5 * std::abs(factor->augmentedInformation().determinant()) +
|
||||
factor_z.constant;
|
||||
return std::exp(-error);
|
||||
}
|
||||
// This is the probability q(μ) at the MLE point.
|
||||
// factor_z.factor is a factor without keys,
|
||||
// just containing the residual.
|
||||
return exp(-factor_z.error(VectorValues()));
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||
|
||||
auto discreteFactor =
|
||||
const DecisionTree<Key, double> fdt(newFactors, factorProb);
|
||||
// // Normalize the values of decision tree to be valid probabilities
|
||||
// double sum = 0.0;
|
||||
// auto visitor = [&](double y) { sum += y; };
|
||||
// fdt.visit(visitor);
|
||||
// // Check if sum is 0, and update accordingly.
|
||||
// if (sum == 0) {
|
||||
// sum = 1.0;
|
||||
// }
|
||||
// fdt = DecisionTree<Key, double>(fdt,
|
||||
// [sum](const double &x) { return x / sum;
|
||||
// });
|
||||
const auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(conditional),
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||
|
||||
} else {
|
||||
// Create a resulting GaussianMixtureFactor on the separator.
|
||||
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||
discreteSeparator, separatorFactors);
|
||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
boost::make_shared<GaussianMixtureFactor>(
|
||||
continuousSeparator, discreteSeparator, newFactors)};
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************
|
||||
* Function to eliminate variables **under the following assumptions**:
|
||||
* 1. When the ordering is fully continuous, and the graph only contains
|
||||
|
|
@ -383,12 +397,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
// Fill in discrete discrete separator keys and continuous separator keys.
|
||||
std::set<DiscreteKey> discreteSeparatorSet;
|
||||
KeySet continuousSeparator;
|
||||
KeyVector continuousSeparator;
|
||||
for (auto &k : separatorKeys) {
|
||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||
} else {
|
||||
continuousSeparator.insert(k);
|
||||
continuousSeparator.push_back(k);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -463,15 +477,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixtureFactor::shared_ptr gaussianMixture =
|
||||
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
|
||||
// Compute factor error.
|
||||
factor_error = gaussianMixture->error(continuousValues);
|
||||
|
||||
// If first factor, assign error, else add it.
|
||||
if (idx == 0) {
|
||||
error_tree = factor_error;
|
||||
} else {
|
||||
error_tree = error_tree + factor_error;
|
||||
}
|
||||
// Compute factor error and add it.
|
||||
error_tree = error_tree + gaussianMixture->error(continuousValues);
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
|
@ -118,14 +119,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
: Base(graph) {}
|
||||
|
||||
/// @}
|
||||
/// @name Adding factors.
|
||||
/// @{
|
||||
|
||||
using Base::empty;
|
||||
using Base::reserve;
|
||||
using Base::size;
|
||||
using Base::operator[];
|
||||
using Base::add;
|
||||
using Base::push_back;
|
||||
using Base::resize;
|
||||
using Base::reserve;
|
||||
|
||||
/// Add a Jacobian factor to the factor graph.
|
||||
void add(JacobianFactor&& factor);
|
||||
|
|
@ -172,6 +171,25 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
}
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
// TODO(dellaert): customize print and equals.
|
||||
// void print(const std::string& s = "HybridGaussianFactorGraph",
|
||||
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
|
||||
// override;
|
||||
// bool equals(const This& fg, double tol = 1e-9) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
using Base::empty;
|
||||
using Base::size;
|
||||
using Base::operator[];
|
||||
using Base::resize;
|
||||
|
||||
/**
|
||||
* @brief Compute error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
|
|
@ -217,6 +235,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @return const Ordering
|
||||
*/
|
||||
const Ordering getHybridOrdering() const;
|
||||
|
||||
/**
|
||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||
* graph.
|
||||
*
|
||||
* For example, if there are two mixture factors, one with a discrete key A
|
||||
* and one with a discrete key B, then the decision tree will have two levels,
|
||||
* one for A and one for B. The leaves of the tree will be the Gaussian
|
||||
* factors that have only continuous keys.
|
||||
*/
|
||||
GaussianFactorGraphTree assembleGraphTree() const;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -99,9 +99,11 @@ void HybridNonlinearISAM::print(const string& s,
|
|||
const KeyFormatter& keyFormatter) const {
|
||||
cout << s << "ReorderInterval: " << reorderInterval_
|
||||
<< " Current Count: " << reorderCounter_ << endl;
|
||||
isam_.print("HybridGaussianISAM:\n", keyFormatter);
|
||||
std::cout << "HybridGaussianISAM:" << std::endl;
|
||||
isam_.print("", keyFormatter);
|
||||
linPoint_.print("Linearization Point:\n", keyFormatter);
|
||||
factors_.print("Nonlinear Graph:\n", keyFormatter);
|
||||
std::cout << "Nonlinear Graph:" << std::endl;
|
||||
factors_.print("", keyFormatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
|||
const Values& getLinearizationPoint() const { return linPoint_; }
|
||||
|
||||
/** Return the current discrete assignment */
|
||||
const DiscreteValues& getAssignment() const { return assignment_; }
|
||||
const DiscreteValues& assignment() const { return assignment_; }
|
||||
|
||||
/** get underlying nonlinear graph */
|
||||
const HybridNonlinearFactorGraph& getFactorsUnsafe() const {
|
||||
|
|
|
|||
|
|
@ -168,6 +168,15 @@ class GTSAM_EXPORT HybridValues {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/// Extract continuous values with given keys.
|
||||
VectorValues continuousSubset(const KeyVector& keys) const {
|
||||
VectorValues measurements;
|
||||
for (const auto& key : keys) {
|
||||
measurements.insert(key, continuous_.at(key));
|
||||
}
|
||||
return measurements;
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor {
|
|||
}
|
||||
|
||||
/// Error for HybridValues is not provided for nonlinear hybrid factor.
|
||||
double error(const HybridValues &values) const override {
|
||||
double error(const HybridValues& values) const override {
|
||||
throw std::runtime_error(
|
||||
"MixtureFactor::error(HybridValues) not implemented.");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the dimension of the factor (number of rows on linearization).
|
||||
* Returns the dimension of the first component factor.
|
||||
* @return size_t
|
||||
*/
|
||||
size_t dim() const {
|
||||
// TODO(Varun)
|
||||
throw std::runtime_error("MixtureFactor::dim not implemented.");
|
||||
const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
|
||||
auto factor = factors_(assignments.at(0));
|
||||
return factor->dim();
|
||||
}
|
||||
|
||||
/// Testable
|
||||
|
|
|
|||
|
|
@ -40,6 +40,15 @@ virtual class HybridFactor {
|
|||
bool empty() const;
|
||||
size_t size() const;
|
||||
gtsam::KeyVector keys() const;
|
||||
|
||||
// Standard interface:
|
||||
double error(const gtsam::HybridValues &values) const;
|
||||
bool isDiscrete() const;
|
||||
bool isContinuous() const;
|
||||
bool isHybrid() const;
|
||||
size_t nrContinuous() const;
|
||||
gtsam::DiscreteKeys discreteKeys() const;
|
||||
gtsam::KeyVector continuousKeys() const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
|
|
@ -50,7 +59,13 @@ virtual class HybridConditional {
|
|||
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
|
||||
size_t nrFrontals() const;
|
||||
size_t nrParents() const;
|
||||
|
||||
// Standard interface:
|
||||
gtsam::GaussianMixture* asMixture() const;
|
||||
gtsam::GaussianConditional* asGaussian() const;
|
||||
gtsam::DiscreteConditional* asDiscrete() const;
|
||||
gtsam::Factor* inner();
|
||||
double error(const gtsam::HybridValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
|
|
@ -61,6 +76,7 @@ virtual class HybridDiscreteFactor {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
|
||||
gtsam::Factor* inner();
|
||||
double error(const gtsam::HybridValues &values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010-2023, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file TinyHybridExample.h
|
||||
* @date December, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#pragma once
|
||||
|
||||
namespace gtsam {
|
||||
namespace tiny {
|
||||
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
// Create mode key: 0 is low-noise, 1 is high-noise.
|
||||
const DiscreteKey mode{M(0), 2};
|
||||
|
||||
/**
|
||||
* Create a tiny two variable hybrid model which represents
|
||||
* the generative probability P(z,x,mode) = P(z|x,mode)P(x)P(mode).
|
||||
*/
|
||||
inline HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
|
||||
HybridBayesNet bayesNet;
|
||||
|
||||
// Create Gaussian mixture z_i = x0 + noise for each measurement.
|
||||
for (int i = 0; i < num_measurements; i++) {
|
||||
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5));
|
||||
const auto conditional1 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3));
|
||||
GaussianMixture gm({Z(i)}, {X(0)}, {mode}, {conditional0, conditional1});
|
||||
bayesNet.emplaceMixture(gm); // copy :-(
|
||||
}
|
||||
|
||||
// Create prior on X(0).
|
||||
const auto prior_on_x0 =
|
||||
GaussianConditional::FromMeanAndStddev(X(0), Vector1(5.0), 0.5);
|
||||
bayesNet.emplaceGaussian(prior_on_x0); // copy :-(
|
||||
|
||||
// Add prior on mode.
|
||||
bayesNet.emplaceDiscrete(mode, "4/6");
|
||||
|
||||
return bayesNet;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph.
|
||||
*/
|
||||
inline HybridGaussianFactorGraph convertBayesNet(
|
||||
const HybridBayesNet& bayesNet, const VectorValues& measurements) {
|
||||
HybridGaussianFactorGraph fg;
|
||||
int num_measurements = bayesNet.size() - 2;
|
||||
for (int i = 0; i < num_measurements; i++) {
|
||||
auto conditional = bayesNet.atMixture(i);
|
||||
auto factor = conditional->likelihood({{Z(i), measurements.at(Z(i))}});
|
||||
fg.push_back(factor);
|
||||
}
|
||||
fg.push_back(bayesNet.atGaussian(num_measurements));
|
||||
fg.push_back(bayesNet.atDiscrete(num_measurements + 1));
|
||||
return fg;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tiny two variable hybrid factor graph which represents a discrete
|
||||
* mode and a continuous variable x0, given a number of measurements of the
|
||||
* continuous variable x0. If no measurements are given, they are sampled from
|
||||
* the generative Bayes net model HybridBayesNet::Example(num_measurements)
|
||||
*/
|
||||
inline HybridGaussianFactorGraph createHybridGaussianFactorGraph(
|
||||
int num_measurements = 1,
|
||||
boost::optional<VectorValues> measurements = boost::none) {
|
||||
auto bayesNet = createHybridBayesNet(num_measurements);
|
||||
if (measurements) {
|
||||
return convertBayesNet(bayesNet, *measurements);
|
||||
} else {
|
||||
return convertBayesNet(bayesNet, bayesNet.sample().continuous());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tiny
|
||||
} // namespace gtsam
|
||||
|
|
@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
|
|||
|
||||
// Create sum of two mixture factors: it will be a decision tree now on both
|
||||
// discrete variables m1 and m2:
|
||||
GaussianMixtureFactor::Sum sum;
|
||||
GaussianFactorGraphTree sum;
|
||||
sum += mixtureFactorA;
|
||||
sum += mixtureFactorB;
|
||||
|
||||
|
|
@ -89,8 +89,8 @@ TEST(GaussianMixtureFactor, Sum) {
|
|||
mode[m1.first] = 1;
|
||||
mode[m2.first] = 2;
|
||||
auto actual = sum(mode);
|
||||
EXPECT(actual.at(0) == f11);
|
||||
EXPECT(actual.at(1) == f22);
|
||||
EXPECT(actual.graph.at(0) == f11);
|
||||
EXPECT(actual.graph.at(1) == f22);
|
||||
}
|
||||
|
||||
TEST(GaussianMixtureFactor, Printing) {
|
||||
|
|
|
|||
|
|
@ -18,19 +18,18 @@
|
|||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
||||
#include "Switching.h"
|
||||
#include "TinyHybridExample.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
|
|
@ -63,7 +62,7 @@ TEST(HybridBayesNet, Add) {
|
|||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, evaluatePureDiscrete) {
|
||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.emplaceDiscrete(Asia, "99/1");
|
||||
HybridValues values;
|
||||
|
|
@ -71,6 +70,13 @@ TEST(HybridBayesNet, evaluatePureDiscrete) {
|
|||
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test creation of a tiny hybrid Bayes net.
|
||||
TEST(HybridBayesNet, Tiny) {
|
||||
auto bayesNet = tiny::createHybridBayesNet();
|
||||
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
|
||||
TEST(HybridBayesNet, evaluateHybrid) {
|
||||
|
|
@ -180,7 +186,7 @@ TEST(HybridBayesNet, OptimizeAssignment) {
|
|||
/* ****************************************************************************/
|
||||
// Test Bayes net optimize
|
||||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4);
|
||||
Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
|
|
@ -188,25 +194,24 @@ TEST(HybridBayesNet, Optimize) {
|
|||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
// TODO(Varun) The expectedAssignment should be 111, not 101
|
||||
// NOTE: The true assignment is 111, but the discrete priors cause 101
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(0)] = 1;
|
||||
expectedAssignment[M(1)] = 0;
|
||||
expectedAssignment[M(1)] = 1;
|
||||
expectedAssignment[M(2)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
// TODO(Varun) This should be all -Vector1::Ones()
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
|
||||
expectedValues.insert(X(2), -1.00971 * Vector1::Ones());
|
||||
expectedValues.insert(X(3), -1.0001 * Vector1::Ones());
|
||||
expectedValues.insert(X(0), -Vector1::Ones());
|
||||
expectedValues.insert(X(1), -Vector1::Ones());
|
||||
expectedValues.insert(X(2), -Vector1::Ones());
|
||||
expectedValues.insert(X(3), -Vector1::Ones());
|
||||
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net error
|
||||
// Test Bayes net error
|
||||
TEST(HybridBayesNet, Error) {
|
||||
Switching s(3);
|
||||
|
||||
|
|
@ -237,7 +242,7 @@ TEST(HybridBayesNet, Error) {
|
|||
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
|
||||
|
||||
// Verify error computation and check for specific error value
|
||||
DiscreteValues discrete_values {{M(0), 1}, {M(1), 1}};
|
||||
DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||
|
||||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
|
|
@ -323,18 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
discrete_conditional_tree->apply(checker);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet sampling.
|
||||
TEST(HybridBayesNet, Sampling) {
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ TEST(HybridBayesTree, Optimize) {
|
|||
dfg.push_back(
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
||||
}
|
||||
|
||||
|
||||
// Add the probabilities for each branch
|
||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656,
|
||||
|
|
@ -211,29 +211,15 @@ TEST(HybridBayesTree, Choose) {
|
|||
ordering += M(0);
|
||||
ordering += M(1);
|
||||
ordering += M(2);
|
||||
|
||||
//TODO(Varun) get segfault if ordering not provided
|
||||
|
||||
// TODO(Varun) get segfault if ordering not provided
|
||||
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
|
||||
|
||||
|
||||
auto expected_gbt = bayesTree->choose(assignment);
|
||||
|
||||
EXPECT(assert_equal(expected_gbt, gbt));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridBayesTree, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -280,11 +280,10 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
|
|||
return probPrimeTree;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
/**
|
||||
/*********************************************************************************
|
||||
* Test for correctness of different branches of the P'(Continuous | Discrete).
|
||||
* The values should match those of P'(Continuous) for each discrete mode.
|
||||
*/
|
||||
********************************************************************************/
|
||||
TEST(HybridEstimation, Probability) {
|
||||
constexpr size_t K = 4;
|
||||
std::vector<double> measurements = {0, 1, 2, 2};
|
||||
|
|
@ -441,18 +440,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
|||
* Do hybrid elimination and do regression test on discrete conditional.
|
||||
********************************************************************************/
|
||||
TEST(HybridEstimation, eliminateSequentialRegression) {
|
||||
// 1. Create the factor graph from the nonlinear factor graph.
|
||||
// Create the factor graph from the nonlinear factor graph.
|
||||
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
||||
|
||||
// 2. Eliminate into BN
|
||||
const Ordering ordering = fg->getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||
// GTSAM_PRINT(*bn);
|
||||
// Create expected discrete conditional on m0.
|
||||
DiscreteKey m(M(0), 2);
|
||||
DiscreteConditional expected(m % "0.51341712/1"); // regression
|
||||
|
||||
// TODO(dellaert): dc should be discrete conditional on m0, but it is an
|
||||
// unnormalized factor? DiscreteKey m(M(0), 2); DiscreteConditional expected(m
|
||||
// % "0.51341712/1"); auto dc = bn->back()->asDiscreteConditional();
|
||||
// EXPECT(assert_equal(expected, *dc, 1e-9));
|
||||
// Eliminate into BN using one ordering
|
||||
Ordering ordering1;
|
||||
ordering1 += X(0), X(1), M(0);
|
||||
HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc1 = bn1->back()->asDiscrete();
|
||||
EXPECT(assert_equal(expected, *dc1, 1e-9));
|
||||
|
||||
// Eliminate into BN using a different ordering
|
||||
Ordering ordering2;
|
||||
ordering2 += X(0), X(1), M(0);
|
||||
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
|
||||
|
||||
// Check that the discrete conditional matches the expected.
|
||||
auto dc2 = bn2->back()->asDiscrete();
|
||||
EXPECT(assert_equal(expected, *dc2, 1e-9));
|
||||
}
|
||||
|
||||
/*********************************************************************************
|
||||
|
|
@ -467,45 +478,35 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
|
|||
********************************************************************************/
|
||||
TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||
// 1. Create the factor graph from the nonlinear factor graph.
|
||||
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
||||
const auto fg = createHybridGaussianFactorGraph();
|
||||
|
||||
// 2. Eliminate into BN
|
||||
const Ordering ordering = fg->getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||
|
||||
// Set up sampling
|
||||
std::mt19937_64 rng(11);
|
||||
|
||||
// 3. Do sampling
|
||||
int num_samples = 10;
|
||||
|
||||
// Functor to compute the ratio between the
|
||||
// Bayes net and the factor graph.
|
||||
auto compute_ratio =
|
||||
[](const HybridBayesNet::shared_ptr& bayesNet,
|
||||
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
|
||||
const HybridValues& sample) -> double {
|
||||
const DiscreteValues assignment = sample.discrete();
|
||||
// Compute in log form for numerical stability
|
||||
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
|
||||
factorGraph->error({sample.continuous(), assignment});
|
||||
double ratio = exp(-log_ratio);
|
||||
return ratio;
|
||||
// Compute the log-ratio between the Bayes net and the factor graph.
|
||||
auto compute_ratio = [&](const HybridValues& sample) -> double {
|
||||
return bn->evaluate(sample) / fg->probPrime(sample);
|
||||
};
|
||||
|
||||
// The error evaluated by the factor graph and the Bayes net should differ by
|
||||
// the normalizing term computed via the Bayes net determinant.
|
||||
const HybridValues sample = bn->sample(&rng);
|
||||
double ratio = compute_ratio(bn, fg, sample);
|
||||
double expected_ratio = compute_ratio(sample);
|
||||
// regression
|
||||
EXPECT_DOUBLES_EQUAL(1.0, ratio, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.728588, expected_ratio, 1e-6);
|
||||
|
||||
// 4. Check that all samples == constant
|
||||
// 3. Do sampling
|
||||
constexpr int num_samples = 10;
|
||||
for (size_t i = 0; i < num_samples; i++) {
|
||||
// Sample from the bayes net
|
||||
const HybridValues sample = bn->sample(&rng);
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9);
|
||||
// 4. Check that the ratio is constant.
|
||||
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "Switching.h"
|
||||
#include "TinyHybridExample.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
|
@ -133,7 +134,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
|||
auto dc = result->at(2)->asDiscrete();
|
||||
DiscreteValues dv;
|
||||
dv[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3);
|
||||
// Regression test
|
||||
EXPECT_DOUBLES_EQUAL(0.62245933120185448, dc->operator()(dv), 1e-3);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -612,6 +614,108 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
|||
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
||||
// assignment.
|
||||
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||
using symbol_shorthand::Z;
|
||||
const int num_measurements = 1;
|
||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
||||
auto sum = fg.assembleGraphTree();
|
||||
|
||||
// Get mixture factor:
|
||||
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
|
||||
using GF = GaussianFactor::shared_ptr;
|
||||
|
||||
// Get prior factor:
|
||||
const GF prior =
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(fg.at(1))->inner();
|
||||
|
||||
// Create DiscreteValues for both 0 and 1:
|
||||
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
|
||||
|
||||
// Expected decision tree with two factor graphs:
|
||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
||||
GaussianFactorGraphTree expectedSum{
|
||||
M(0),
|
||||
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
|
||||
mixture->constant(d0)},
|
||||
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
|
||||
mixture->constant(d1)}};
|
||||
|
||||
EXPECT(assert_equal(expectedSum(d0), sum(d0), 1e-5));
|
||||
EXPECT(assert_equal(expectedSum(d1), sum(d1), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Check that eliminating tiny net with 1 measurement yields correct result.
|
||||
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||
using symbol_shorthand::Z;
|
||||
const int num_measurements = 1;
|
||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
||||
|
||||
// Create expected Bayes Net:
|
||||
HybridBayesNet expectedBayesNet;
|
||||
|
||||
// Create Gaussian mixture on X(0).
|
||||
using tiny::mode;
|
||||
// regression, but mean checked to be 5.0 in both cases:
|
||||
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||
X(0), Vector1(14.1421), I_1x1 * 2.82843),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(
|
||||
X(0), Vector1(10.1379), I_1x1 * 2.02759);
|
||||
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
|
||||
expectedBayesNet.emplaceMixture(gm); // copy :-(
|
||||
|
||||
// Add prior on mode.
|
||||
expectedBayesNet.emplaceDiscrete(mode, "74/26");
|
||||
|
||||
// Test elimination
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(0));
|
||||
ordering.push_back(M(0));
|
||||
const auto posterior = fg.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Check that eliminating tiny net with 2 measurements yields correct result.
|
||||
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
||||
using symbol_shorthand::Z;
|
||||
const int num_measurements = 2;
|
||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
||||
num_measurements,
|
||||
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}});
|
||||
|
||||
// Create expected Bayes Net:
|
||||
HybridBayesNet expectedBayesNet;
|
||||
|
||||
// Create Gaussian mixture on X(0).
|
||||
using tiny::mode;
|
||||
// regression, but mean checked to be 5.0 in both cases:
|
||||
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||
X(0), Vector1(17.3205), I_1x1 * 3.4641),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(
|
||||
X(0), Vector1(10.274), I_1x1 * 2.0548);
|
||||
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
|
||||
expectedBayesNet.emplaceMixture(gm); // copy :-(
|
||||
|
||||
// Add prior on mode.
|
||||
expectedBayesNet.emplaceDiscrete(mode, "23/77");
|
||||
|
||||
// Test elimination
|
||||
Ordering ordering;
|
||||
ordering.push_back(X(0));
|
||||
ordering.push_back(M(0));
|
||||
const auto posterior = fg.eliminateSequential(ordering);
|
||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -177,19 +177,19 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
|||
|
||||
// Test the probability values with regression tests.
|
||||
DiscreteValues assignment;
|
||||
EXPECT(assert_equal(0.0619233, m00_prob, 1e-5));
|
||||
EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
|
||||
assignment[M(0)] = 0;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 1;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 0;
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 1;
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
|
||||
|
||||
// Check if the clique conditional generated from incremental elimination
|
||||
// matches that of batch elimination.
|
||||
|
|
@ -199,10 +199,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
|||
isam[M(1)]->conditional()->inner());
|
||||
// Account for the probability terms from evaluating continuous FGs
|
||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2};
|
||||
vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
|
||||
auto expectedConditional =
|
||||
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
||||
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
||||
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -443,7 +443,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
|
|||
ordering.clear();
|
||||
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
|
||||
discreteBayesNet =
|
||||
*discrete_fg.eliminateSequential(ordering, EliminateForMPE);
|
||||
*discrete_fg.eliminateSequential(ordering, EliminateDiscrete);
|
||||
}
|
||||
|
||||
// Create ordering.
|
||||
|
|
@ -638,22 +638,30 @@ conditional 2: Hybrid P( x2 | m0 m1)
|
|||
0 0 Leaf p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.1489 ]
|
||||
mean: 1 elements
|
||||
x2: -1.0099
|
||||
No noise model
|
||||
|
||||
0 1 Leaf p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.1479 ]
|
||||
mean: 1 elements
|
||||
x2: -1.0098
|
||||
No noise model
|
||||
|
||||
1 Choice(m0)
|
||||
1 0 Leaf p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.0504 ]
|
||||
mean: 1 elements
|
||||
x2: -1.0001
|
||||
No noise model
|
||||
|
||||
1 1 Leaf p(x2)
|
||||
R = [ 10.0494 ]
|
||||
d = [ -10.0494 ]
|
||||
mean: 1 elements
|
||||
x2: -1
|
||||
No noise model
|
||||
|
||||
)";
|
||||
|
|
|
|||
|
|
@ -191,24 +191,23 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
|||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
|
||||
double m00_prob = decisionTree(m00);
|
||||
|
||||
auto discreteConditional =
|
||||
bayesTree[M(1)]->conditional()->asDiscrete();
|
||||
auto discreteConditional = bayesTree[M(1)]->conditional()->asDiscrete();
|
||||
|
||||
// Test the probability values with regression tests.
|
||||
DiscreteValues assignment;
|
||||
EXPECT(assert_equal(0.0619233, m00_prob, 1e-5));
|
||||
EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
|
||||
assignment[M(0)] = 0;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 1;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT(assert_equal(0.183743, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 0;
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT(assert_equal(0.204159, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
|
||||
assignment[M(0)] = 1;
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT(assert_equal(0.2, (*discreteConditional)(assignment), 1e-5));
|
||||
EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
|
||||
|
||||
// Check if the clique conditional generated from incremental elimination
|
||||
// matches that of batch elimination.
|
||||
|
|
@ -217,10 +216,10 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
|||
bayesTree[M(1)]->conditional()->inner());
|
||||
// Account for the probability terms from evaluating continuous FGs
|
||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
vector<double> probs = {0.061923317, 0.20415914, 0.18374323, 0.2};
|
||||
vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
|
||||
auto expectedConditional =
|
||||
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
||||
EXPECT(assert_equal(*actualConditional, *expectedConditional, 1e-6));
|
||||
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
|
@ -358,10 +357,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
// Run update with pruning
|
||||
size_t maxComponents = 5;
|
||||
incrementalHybrid.update(graph1, initial);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
HybridGaussianISAM bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
||||
|
|
@ -383,10 +381,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
|
||||
// Run update with pruning a second time.
|
||||
incrementalHybrid.update(graph2, initial);
|
||||
incrementalHybrid.prune(maxComponents);
|
||||
bayesTree = incrementalHybrid.bayesTree();
|
||||
|
||||
bayesTree.prune(maxComponents);
|
||||
|
||||
// Check if we have a bayes tree with pruned hybrid nodes,
|
||||
// with 5 (pruned) leaves.
|
||||
CHECK_EQUAL(5, bayesTree.size());
|
||||
|
|
|
|||
|
|
@ -70,8 +70,7 @@ MixtureFactor
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the error of the MixtureFactor
|
||||
TEST(MixtureFactor, Error) {
|
||||
static MixtureFactor getMixtureFactor() {
|
||||
DiscreteKey m1(1, 2);
|
||||
|
||||
double between0 = 0.0;
|
||||
|
|
@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) {
|
|||
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
|
||||
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
return MixtureFactor({X(1), X(2)}, {m1}, factors);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test the error of the MixtureFactor
|
||||
TEST(MixtureFactor, Error) {
|
||||
auto mixtureFactor = getMixtureFactor();
|
||||
|
||||
Values continuousValues;
|
||||
continuousValues.insert<double>(X(1), 0);
|
||||
|
|
@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) {
|
|||
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> errors = {0.5, 0};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||
|
|
@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) {
|
|||
EXPECT(assert_equal(expected_error, error_tree));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test dim of the MixtureFactor
|
||||
TEST(MixtureFactor, Dim) {
|
||||
auto mixtureFactor = getMixtureFactor();
|
||||
EXPECT_LONGS_EQUAL(1, mixtureFactor.dim());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,179 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testSerializationHybrid.cpp
|
||||
* @brief Unit tests for hybrid serialization
|
||||
* @author Varun Agrawal
|
||||
* @date January 2023
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
|
||||
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
|
||||
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
|
||||
"gtsam_GaussianMixtureFactor_Factors");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Leaf,
|
||||
"gtsam_GaussianMixtureFactor_Factors_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors::Choice,
|
||||
"gtsam_GaussianMixtureFactor_Factors_Choice");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals,
|
||||
"gtsam_GaussianMixture_Conditionals");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf,
|
||||
"gtsam_GaussianMixture_Conditionals_Leaf");
|
||||
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
|
||||
"gtsam_GaussianMixture_Conditionals_Choice");
|
||||
// Needed since GaussianConditional::FromMeanAndStddev uses it
|
||||
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridGaussianFactor serialization.
|
||||
TEST(HybridSerialization, HybridGaussianFactor) {
|
||||
const HybridGaussianFactor factor(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||
|
||||
EXPECT(equalsObj<HybridGaussianFactor>(factor));
|
||||
EXPECT(equalsXML<HybridGaussianFactor>(factor));
|
||||
EXPECT(equalsBinary<HybridGaussianFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridDiscreteFactor serialization.
|
||||
TEST(HybridSerialization, HybridDiscreteFactor) {
|
||||
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||
const HybridDiscreteFactor factor(
|
||||
DecisionTreeFactor(discreteKeys, std::vector<double>{0.4, 0.6}));
|
||||
|
||||
EXPECT(equalsObj<HybridDiscreteFactor>(factor));
|
||||
EXPECT(equalsXML<HybridDiscreteFactor>(factor));
|
||||
EXPECT(equalsBinary<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test GaussianMixtureFactor serialization.
|
||||
TEST(HybridSerialization, GaussianMixtureFactor) {
|
||||
KeyVector continuousKeys{X(0)};
|
||||
DiscreteKeys discreteKeys{{M(0), 2}};
|
||||
|
||||
auto A = Matrix::Zero(2, 1);
|
||||
auto b0 = Matrix::Zero(2, 1);
|
||||
auto b1 = Matrix::Ones(2, 1);
|
||||
auto f0 = boost::make_shared<JacobianFactor>(X(0), A, b0);
|
||||
auto f1 = boost::make_shared<JacobianFactor>(X(0), A, b1);
|
||||
std::vector<GaussianFactor::shared_ptr> factors{f0, f1};
|
||||
|
||||
const GaussianMixtureFactor factor(continuousKeys, discreteKeys, factors);
|
||||
|
||||
EXPECT(equalsObj<GaussianMixtureFactor>(factor));
|
||||
EXPECT(equalsXML<GaussianMixtureFactor>(factor));
|
||||
EXPECT(equalsBinary<GaussianMixtureFactor>(factor));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridConditional serialization.
|
||||
TEST(HybridSerialization, HybridConditional) {
|
||||
const DiscreteKey mode(M(0), 2);
|
||||
Matrix1 I = Matrix1::Identity();
|
||||
const auto conditional = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||
const HybridConditional hc(conditional);
|
||||
|
||||
EXPECT(equalsObj<HybridConditional>(hc));
|
||||
EXPECT(equalsXML<HybridConditional>(hc));
|
||||
EXPECT(equalsBinary<HybridConditional>(hc));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test GaussianMixture serialization.
|
||||
TEST(HybridSerialization, GaussianMixture) {
|
||||
const DiscreteKey mode(M(0), 2);
|
||||
Matrix1 I = Matrix1::Identity();
|
||||
const auto conditional0 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
|
||||
const auto conditional1 = boost::make_shared<GaussianConditional>(
|
||||
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
|
||||
const GaussianMixture gm({Z(0)}, {X(0)}, {mode},
|
||||
{conditional0, conditional1});
|
||||
|
||||
EXPECT(equalsObj<GaussianMixture>(gm));
|
||||
EXPECT(equalsXML<GaussianMixture>(gm));
|
||||
EXPECT(equalsBinary<GaussianMixture>(gm));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridSerialization, HybridBayesNet) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesTree serialization.
|
||||
TEST(HybridSerialization, HybridBayesTree) {
|
||||
Switching s(2);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesTree hbt =
|
||||
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsXML<HybridBayesTree>(hbt));
|
||||
EXPECT(equalsBinary<HybridBayesTree>(hbt));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -67,7 +67,7 @@ namespace gtsam {
|
|||
GaussianConditional GaussianConditional::FromMeanAndStddev(Key key,
|
||||
const Vector& mu,
|
||||
double sigma) {
|
||||
// |Rx - d| = |x-(Ay + b)|/sigma
|
||||
// |Rx - d| = |x - mu|/sigma
|
||||
const Matrix R = Matrix::Identity(mu.size(), mu.size());
|
||||
const Vector& d = mu;
|
||||
return GaussianConditional(key, d, R,
|
||||
|
|
@ -120,6 +120,10 @@ namespace gtsam {
|
|||
<< endl;
|
||||
}
|
||||
cout << formatMatrixIndented(" d = ", getb(), true) << "\n";
|
||||
if (nrParents() == 0) {
|
||||
const auto mean = solve({}); // solve for mean.
|
||||
mean.print(" mean");
|
||||
}
|
||||
if (model_)
|
||||
model_->print(" Noise model: ");
|
||||
else
|
||||
|
|
@ -189,7 +193,7 @@ double GaussianConditional::logNormalizationConstant() const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// density = k exp(-error(x))
|
||||
// log = log(k) -error(x) - 0.5 * n*log(2*pi)
|
||||
// log = log(k) -error(x)
|
||||
double GaussianConditional::logDensity(const VectorValues& x) const {
|
||||
return logNormalizationConstant() - error(x);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -466,6 +466,31 @@ TEST(GaussianConditional, sample) {
|
|||
// EXPECT(assert_equal(Vector2(31.0111856, 64.9850775), actual2[X(0)], 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(GaussianConditional, LogNormalizationConstant) {
|
||||
// Create univariate standard gaussian conditional
|
||||
auto std_gaussian =
|
||||
GaussianConditional::FromMeanAndStddev(X(0), Vector1::Zero(), 1.0);
|
||||
VectorValues values;
|
||||
values.insert(X(0), Vector1::Zero());
|
||||
double logDensity = std_gaussian.logDensity(values);
|
||||
|
||||
// Regression.
|
||||
// These values were computed by hand for a univariate standard gaussian.
|
||||
EXPECT_DOUBLES_EQUAL(-0.9189385332046727, logDensity, 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.3989422804014327, exp(logDensity), 1e-9);
|
||||
|
||||
// Similar test for multivariate gaussian but with sigma 2.0
|
||||
double sigma = 2.0;
|
||||
auto conditional = GaussianConditional::FromMeanAndStddev(X(0), Vector3::Zero(), sigma);
|
||||
VectorValues x;
|
||||
x.insert(X(0), Vector3::Zero());
|
||||
Matrix3 Sigma = I_3x3 * sigma * sigma;
|
||||
double expectedLogNormalizingConstant = log(1 / sqrt((2 * M_PI * Sigma).determinant()));
|
||||
|
||||
EXPECT_DOUBLES_EQUAL(expectedLogNormalizingConstant, conditional.logNormalizationConstant(), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(GaussianConditional, Print) {
|
||||
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
|
||||
|
|
@ -482,6 +507,8 @@ TEST(GaussianConditional, Print) {
|
|||
" R = [ 1 0 ]\n"
|
||||
" [ 0 1 ]\n"
|
||||
" d = [ 20 40 ]\n"
|
||||
" mean: 1 elements\n"
|
||||
" x0: 20 40\n"
|
||||
"isotropic dim=2 sigma=3\n";
|
||||
EXPECT(assert_print_equal(expected, conditional, "GaussianConditional"));
|
||||
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@ from gtsam.utils.test_case import GtsamTestCase
|
|||
|
||||
import gtsam
|
||||
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
|
||||
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues,
|
||||
HybridGaussianFactorGraph, JacobianFactor, Ordering,
|
||||
noiseModel)
|
||||
GaussianMixture, GaussianMixtureFactor, HybridBayesNet,
|
||||
HybridGaussianFactorGraph, HybridValues, JacobianFactor,
|
||||
Ordering, noiseModel)
|
||||
|
||||
|
||||
class TestHybridGaussianFactorGraph(GtsamTestCase):
|
||||
|
|
@ -82,10 +82,12 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
self.assertEqual(hv.atDiscrete(C(0)), 1)
|
||||
|
||||
@staticmethod
|
||||
def tiny(num_measurements: int = 1) -> HybridBayesNet:
|
||||
def tiny(num_measurements: int = 1, prior_mean: float = 5.0,
|
||||
prior_sigma: float = 0.5) -> HybridBayesNet:
|
||||
"""
|
||||
Create a tiny two variable hybrid model which represents
|
||||
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n).
|
||||
the generative probability P(Z, x0, mode) = P(Z|x0, mode)P(x0)P(mode).
|
||||
num_measurements: number of measurements in Z = {z0, z1...}
|
||||
"""
|
||||
# Create hybrid Bayes net.
|
||||
bayesNet = HybridBayesNet()
|
||||
|
|
@ -94,23 +96,24 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
mode = (M(0), 2)
|
||||
|
||||
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
|
||||
I = np.eye(1)
|
||||
I_1x1 = np.eye(1)
|
||||
keys = DiscreteKeys()
|
||||
keys.push_back(mode)
|
||||
for i in range(num_measurements):
|
||||
conditional0 = GaussianConditional.FromMeanAndStddev(Z(i),
|
||||
I,
|
||||
I_1x1,
|
||||
X(0), [0],
|
||||
sigma=0.5)
|
||||
conditional1 = GaussianConditional.FromMeanAndStddev(Z(i),
|
||||
I,
|
||||
I_1x1,
|
||||
X(0), [0],
|
||||
sigma=3)
|
||||
bayesNet.emplaceMixture([Z(i)], [X(0)], keys,
|
||||
[conditional0, conditional1])
|
||||
|
||||
# Create prior on X(0).
|
||||
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0)
|
||||
prior_on_x0 = GaussianConditional.FromMeanAndStddev(
|
||||
X(0), [prior_mean], prior_sigma)
|
||||
bayesNet.addGaussian(prior_on_x0)
|
||||
|
||||
# Add prior on mode.
|
||||
|
|
@ -118,8 +121,41 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
|
||||
return bayesNet
|
||||
|
||||
def test_evaluate(self):
|
||||
"""Test evaluate with two different prior noise models."""
|
||||
# TODO(dellaert): really a HBN test
|
||||
# Create a tiny Bayes net P(x0) P(m0) P(z0|x0)
|
||||
bayesNet1 = self.tiny(prior_sigma=0.5, num_measurements=1)
|
||||
bayesNet2 = self.tiny(prior_sigma=5.0, num_measurements=1)
|
||||
# bn1: # 1/sqrt(2*pi*0.5^2)
|
||||
# bn2: # 1/sqrt(2*pi*5.0^2)
|
||||
expected_ratio = np.sqrt(2*np.pi*5.0**2)/np.sqrt(2*np.pi*0.5**2)
|
||||
mean0 = HybridValues()
|
||||
mean0.insert(X(0), [5.0])
|
||||
mean0.insert(Z(0), [5.0])
|
||||
mean0.insert(M(0), 0)
|
||||
self.assertAlmostEqual(bayesNet1.evaluate(mean0) /
|
||||
bayesNet2.evaluate(mean0), expected_ratio,
|
||||
delta=1e-9)
|
||||
mean1 = HybridValues()
|
||||
mean1.insert(X(0), [5.0])
|
||||
mean1.insert(Z(0), [5.0])
|
||||
mean1.insert(M(0), 1)
|
||||
self.assertAlmostEqual(bayesNet1.evaluate(mean1) /
|
||||
bayesNet2.evaluate(mean1), expected_ratio,
|
||||
delta=1e-9)
|
||||
|
||||
@staticmethod
|
||||
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues):
|
||||
def measurements(sample: HybridValues, indices) -> gtsam.VectorValues:
|
||||
"""Create measurements from a sample, grabbing Z(i) for indices."""
|
||||
measurements = gtsam.VectorValues()
|
||||
for i in indices:
|
||||
measurements.insert(Z(i), sample.at(Z(i)))
|
||||
return measurements
|
||||
|
||||
@classmethod
|
||||
def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet,
|
||||
sample: HybridValues):
|
||||
"""Create a factor graph from the Bayes net with sampled measurements.
|
||||
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
|
||||
and thus represents the same joint probability as the Bayes net.
|
||||
|
|
@ -128,31 +164,27 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
num_measurements = bayesNet.size() - 2
|
||||
for i in range(num_measurements):
|
||||
conditional = bayesNet.atMixture(i)
|
||||
measurement = gtsam.VectorValues()
|
||||
measurement.insert(Z(i), sample.at(Z(i)))
|
||||
factor = conditional.likelihood(measurement)
|
||||
factor = conditional.likelihood(cls.measurements(sample, [i]))
|
||||
fg.push_back(factor)
|
||||
fg.push_back(bayesNet.atGaussian(num_measurements))
|
||||
fg.push_back(bayesNet.atDiscrete(num_measurements+1))
|
||||
return fg
|
||||
|
||||
@classmethod
|
||||
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000):
|
||||
"""Do importance sampling to get an estimate of the discrete marginal P(mode)."""
|
||||
# Use prior on x0, mode as proposal density.
|
||||
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode)
|
||||
|
||||
# Allocate space for marginals.
|
||||
def estimate_marginals(cls, target, proposal_density: HybridBayesNet,
|
||||
N=10000):
|
||||
"""Do importance sampling to estimate discrete marginal P(mode)."""
|
||||
# Allocate space for marginals on mode.
|
||||
marginals = np.zeros((2,))
|
||||
|
||||
# Do importance sampling.
|
||||
num_measurements = bayesNet.size() - 2
|
||||
for s in range(N):
|
||||
proposed = prior.sample()
|
||||
for i in range(num_measurements):
|
||||
z_i = sample.at(Z(i))
|
||||
proposed.insert(Z(i), z_i)
|
||||
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed)
|
||||
proposed = proposal_density.sample() # sample from proposal
|
||||
target_proposed = target(proposed) # evaluate target
|
||||
# print(target_proposed, proposal_density.evaluate(proposed))
|
||||
weight = target_proposed / proposal_density.evaluate(proposed)
|
||||
# print weight:
|
||||
# print(f"weight: {weight}")
|
||||
marginals[proposed.atDiscrete(M(0))] += weight
|
||||
|
||||
# print marginals:
|
||||
|
|
@ -161,72 +193,146 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
|
|||
|
||||
def test_tiny(self):
|
||||
"""Test a tiny two variable hybrid model."""
|
||||
bayesNet = self.tiny()
|
||||
sample = bayesNet.sample()
|
||||
# print(sample)
|
||||
# P(x0)P(mode)P(z0|x0,mode)
|
||||
prior_sigma = 0.5
|
||||
bayesNet = self.tiny(prior_sigma=prior_sigma)
|
||||
|
||||
# Deterministic values exactly at the mean, for both x and Z:
|
||||
values = HybridValues()
|
||||
values.insert(X(0), [5.0])
|
||||
values.insert(M(0), 0) # low-noise, standard deviation 0.5
|
||||
z0: float = 5.0
|
||||
values.insert(Z(0), [z0])
|
||||
|
||||
def unnormalized_posterior(x):
|
||||
"""Posterior is proportional to joint, centered at 5.0 as well."""
|
||||
x.insert(Z(0), [z0])
|
||||
return bayesNet.evaluate(x)
|
||||
|
||||
# Create proposal density on (x0, mode), making sure it has same mean:
|
||||
posterior_information = 1/(prior_sigma**2) + 1/(0.5**2)
|
||||
posterior_sigma = posterior_information**(-0.5)
|
||||
proposal_density = self.tiny(
|
||||
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(bayesNet, sample)
|
||||
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||
marginals = self.estimate_marginals(
|
||||
target=unnormalized_posterior, proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
# print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
|
||||
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, values)
|
||||
self.assertEqual(fg.size(), 3)
|
||||
|
||||
# Test elimination.
|
||||
ordering = gtsam.Ordering()
|
||||
ordering.push_back(X(0))
|
||||
ordering.push_back(M(0))
|
||||
posterior = fg.eliminateSequential(ordering)
|
||||
|
||||
def true_posterior(x):
|
||||
"""Posterior from elimination."""
|
||||
x.insert(Z(0), [z0])
|
||||
return posterior.evaluate(x)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(
|
||||
target=true_posterior, proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; z0) = {marginals[0]}")
|
||||
# print(f"P(mode=1; z0) = {marginals[1]}")
|
||||
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1)
|
||||
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1)
|
||||
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||
self.assertEqual(fg.size(), 3)
|
||||
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
|
||||
|
||||
@staticmethod
|
||||
def calculate_ratio(bayesNet: HybridBayesNet,
|
||||
fg: HybridGaussianFactorGraph,
|
||||
sample: HybridValues):
|
||||
"""Calculate ratio between Bayes net probability and the factor graph."""
|
||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0
|
||||
"""Calculate ratio between Bayes net and factor graph."""
|
||||
return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
|
||||
fg.probPrime(sample) > 0 else 0
|
||||
|
||||
def test_ratio(self):
|
||||
"""
|
||||
Given a tiny two variable hybrid model, with 2 measurements,
|
||||
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
Given a tiny two variable hybrid model, with 2 measurements, test the
|
||||
ratio of the bayes net model representing P(z,x,n)=P(z|x, n)P(x)P(n)
|
||||
and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
|
||||
both of which represent the same posterior.
|
||||
"""
|
||||
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
bayesNet = self.tiny(num_measurements=2)
|
||||
# Sample from the Bayes net.
|
||||
sample: HybridValues = bayesNet.sample()
|
||||
# print(sample)
|
||||
# Create generative model P(z, x, n)=P(z|x, n)P(x)P(n)
|
||||
prior_sigma = 0.5
|
||||
bayesNet = self.tiny(prior_sigma=prior_sigma, num_measurements=2)
|
||||
|
||||
# Deterministic values exactly at the mean, for both x and Z:
|
||||
values = HybridValues()
|
||||
values.insert(X(0), [5.0])
|
||||
values.insert(M(0), 0) # high-noise, standard deviation 3
|
||||
measurements = gtsam.VectorValues()
|
||||
measurements.insert(Z(0), [4.0])
|
||||
measurements.insert(Z(1), [6.0])
|
||||
values.insert(measurements)
|
||||
|
||||
def unnormalized_posterior(x):
|
||||
"""Posterior is proportional to joint, centered at 5.0 as well."""
|
||||
x.insert(measurements)
|
||||
return bayesNet.evaluate(x)
|
||||
|
||||
# Create proposal density on (x0, mode), making sure it has same mean:
|
||||
posterior_information = 1/(prior_sigma**2) + 2.0/(3.0**2)
|
||||
posterior_sigma = posterior_information**(-0.5)
|
||||
proposal_density = self.tiny(
|
||||
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
|
||||
|
||||
# Estimate marginals using importance sampling.
|
||||
marginals = self.estimate_marginals(bayesNet, sample)
|
||||
# print(f"True mode: {sample.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; z0, z1) = {marginals[0]}")
|
||||
# print(f"P(mode=1; z0, z1) = {marginals[1]}")
|
||||
marginals = self.estimate_marginals(
|
||||
target=unnormalized_posterior, proposal_density=proposal_density)
|
||||
# print(f"True mode: {values.atDiscrete(M(0))}")
|
||||
# print(f"P(mode=0; Z) = {marginals[0]}")
|
||||
# print(f"P(mode=1; Z) = {marginals[1]}")
|
||||
|
||||
# Check marginals based on sampled mode.
|
||||
if sample.atDiscrete(M(0)) == 0:
|
||||
self.assertGreater(marginals[0], marginals[1])
|
||||
else:
|
||||
self.assertGreater(marginals[1], marginals[0])
|
||||
# Check that the estimate is close to the true value.
|
||||
self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
|
||||
self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
|
||||
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
|
||||
# Convert to factor graph using measurements.
|
||||
fg = self.factor_graph_from_bayes_net(bayesNet, values)
|
||||
self.assertEqual(fg.size(), 4)
|
||||
|
||||
# Calculate ratio between Bayes net probability and the factor graph:
|
||||
expected_ratio = self.calculate_ratio(bayesNet, fg, sample)
|
||||
expected_ratio = self.calculate_ratio(bayesNet, fg, values)
|
||||
# print(f"expected_ratio: {expected_ratio}\n")
|
||||
|
||||
# Create measurements from the sample.
|
||||
measurements = gtsam.VectorValues()
|
||||
for i in range(2):
|
||||
measurements.insert(Z(i), sample.at(Z(i)))
|
||||
|
||||
# Check with a number of other samples.
|
||||
for i in range(10):
|
||||
other = bayesNet.sample()
|
||||
other.update(measurements)
|
||||
ratio = self.calculate_ratio(bayesNet, fg, other)
|
||||
samples = bayesNet.sample()
|
||||
samples.update(measurements)
|
||||
ratio = self.calculate_ratio(bayesNet, fg, samples)
|
||||
# print(f"Ratio: {ratio}\n")
|
||||
if (ratio > 0):
|
||||
self.assertAlmostEqual(ratio, expected_ratio)
|
||||
|
||||
# Test elimination.
|
||||
ordering = gtsam.Ordering()
|
||||
ordering.push_back(X(0))
|
||||
ordering.push_back(M(0))
|
||||
posterior = fg.eliminateSequential(ordering)
|
||||
|
||||
# Calculate ratio between Bayes net probability and the factor graph:
|
||||
expected_ratio = self.calculate_ratio(posterior, fg, values)
|
||||
# print(f"expected_ratio: {expected_ratio}\n")
|
||||
|
||||
# Check with a number of other samples.
|
||||
for i in range(10):
|
||||
samples = posterior.sample()
|
||||
samples.insert(measurements)
|
||||
ratio = self.calculate_ratio(posterior, fg, samples)
|
||||
# print(f"Ratio: {ratio}\n")
|
||||
if (ratio > 0):
|
||||
self.assertAlmostEqual(ratio, expected_ratio)
|
||||
|
|
|
|||
Loading…
Reference in New Issue