Merge pull request #1282 from borglab/hybrid/optimize-2
commit
f7e1d2a1d3
|
@ -15,8 +15,9 @@
|
|||
* @date January 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -111,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||
return factors_.at(i)->asMixture();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
return factors_.at(i)->asGaussian();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||
return factors_.at(i)->asDiscreteConditional();
|
||||
|
@ -125,22 +131,39 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
try {
|
||||
GaussianMixture gm = *this->atGaussian(idx);
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *this->atMixture(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} catch (std::exception &exc) {
|
||||
// if factor at `idx` is discrete-only, just continue.
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, add gaussian conditional.
|
||||
gbn.push_back((this->atGaussian(idx)));
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we simply continue.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return gbn;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
auto dag = HybridLookupDAG::FromBayesNet(*this);
|
||||
return dag.argmax();
|
||||
// Solve for the MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &conditional : factors_) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = this->choose(mpe);
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
}
|
||||
|
||||
/// Get a specific Gaussian mixture by index `i`.
|
||||
GaussianMixture::shared_ptr atGaussian(size_t i) const;
|
||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||
|
||||
/// Get a specific Gaussian conditional by index `i`.
|
||||
GaussianConditional::shared_ptr atGaussian(size_t i) const;
|
||||
|
||||
/// Get a specific discrete conditional by index `i`.
|
||||
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/base/treeTraversal-inst.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/inference/BayesTree-inst.h>
|
||||
|
@ -35,6 +37,52 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
|
|||
return Base::equals(other, tol);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesTree::optimize() const {
|
||||
HybridBayesNet hbn;
|
||||
DiscreteBayesNet dbn;
|
||||
|
||||
KeyVector added_keys;
|
||||
|
||||
// Iterate over all the nodes in the BayesTree
|
||||
for (auto&& node : nodes()) {
|
||||
// Check if conditional being added is already in the Bayes net.
|
||||
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
|
||||
added_keys.end()) {
|
||||
// Access the clique and get the underlying hybrid conditional
|
||||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
// Record the key being added
|
||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||
conditional->frontals().end());
|
||||
|
||||
if (conditional->isDiscrete()) {
|
||||
// If discrete, we use it to compute the MPE
|
||||
dbn.push_back(conditional->asDiscreteConditional());
|
||||
|
||||
} else {
|
||||
// Else conditional is hybrid or continuous-only,
|
||||
// so we directly add it to the Hybrid Bayes net.
|
||||
hbn.push_back(conditional);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Get the MPE
|
||||
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = hbn.choose(mpe);
|
||||
|
||||
// If TBB is enabled, the bayes net order gets reversed,
|
||||
// so we pre-reverse it
|
||||
#ifdef GTSAM_USE_TBB
|
||||
auto reversed = boost::adaptors::reverse(gbn);
|
||||
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
|
||||
#endif
|
||||
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
|
@ -50,11 +98,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
HybridBayesTreeClique::shared_ptr clique = node.second;
|
||||
HybridConditional::shared_ptr conditional = clique->conditional();
|
||||
|
||||
KeyVector frontals(conditional->frontals().begin(),
|
||||
conditional->frontals().end());
|
||||
|
||||
// Record the key being added
|
||||
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
|
||||
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
|
||||
conditional->frontals().end());
|
||||
|
||||
// If conditional is hybrid (and not discrete-only), we get the Gaussian
|
||||
// Conditional corresponding to the assignment and add it to the Gaussian
|
||||
|
@ -65,9 +111,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
(*gm)(assignment);
|
||||
|
||||
gbn.push_back(gaussian_conditional);
|
||||
|
||||
} else if (conditional->isContinuous()) {
|
||||
// If conditional is Gaussian, we simply add it to the Bayes net.
|
||||
gbn.push_back(conditional->asGaussian());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If TBB is enabled, the bayes net order gets reversed,
|
||||
// so we pre-reverse it
|
||||
#ifdef GTSAM_USE_TBB
|
||||
|
|
|
@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
|
|||
/** Check equality */
|
||||
bool equals(const This& other, double tol = 1e-9) const;
|
||||
|
||||
/**
|
||||
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
|
||||
* set of discrete variables and using it to compute the best continuous
|
||||
* update delta.
|
||||
*
|
||||
* @return HybridValues
|
||||
*/
|
||||
HybridValues optimize() const;
|
||||
|
||||
/**
|
||||
* @brief Recursively optimize the BayesTree to produce a vector solution.
|
||||
*
|
||||
|
|
|
@ -139,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
|
|||
return boost::static_pointer_cast<GaussianMixture>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianConditional
|
||||
*
|
||||
* @return GaussianConditional::shared_ptr
|
||||
*/
|
||||
GaussianConditional::shared_ptr asGaussian() {
|
||||
if (!isContinuous())
|
||||
throw std::invalid_argument("Not a continuous conditional");
|
||||
return boost::static_pointer_cast<GaussianConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
*
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file DiscreteLookupDAG.cpp
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using std::pair;
|
||||
using std::vector;
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************** */
|
||||
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
||||
// For discrete conditional, uses argmaxInPlace() method in
|
||||
// DiscreteLookupTable.
|
||||
if (isDiscrete()) {
|
||||
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
|
||||
&(values->discrete));
|
||||
} else if (isContinuous()) {
|
||||
// For Gaussian conditional, uses solve() method in GaussianConditional.
|
||||
values->continuous.insert(
|
||||
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
|
||||
values->continuous));
|
||||
} else if (isHybrid()) {
|
||||
// For hybrid conditional, since children should not contain discrete
|
||||
// variable, we can condition on the discrete variable in the parents and
|
||||
// solve the resulting GaussianConditional.
|
||||
auto conditional =
|
||||
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
|
||||
values->discrete);
|
||||
values->continuous.insert(conditional->solve(values->continuous));
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
|
||||
HybridLookupDAG dag;
|
||||
for (auto&& conditional : bayesNet) {
|
||||
HybridLookupTable hlt(*conditional);
|
||||
dag.push_back(hlt);
|
||||
}
|
||||
return dag;
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
|
||||
// Argmax each node in turn in topological sort order (parents first).
|
||||
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||
lookupTable->argmaxInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
|
@ -1,119 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridLookupDAG.h
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief HybridLookupTable table for max-product
|
||||
*
|
||||
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
|
||||
* convenience. Is used in the max-product algorithm.
|
||||
*/
|
||||
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
||||
public:
|
||||
using Base = HybridConditional;
|
||||
using This = HybridLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
|
||||
*
|
||||
* @param conditional input hybrid conditional
|
||||
*/
|
||||
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
|
||||
|
||||
/**
|
||||
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||
* @param (in/out) parentsValues Known assignments for the parents.
|
||||
*/
|
||||
void argmaxInPlace(HybridValues* parentsValues) const;
|
||||
};
|
||||
|
||||
/** A DAG made from hybrid lookup tables, as defined above. Similar to
|
||||
* DiscreteLookupDAG */
|
||||
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
|
||||
public:
|
||||
using Base = BayesNet<HybridLookupTable>;
|
||||
using This = HybridLookupDAG;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/// Construct empty DAG.
|
||||
HybridLookupDAG() {}
|
||||
|
||||
/// Create from BayesNet with LookupTables
|
||||
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
|
||||
|
||||
/// Destructor
|
||||
virtual ~HybridLookupDAG() {}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution, optionally given certain variables.
|
||||
*
|
||||
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first *and* that the
|
||||
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||
* resulted from eliminating a factor graph, this is true for the elimination
|
||||
* ordering.
|
||||
*
|
||||
* @return given assignment extended w. optimal assignment for all variables.
|
||||
*/
|
||||
HybridValues argmax(HybridValues given = HybridValues()) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -36,40 +36,58 @@ namespace gtsam {
|
|||
* Optimizing a HybridGaussianBayesNet returns this class.
|
||||
*/
|
||||
class GTSAM_EXPORT HybridValues {
|
||||
public:
|
||||
private:
|
||||
// DiscreteValue stored the discrete components of the HybridValues.
|
||||
DiscreteValues discrete;
|
||||
DiscreteValues discrete_;
|
||||
|
||||
// VectorValue stored the continuous components of the HybridValues.
|
||||
VectorValues continuous;
|
||||
VectorValues continuous_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
// Default constructor creates an empty HybridValues.
|
||||
HybridValues() : discrete(), continuous(){};
|
||||
HybridValues() = default;
|
||||
|
||||
// Construct from DiscreteValues and VectorValues.
|
||||
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
|
||||
: discrete(dv), continuous(cv){};
|
||||
: discrete_(dv), continuous_(cv){};
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
// print required by Testable for unit testing
|
||||
void print(const std::string& s = "HybridValues",
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
std::cout << s << ": \n";
|
||||
discrete.print(" Discrete", keyFormatter); // print discrete components
|
||||
continuous.print(" Continuous",
|
||||
keyFormatter); // print continuous components
|
||||
discrete_.print(" Discrete", keyFormatter); // print discrete components
|
||||
continuous_.print(" Continuous",
|
||||
keyFormatter); // print continuous components
|
||||
};
|
||||
|
||||
// equals required by Testable for unit testing
|
||||
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
||||
return discrete.equals(other.discrete, tol) &&
|
||||
continuous.equals(other.continuous, tol);
|
||||
return discrete_.equals(other.discrete_, tol) &&
|
||||
continuous_.equals(other.continuous_, tol);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Interface
|
||||
/// @{
|
||||
|
||||
/// Return the discrete MPE assignment
|
||||
DiscreteValues discrete() const { return discrete_; }
|
||||
|
||||
/// Return the delta update for the continuous vectors
|
||||
VectorValues continuous() const { return continuous_; }
|
||||
|
||||
// Check whether a variable with key \c j exists in DiscreteValue.
|
||||
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
|
||||
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
|
||||
|
||||
// Check whether a variable with key \c j exists in VectorValue.
|
||||
bool existsVector(Key j) { return continuous.exists(j); };
|
||||
bool existsVector(Key j) { return continuous_.exists(j); };
|
||||
|
||||
// Check whether a variable with key \c j exists.
|
||||
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
|
||||
|
@ -78,13 +96,13 @@ class GTSAM_EXPORT HybridValues {
|
|||
* the key \c j is already used.
|
||||
* @param value The vector to be inserted.
|
||||
* @param j The index with which the value will be associated. */
|
||||
void insert(Key j, int value) { discrete[j] = value; };
|
||||
void insert(Key j, int value) { discrete_[j] = value; };
|
||||
|
||||
/** Insert a vector \c value with key \c j. Throws an invalid_argument
|
||||
* exception if the key \c j is already used.
|
||||
* @param value The vector to be inserted.
|
||||
* @param j The index with which the value will be associated. */
|
||||
void insert(Key j, const Vector& value) { continuous.insert(j, value); }
|
||||
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }
|
||||
|
||||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
||||
|
||||
|
@ -92,13 +110,13 @@ class GTSAM_EXPORT HybridValues {
|
|||
* Read/write access to the discrete value with key \c j, throws
|
||||
* std::out_of_range if \c j does not exist.
|
||||
*/
|
||||
size_t& atDiscrete(Key j) { return discrete.at(j); };
|
||||
size_t& atDiscrete(Key j) { return discrete_.at(j); };
|
||||
|
||||
/**
|
||||
* Read/write access to the vector value with key \c j, throws
|
||||
* std::out_of_range if \c j does not exist.
|
||||
*/
|
||||
Vector& at(Key j) { return continuous.at(j); };
|
||||
Vector& at(Key j) { return continuous_.at(j); };
|
||||
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
@ -112,8 +130,8 @@ class GTSAM_EXPORT HybridValues {
|
|||
std::string html(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||
std::stringstream ss;
|
||||
ss << this->discrete.html(keyFormatter);
|
||||
ss << this->continuous.html(keyFormatter);
|
||||
ss << this->discrete_.html(keyFormatter);
|
||||
ss << this->continuous_.html(keyFormatter);
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ namespace gtsam {
|
|||
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
class HybridValues {
|
||||
gtsam::DiscreteValues discrete;
|
||||
gtsam::VectorValues continuous;
|
||||
gtsam::DiscreteValues discrete() const;
|
||||
gtsam::VectorValues continuous() const;
|
||||
HybridValues();
|
||||
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
|
||||
void print(string s = "HybridValues",
|
||||
|
|
|
@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
|
|||
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(0)))(assignment),
|
||||
hybridBayesNet->atMixture(0)))(assignment),
|
||||
*gbn.at(0)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(1)))(assignment),
|
||||
hybridBayesNet->atMixture(1)))(assignment),
|
||||
*gbn.at(1)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(2)))(assignment),
|
||||
hybridBayesNet->atMixture(2)))(assignment),
|
||||
*gbn.at(2)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(3)))(assignment),
|
||||
hybridBayesNet->atMixture(3)))(assignment),
|
||||
*gbn.at(3)));
|
||||
}
|
||||
|
||||
|
@ -125,35 +125,25 @@ TEST(HybridBayesNet, OptimizeAssignment) {
|
|||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering ordering;
|
||||
for (auto&& kvp : s.linearizationPoint) {
|
||||
ordering += kvp.key;
|
||||
}
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
delta.print();
|
||||
VectorValues correct;
|
||||
correct.insert(X(1), 0 * Vector1::Ones());
|
||||
correct.insert(X(2), 1 * Vector1::Ones());
|
||||
correct.insert(X(3), 2 * Vector1::Ones());
|
||||
correct.insert(X(4), 3 * Vector1::Ones());
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(1)] = 1;
|
||||
expectedAssignment[M(2)] = 0;
|
||||
expectedAssignment[M(3)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
DiscreteValues assignment111;
|
||||
assignment111[M(1)] = 1;
|
||||
assignment111[M(2)] = 1;
|
||||
assignment111[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
|
||||
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
|
||||
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
|
||||
|
||||
DiscreteValues assignment101;
|
||||
assignment101[M(1)] = 1;
|
||||
assignment101[M(2)] = 0;
|
||||
assignment101[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
* @date August 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
|
||||
|
@ -31,8 +32,8 @@ using symbol_shorthand::M;
|
|||
using symbol_shorthand::X;
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test for optimizing a HybridBayesTree.
|
||||
TEST(HybridBayesTree, Optimize) {
|
||||
// Test for optimizing a HybridBayesTree with a given assignment.
|
||||
TEST(HybridBayesTree, OptimizeAssignment) {
|
||||
Switching s(4);
|
||||
|
||||
HybridGaussianISAM isam;
|
||||
|
@ -50,6 +51,11 @@ TEST(HybridBayesTree, Optimize) {
|
|||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the discrete factors
|
||||
for (size_t i = 7; i <= 9; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
isam.update(graph1);
|
||||
|
||||
DiscreteValues assignment;
|
||||
|
@ -85,6 +91,58 @@ TEST(HybridBayesTree, Optimize) {
|
|||
EXPECT(assert_equal(expected, delta));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test for optimizing a HybridBayesTree.
|
||||
TEST(HybridBayesTree, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
HybridGaussianISAM isam;
|
||||
HybridGaussianFactorGraph graph1;
|
||||
|
||||
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
|
||||
for (size_t i = 1; i < 4; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the Gaussian factors, 1 prior on X(1),
|
||||
// 3 measurements on X(2), X(3), X(4)
|
||||
graph1.push_back(s.linearizedFactorGraph.at(0));
|
||||
for (size_t i = 4; i <= 6; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
// Add the discrete factors
|
||||
for (size_t i = 7; i <= 9; i++) {
|
||||
graph1.push_back(s.linearizedFactorGraph.at(i));
|
||||
}
|
||||
|
||||
isam.update(graph1);
|
||||
|
||||
HybridValues delta = isam.optimize();
|
||||
|
||||
// Create ordering.
|
||||
Ordering ordering;
|
||||
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
|
||||
|
||||
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto&& f : *remainingFactorGraph) {
|
||||
auto factor = dynamic_pointer_cast<HybridDiscreteFactor>(f);
|
||||
dfg.push_back(
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
|
||||
}
|
||||
|
||||
DiscreteValues expectedMPE = dfg.optimize();
|
||||
VectorValues expectedValues = hybridBayesNet->optimize(expectedMPE);
|
||||
|
||||
EXPECT(assert_equal(expectedMPE, delta.discrete()));
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -1,272 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testHybridLookupDAG.cpp
|
||||
* @date Aug, 2022
|
||||
* @author Shangjie Xue
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
#include <gtsam/nonlinear/Values.h>
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
TEST(HybridLookupTable, basics) {
|
||||
// create a conditional gaussian node
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 2;
|
||||
S1(0, 1) = 3;
|
||||
S1(1, 1) = 4;
|
||||
|
||||
Matrix S2(2, 2);
|
||||
S2(0, 0) = 6;
|
||||
S2(1, 0) = 0.2;
|
||||
S2(0, 1) = 8;
|
||||
S2(1, 1) = 0.4;
|
||||
|
||||
Matrix R1(2, 2);
|
||||
R1(0, 0) = 0.1;
|
||||
R1(1, 0) = 0.3;
|
||||
R1(0, 1) = 0.0;
|
||||
R1(1, 1) = 0.34;
|
||||
|
||||
Matrix R2(2, 2);
|
||||
R2(0, 0) = 0.1;
|
||||
R2(1, 0) = 0.3;
|
||||
R2(0, 1) = 0.0;
|
||||
R2(1, 1) = 0.34;
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
|
||||
|
||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||
X(2), S1, model),
|
||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
|
||||
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
|
||||
|
||||
HybridConditional hc(mixtureFactor);
|
||||
|
||||
GaussianMixture::Conditionals conditional2 =
|
||||
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 1;
|
||||
|
||||
VectorValues cv;
|
||||
cv.insert(X(2), Vector2(0.0, 0.0));
|
||||
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
// std::cout << conditional2(values).markdown();
|
||||
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
|
||||
EXPECT(conditional2(dv) == conditionals(dv));
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
// hlt.argmaxInPlace(&hv);
|
||||
|
||||
HybridLookupDAG dag;
|
||||
dag.push_back(hlt);
|
||||
dag.argmax(hv);
|
||||
|
||||
// HybridBayesNet hbn;
|
||||
// hbn.push_back(hc);
|
||||
// hbn.optimize();
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, hybrid_argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
|
||||
conditional1 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
|
||||
|
||||
HybridConditional hc(mixtureFactor);
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 1;
|
||||
VectorValues cv;
|
||||
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.at(X(1)), d2));
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, discrete_argmax) {
|
||||
DiscreteKey X(0, 2), Y(1, 2);
|
||||
|
||||
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
|
||||
|
||||
HybridConditional hc(conditional);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
DiscreteValues dv;
|
||||
dv[1] = 0;
|
||||
VectorValues cv;
|
||||
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.atDiscrete(0), 1));
|
||||
|
||||
DecisionTreeFactor f1(X, "2 3");
|
||||
auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
|
||||
|
||||
HybridConditional hc2(conditional2);
|
||||
|
||||
HybridLookupTable hlt2(hc2);
|
||||
|
||||
HybridValues hv2;
|
||||
|
||||
hlt2.argmaxInPlace(&hv2);
|
||||
|
||||
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
|
||||
}
|
||||
|
||||
TEST(HybridLookupTable, gaussian_argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||
|
||||
HybridConditional hc(conditional);
|
||||
|
||||
HybridLookupTable hlt(hc);
|
||||
|
||||
DiscreteValues dv;
|
||||
// dv[1]=0;
|
||||
VectorValues cv;
|
||||
cv.insert(X(2), d2);
|
||||
HybridValues hv(dv, cv);
|
||||
|
||||
hlt.argmaxInPlace(&hv);
|
||||
|
||||
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
|
||||
}
|
||||
|
||||
TEST(HybridLookupDAG, argmax) {
|
||||
Matrix S1(2, 2);
|
||||
S1(0, 0) = 1;
|
||||
S1(1, 0) = 0;
|
||||
S1(0, 1) = 0;
|
||||
S1(1, 1) = 1;
|
||||
|
||||
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
|
||||
|
||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||
|
||||
auto conditional0 =
|
||||
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
|
||||
conditional1 =
|
||||
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
|
||||
HybridConditional hc2(mixtureFactor);
|
||||
HybridLookupTable hlt2(hc2);
|
||||
|
||||
auto conditional2 =
|
||||
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||
|
||||
HybridConditional hc1(conditional2);
|
||||
HybridLookupTable hlt1(hc1);
|
||||
|
||||
DecisionTreeFactor f1(m1, "2 3");
|
||||
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
|
||||
|
||||
HybridConditional hc3(discrete_conditional);
|
||||
HybridLookupTable hlt3(hc3);
|
||||
|
||||
HybridLookupDAG dag;
|
||||
dag.push_back(hlt1);
|
||||
dag.push_back(hlt2);
|
||||
dag.push_back(hlt3);
|
||||
auto hv = dag.argmax();
|
||||
|
||||
EXPECT(assert_equal(hv.atDiscrete(1), 1));
|
||||
EXPECT(assert_equal(hv.at(X(2)), d2));
|
||||
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
Loading…
Reference in New Issue