Merge pull request #1282 from borglab/hybrid/optimize-2

release/4.3a0
Varun Agrawal 2022-08-31 12:26:42 -04:00 committed by GitHub
commit f7e1d2a1d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 223 additions and 527 deletions

View File

@ -15,8 +15,9 @@
* @date January 2022
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
namespace gtsam {
@ -111,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
}
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
}
/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
@ -125,22 +131,39 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
try {
GaussianMixture gm = *this->atGaussian(idx);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment));
} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue;
}
}
return gbn;
}
/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
}
/* *******************************************************************************/

View File

@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
}
/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;
GaussianMixture::shared_ptr atMixture(size_t i) const;
/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;
/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;

View File

@ -18,6 +18,8 @@
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
@ -35,6 +37,52 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
HybridBayesNet hbn;
DiscreteBayesNet dbn;
KeyVector added_keys;
// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();
// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());
if (conditional->isDiscrete()) {
// If discrete, we use it to compute the MPE
dbn.push_back(conditional->asDiscreteConditional());
} else {
// Else conditional is hybrid or continuous-only,
// so we directly add it to the Hybrid Bayes net.
hbn.push_back(conditional);
}
}
}
// Get the MPE
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hbn.choose(mpe);
// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif
return HybridValues(mpe, gbn.optimize());
}
/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;
@ -50,11 +98,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();
KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());
// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());
// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
@ -65,9 +111,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
(*gm)(assignment);
gbn.push_back(gaussian_conditional);
} else if (conditional->isContinuous()) {
// If conditional is Gaussian, we simply add it to the Bayes net.
gbn.push_back(conditional->asGaussian());
}
}
}
// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB

View File

@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous
* update delta.
*
* @return HybridValues
*/
HybridValues optimize() const;
/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*

View File

@ -139,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
return boost::static_pointer_cast<GaussianMixture>(inner_);
}
/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
*

View File

@ -1,76 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in
// DiscreteLookupTable.
if (isDiscrete()) {
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
values->continuous));
} else if (isHybrid()) {
// For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous));
}
}
/* ************************************************************************** */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag;
for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional);
dag.push_back(hlt);
}
return dag;
}
/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
} // namespace gtsam

View File

@ -1,119 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridLookupDAG.h
* @date Aug, 2022
* @author Shangjie Xue
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
/**
* @brief HybridLookupTable table for max-product
*
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
* convenience. Is used in the max-product algorithm.
*/
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public:
using Base = HybridConditional;
using This = HybridLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
*
* @param conditional input hybrid conditional
*/
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(HybridValues* parentsValues) const;
};
/** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
public:
using Base = BayesNet<HybridLookupTable>;
using This = HybridLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
HybridLookupDAG() {}
/// Create from BayesNet with LookupTables
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
/// Destructor
virtual ~HybridLookupDAG() {}
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
HybridValues argmax(HybridValues given = HybridValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
} // namespace gtsam

View File

@ -36,40 +36,58 @@ namespace gtsam {
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
public:
private:
// DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete;
DiscreteValues discrete_;
// VectorValue stored the continuous components of the HybridValues.
VectorValues continuous;
VectorValues continuous_;
public:
/// @name Standard Constructors
/// @{
// Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous(){};
HybridValues() = default;
// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){};
: discrete_(dv), continuous_(cv){};
/// @}
/// @name Testable
/// @{
// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous",
keyFormatter); // print continuous components
discrete_.print(" Discrete", keyFormatter); // print discrete components
continuous_.print(" Continuous",
keyFormatter); // print continuous components
};
// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) &&
continuous.equals(other.continuous, tol);
return discrete_.equals(other.discrete_, tol) &&
continuous_.equals(other.continuous_, tol);
}
/// @}
/// @name Interface
/// @{
/// Return the discrete MPE assignment
DiscreteValues discrete() const { return discrete_; }
/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }
// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous.exists(j); };
bool existsVector(Key j) { return continuous_.exists(j); };
// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
@ -78,13 +96,13 @@ class GTSAM_EXPORT HybridValues {
* the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; };
void insert(Key j, int value) { discrete_[j] = value; };
/** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); }
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues {
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete.at(j); };
size_t& atDiscrete(Key j) { return discrete_.at(j); };
/**
* Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
Vector& at(Key j) { return continuous.at(j); };
Vector& at(Key j) { return continuous_.at(j); };
/// @name Wrapper support
/// @{
@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues {
std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss;
ss << this->discrete.html(keyFormatter);
ss << this->continuous.html(keyFormatter);
ss << this->discrete_.html(keyFormatter);
ss << this->continuous_.html(keyFormatter);
return ss.str();
};

View File

@ -6,8 +6,8 @@ namespace gtsam {
#include <gtsam/hybrid/HybridValues.h>
class HybridValues {
gtsam::DiscreteValues discrete;
gtsam::VectorValues continuous;
gtsam::DiscreteValues discrete() const;
gtsam::VectorValues continuous() const;
HybridValues();
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
void print(string s = "HybridValues",

View File

@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
EXPECT_LONGS_EQUAL(4, gbn.size());
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(0)))(assignment),
hybridBayesNet->atMixture(0)))(assignment),
*gbn.at(0)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(1)))(assignment),
hybridBayesNet->atMixture(1)))(assignment),
*gbn.at(1)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(2)))(assignment),
hybridBayesNet->atMixture(2)))(assignment),
*gbn.at(2)));
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet->atGaussian(3)))(assignment),
hybridBayesNet->atMixture(3)))(assignment),
*gbn.at(3)));
}
@ -125,35 +125,25 @@ TEST(HybridBayesNet, OptimizeAssignment) {
TEST(HybridBayesNet, Optimize) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
delta.print();
VectorValues correct;
correct.insert(X(1), 0 * Vector1::Ones());
correct.insert(X(2), 1 * Vector1::Ones());
correct.insert(X(3), 2 * Vector1::Ones());
correct.insert(X(4), 3 * Vector1::Ones());
DiscreteValues expectedAssignment;
expectedAssignment[M(1)] = 1;
expectedAssignment[M(2)] = 0;
expectedAssignment[M(3)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
DiscreteValues assignment111;
assignment111[M(1)] = 1;
assignment111[M(2)] = 1;
assignment111[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
VectorValues expectedValues;
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
DiscreteValues assignment101;
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ************************************************************************* */

View File

@ -16,6 +16,7 @@
* @date August 2022
*/
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
@ -31,8 +32,8 @@ using symbol_shorthand::M;
using symbol_shorthand::X;
/* ****************************************************************************/
// Test for optimizing a HybridBayesTree.
TEST(HybridBayesTree, Optimize) {
// Test for optimizing a HybridBayesTree with a given assignment.
TEST(HybridBayesTree, OptimizeAssignment) {
Switching s(4);
HybridGaussianISAM isam;
@ -50,6 +51,11 @@ TEST(HybridBayesTree, Optimize) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
DiscreteValues assignment;
@ -85,6 +91,58 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expected, delta));
}
/* ****************************************************************************/
// Test for optimizing a HybridBayesTree.
TEST(HybridBayesTree, Optimize) {
Switching s(4);
HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 6; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
HybridValues delta = isam.optimize();
// Create ordering.
Ordering ordering;
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
}
DiscreteValues expectedMPE = dfg.optimize();
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
EXPECT(assert_equal(expectedMPE, delta.discrete()));
EXPECT(assert_equal(expectedValues, delta.continuous()));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -1,272 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
#include <iostream>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
TEST(HybridLookupTable, basics) {
// create a conditional gaussian node
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 2;
S1(0, 1) = 3;
S1(1, 1) = 4;
Matrix S2(2, 2);
S2(0, 0) = 6;
S2(1, 0) = 0.2;
S2(0, 1) = 8;
S2(1, 1) = 0.4;
Matrix R1(2, 2);
R1(0, 0) = 0.1;
R1(1, 0) = 0.3;
R1(0, 1) = 0.0;
R1(1, 1) = 0.34;
Matrix R2(2, 2);
R2(0, 0) = 0.1;
R2(1, 0) = 0.3;
R2(0, 1) = 0.0;
R2(1, 1) = 0.34;
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
GaussianMixture::Conditionals conditional2 =
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
cv.insert(X(2), Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
// std::cout << conditional2(values).markdown();
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
EXPECT(conditional2(dv) == conditionals(dv));
HybridLookupTable hlt(hc);
// hlt.argmaxInPlace(&hv);
HybridLookupDAG dag;
dag.push_back(hlt);
dag.argmax(hv);
// HybridBayesNet hbn;
// hbn.push_back(hc);
// hbn.optimize();
}
TEST(HybridLookupTable, hybrid_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
HybridLookupTable hlt(hc);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d2));
}
TEST(HybridLookupTable, discrete_argmax) {
DiscreteKey X(0, 2), Y(1, 2);
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
dv[1] = 0;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.atDiscrete(0), 1));
DecisionTreeFactor f1(X, "2 3");
auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc2(conditional2);
HybridLookupTable hlt2(hc2);
HybridValues hv2;
hlt2.argmaxInPlace(&hv2);
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
}
TEST(HybridLookupTable, gaussian_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
// dv[1]=0;
VectorValues cv;
cv.insert(X(2), d2);
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
}
TEST(HybridLookupDAG, argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
HybridConditional hc2(mixtureFactor);
HybridLookupTable hlt2(hc2);
auto conditional2 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc1(conditional2);
HybridLookupTable hlt1(hc1);
DecisionTreeFactor f1(m1, "2 3");
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc3(discrete_conditional);
HybridLookupTable hlt3(hc3);
HybridLookupDAG dag;
dag.push_back(hlt1);
dag.push_back(hlt2);
dag.push_back(hlt3);
auto hv = dag.argmax();
EXPECT(assert_equal(hv.atDiscrete(1), 1));
EXPECT(assert_equal(hv.at(X(2)), d2));
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */