HybridValues and optimize() method for hybrid gaussian bayes net
parent
a700c87504
commit
58068503f4
|
|
@ -10,7 +10,20 @@
|
||||||
* @file HybridBayesNet.cpp
|
* @file HybridBayesNet.cpp
|
||||||
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Shangjie Xue
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
|
auto dag = HybridLookupDAG::FromBayesNet(*this);
|
||||||
|
return dag.argmax();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -34,8 +35,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/// Construct empty bayes net
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~HybridBayesNet() {}
|
||||||
|
|
||||||
|
/// Solve the HybridBayesNet by back-substitution.
|
||||||
|
HybridValues optimize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteLookupDAG.cpp
|
||||||
|
* @date Aug, 2022
|
||||||
|
* @author Shangjie Xue
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using std::pair;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
||||||
|
// For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable.
|
||||||
|
if (isDiscrete()){
|
||||||
|
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(&(values->discrete));
|
||||||
|
} else if (isContinuous()){
|
||||||
|
// For Gaussian conditional, uses solve() method in GaussianConditional.
|
||||||
|
values->continuous.insert(boost::static_pointer_cast<GaussianConditional>(inner_)->solve(values->continuous));
|
||||||
|
} else if (isHybrid()){
|
||||||
|
// For hybrid conditional, since children should not contain discrete variable, we can condition on
|
||||||
|
// the discrete variable in the parents and solve the resulting GaussianConditional.
|
||||||
|
auto conditional = boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(values->discrete);
|
||||||
|
values->continuous.insert(conditional->solve(values->continuous));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// /* ************************************************************************** */
|
||||||
|
HybridLookupDAG HybridLookupDAG::FromBayesNet(
|
||||||
|
const HybridBayesNet& bayesNet) {
|
||||||
|
HybridLookupDAG dag;
|
||||||
|
for (auto&& conditional : bayesNet) {
|
||||||
|
HybridLookupTable hlt(*conditional);
|
||||||
|
dag.push_back(hlt);
|
||||||
|
}
|
||||||
|
return dag;
|
||||||
|
}
|
||||||
|
|
||||||
|
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
|
||||||
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
|
for (auto lookupTable : boost::adaptors::reverse(*this))
|
||||||
|
lookupTable->argmaxInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file HybridLookupDAG.h
|
||||||
|
* @date Aug, 2022
|
||||||
|
* @author Shangjie Xue
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief HybridLookupTable table for max-product
|
||||||
|
*
|
||||||
|
* Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience.
|
||||||
|
* Is used in the max-product algorithm.
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
||||||
|
public:
|
||||||
|
using Base = HybridConditional;
|
||||||
|
using This = HybridLookupTable;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
using BaseConditional = Conditional<DecisionTreeFactor, This>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
|
||||||
|
*
|
||||||
|
* @param conditional input hybrid conditional
|
||||||
|
*/
|
||||||
|
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(HybridValues* parentsValues) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */
|
||||||
|
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
|
||||||
|
public:
|
||||||
|
using Base = BayesNet<HybridLookupTable>;
|
||||||
|
using This = HybridLookupDAG;
|
||||||
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Construct empty DAG.
|
||||||
|
HybridLookupDAG() {}
|
||||||
|
|
||||||
|
/// Create from BayesNet with LookupTables
|
||||||
|
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
|
||||||
|
|
||||||
|
/// Destructor
|
||||||
|
virtual ~HybridLookupDAG() {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/** Add a DiscreteLookupTable */
|
||||||
|
template <typename... Args>
|
||||||
|
void add(Args&&... args) {
|
||||||
|
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief argmax by back-substitution, optionally given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the DAG is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first *and* that the
|
||||||
|
* DAG does not contain any conditionals for the given variables. If the DAG
|
||||||
|
* resulted from eliminating a factor graph, this is true for the elimination
|
||||||
|
* ordering.
|
||||||
|
*
|
||||||
|
* @return given assignment extended w. optimal assignment for all variables.
|
||||||
|
*/
|
||||||
|
HybridValues argmax(HybridValues given = HybridValues()) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,138 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file HybridValues.h
|
||||||
|
* @date Jul 28, 2022
|
||||||
|
* @author Shangjie Xue
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables
|
||||||
|
* of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class.
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT HybridValues {
|
||||||
|
public:
|
||||||
|
// DiscreteValue stored the discrete components of the HybridValues.
|
||||||
|
DiscreteValues discrete;
|
||||||
|
|
||||||
|
// VectorValue stored the continuous components of the HybridValues.
|
||||||
|
VectorValues continuous;
|
||||||
|
|
||||||
|
// Default constructor creates an empty HybridValues.
|
||||||
|
HybridValues() : discrete(), continuous() {};
|
||||||
|
|
||||||
|
// Construct from DiscreteValues and VectorValues.
|
||||||
|
HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {};
|
||||||
|
|
||||||
|
// print required by Testable for unit testing
|
||||||
|
void print(const std::string& s = "HybridValues",
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{
|
||||||
|
std::cout << s << ": \n";
|
||||||
|
discrete.print(" Discrete", keyFormatter);
|
||||||
|
continuous.print(" Continuous", keyFormatter);
|
||||||
|
};
|
||||||
|
|
||||||
|
// equals required by Testable for unit testing
|
||||||
|
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
||||||
|
return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether a variable with key \c j exists in DiscreteValue.
|
||||||
|
bool existsDiscrete(Key j){
|
||||||
|
return (discrete.find(j) != discrete.end());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check whether a variable with key \c j exists in VectorValue.
|
||||||
|
bool existsVector(Key j){
|
||||||
|
return continuous.exists(j);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check whether a variable with key \c j exists.
|
||||||
|
bool exists(Key j){
|
||||||
|
return existsDiscrete(j) || existsVector(j);
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c
|
||||||
|
* j is already used.
|
||||||
|
* @param value The vector to be inserted.
|
||||||
|
* @param j The index with which the value will be associated. */
|
||||||
|
void insert(Key j, int value){
|
||||||
|
discrete[j] = value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c
|
||||||
|
* j is already used.
|
||||||
|
* @param value The vector to be inserted.
|
||||||
|
* @param j The index with which the value will be associated. */
|
||||||
|
void insert(Key j, const Vector& value) {
|
||||||
|
continuous.insert(j, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read/write access to the discrete value with key \c j, throws
|
||||||
|
* std::out_of_range if \c j does not exist.
|
||||||
|
*/
|
||||||
|
size_t& atDiscrete(Key j){
|
||||||
|
return discrete.at(j);
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read/write access to the vector value with key \c j, throws
|
||||||
|
* std::out_of_range if \c j does not exist.
|
||||||
|
*/
|
||||||
|
Vector& at(Key j) {
|
||||||
|
return continuous.at(j);
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// @name Wrapper support
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Output as a html table.
|
||||||
|
*
|
||||||
|
* @param keyFormatter function that formats keys.
|
||||||
|
* @return string html output.
|
||||||
|
*/
|
||||||
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << this->discrete.html(keyFormatter);
|
||||||
|
ss << this->continuous.html(keyFormatter);
|
||||||
|
return ss.str();
|
||||||
|
};
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridValues> : public Testable<HybridValues> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -4,6 +4,22 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
class HybridValues {
|
||||||
|
gtsam::DiscreteValues discrete;
|
||||||
|
gtsam::VectorValues continuous;
|
||||||
|
HybridValues();
|
||||||
|
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
|
||||||
|
void print(string s = "HybridValues",
|
||||||
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
bool equals(const gtsam::HybridValues& other, double tol) const;
|
||||||
|
void insert(gtsam::Key j, int value);
|
||||||
|
void insert(gtsam::Key j, const gtsam::Vector& value);
|
||||||
|
size_t& atDiscrete(gtsam::Key j);
|
||||||
|
gtsam::Vector& at(gtsam::Key j);
|
||||||
|
};
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
virtual class HybridFactor {
|
virtual class HybridFactor {
|
||||||
void print(string s = "HybridFactor\n",
|
void print(string s = "HybridFactor\n",
|
||||||
|
|
@ -84,6 +100,7 @@ class HybridBayesNet {
|
||||||
size_t size() const;
|
size_t size() const;
|
||||||
gtsam::KeySet keys() const;
|
gtsam::KeySet keys() const;
|
||||||
const gtsam::HybridConditional* at(size_t i) const;
|
const gtsam::HybridConditional* at(size_t i) const;
|
||||||
|
gtsam::HybridValues optimize() const;
|
||||||
void print(string s = "HybridBayesNet\n",
|
void print(string s = "HybridBayesNet\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
#include <CppUnitLite/Test.h>
|
#include <CppUnitLite/Test.h>
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
|
@ -30,6 +31,7 @@
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
@ -585,6 +587,28 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(HybridGaussianFactorGraph, optimize) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
DiscreteKey c1(C(1), 2);
|
||||||
|
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
|
||||||
|
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
|
||||||
|
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
|
||||||
|
|
||||||
|
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
|
||||||
|
|
||||||
|
auto result =
|
||||||
|
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
|
||||||
|
|
||||||
|
HybridValues hv = result->optimize();
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
|
||||||
|
|
||||||
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,276 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridLookupDAG.cpp
|
||||||
|
* @date Aug, 2022
|
||||||
|
* @author Shangjie Xue
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
using noiseModel::Isotropic;
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
|
||||||
|
TEST(HybridLookupTable, basics) {
|
||||||
|
// create a conditional gaussian node
|
||||||
|
Matrix S1(2, 2);
|
||||||
|
S1(0, 0) = 1;
|
||||||
|
S1(1, 0) = 2;
|
||||||
|
S1(0, 1) = 3;
|
||||||
|
S1(1, 1) = 4;
|
||||||
|
|
||||||
|
Matrix S2(2, 2);
|
||||||
|
S2(0, 0) = 6;
|
||||||
|
S2(1, 0) = 0.2;
|
||||||
|
S2(0, 1) = 8;
|
||||||
|
S2(1, 1) = 0.4;
|
||||||
|
|
||||||
|
Matrix R1(2, 2);
|
||||||
|
R1(0, 0) = 0.1;
|
||||||
|
R1(1, 0) = 0.3;
|
||||||
|
R1(0, 1) = 0.0;
|
||||||
|
R1(1, 1) = 0.34;
|
||||||
|
|
||||||
|
Matrix R2(2, 2);
|
||||||
|
R2(0, 0) = 0.1;
|
||||||
|
R2(1, 0) = 0.3;
|
||||||
|
R2(0, 1) = 0.0;
|
||||||
|
R2(1, 1) = 0.34;
|
||||||
|
|
||||||
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
|
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
|
||||||
|
|
||||||
|
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
|
||||||
|
X(2), S1, model),
|
||||||
|
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
|
||||||
|
X(2), S2, model);
|
||||||
|
|
||||||
|
// Create decision tree
|
||||||
|
DiscreteKey m1(1, 2);
|
||||||
|
GaussianMixture::Conditionals conditionals(
|
||||||
|
{m1},
|
||||||
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
|
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
|
||||||
|
|
||||||
|
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
|
||||||
|
|
||||||
|
HybridConditional hc(mixtureFactor);
|
||||||
|
|
||||||
|
GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
|
||||||
|
|
||||||
|
DiscreteValues dv;
|
||||||
|
dv[1]=1;
|
||||||
|
|
||||||
|
VectorValues cv;
|
||||||
|
cv.insert(X(2),Vector2(0.0, 0.0));
|
||||||
|
|
||||||
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// std::cout << conditional2(values).markdown();
|
||||||
|
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
|
||||||
|
EXPECT(conditional2(dv)==conditionals(dv));
|
||||||
|
HybridLookupTable hlt(hc);
|
||||||
|
|
||||||
|
// hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
|
HybridLookupDAG dag;
|
||||||
|
dag.push_back(hlt);
|
||||||
|
dag.argmax(hv);
|
||||||
|
|
||||||
|
// HybridBayesNet hbn;
|
||||||
|
// hbn.push_back(hc);
|
||||||
|
// hbn.optimize();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HybridLookupTable, hybrid_argmax) {
|
||||||
|
Matrix S1(2, 2);
|
||||||
|
S1(0, 0) = 1;
|
||||||
|
S1(1, 0) = 0;
|
||||||
|
S1(0, 1) = 0;
|
||||||
|
S1(1, 1) = 1;
|
||||||
|
|
||||||
|
Vector2 d1(0.2, 0.5), d2(-0.5,0.6);
|
||||||
|
|
||||||
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
|
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
|
||||||
|
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
|
||||||
|
|
||||||
|
DiscreteKey m1(1, 2);
|
||||||
|
GaussianMixture::Conditionals conditionals(
|
||||||
|
{m1},
|
||||||
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
|
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals));
|
||||||
|
|
||||||
|
HybridConditional hc(mixtureFactor);
|
||||||
|
|
||||||
|
DiscreteValues dv;
|
||||||
|
dv[1]=1;
|
||||||
|
VectorValues cv;
|
||||||
|
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||||
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
HybridLookupTable hlt(hc);
|
||||||
|
|
||||||
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv.at(X(1)), d2));
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HybridLookupTable, discrete_argmax) {
|
||||||
|
DiscreteKey X(0, 2), Y(1, 2);
|
||||||
|
|
||||||
|
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
|
||||||
|
|
||||||
|
HybridConditional hc(conditional);
|
||||||
|
|
||||||
|
HybridLookupTable hlt(hc);
|
||||||
|
|
||||||
|
DiscreteValues dv;
|
||||||
|
dv[1]=0;
|
||||||
|
VectorValues cv;
|
||||||
|
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||||
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv.atDiscrete(0), 1));
|
||||||
|
|
||||||
|
DecisionTreeFactor f1(X , "2 3");
|
||||||
|
auto conditional2 = boost::make_shared<DiscreteConditional>(1,f1);
|
||||||
|
|
||||||
|
HybridConditional hc2(conditional2);
|
||||||
|
|
||||||
|
HybridLookupTable hlt2(hc2);
|
||||||
|
|
||||||
|
HybridValues hv2;
|
||||||
|
|
||||||
|
|
||||||
|
hlt2.argmaxInPlace(&hv2);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HybridLookupTable, gaussian_argmax) {
|
||||||
|
Matrix S1(2, 2);
|
||||||
|
S1(0, 0) = 1;
|
||||||
|
S1(1, 0) = 0;
|
||||||
|
S1(0, 1) = 0;
|
||||||
|
S1(1, 1) = 1;
|
||||||
|
|
||||||
|
Vector2 d1(0.2, 0.5), d2(-0.5,0.6);
|
||||||
|
|
||||||
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
|
auto conditional = boost::make_shared<GaussianConditional>(X(1), d1, S1,
|
||||||
|
X(2), -S1, model);
|
||||||
|
|
||||||
|
HybridConditional hc(conditional);
|
||||||
|
|
||||||
|
HybridLookupTable hlt(hc);
|
||||||
|
|
||||||
|
DiscreteValues dv;
|
||||||
|
// dv[1]=0;
|
||||||
|
VectorValues cv;
|
||||||
|
cv.insert(X(2),d2);
|
||||||
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv.at(X(1)), d1+d2));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(HybridLookupDAG, argmax) {
|
||||||
|
|
||||||
|
Matrix S1(2, 2);
|
||||||
|
S1(0, 0) = 1;
|
||||||
|
S1(1, 0) = 0;
|
||||||
|
S1(0, 1) = 0;
|
||||||
|
S1(1, 1) = 1;
|
||||||
|
|
||||||
|
Vector2 d1(0.2, 0.5), d2(-0.5,0.6);
|
||||||
|
|
||||||
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
|
auto conditional0 = boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
|
||||||
|
conditional1 = boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
|
||||||
|
|
||||||
|
DiscreteKey m1(1, 2);
|
||||||
|
GaussianMixture::Conditionals conditionals(
|
||||||
|
{m1},
|
||||||
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
|
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals));
|
||||||
|
HybridConditional hc2(mixtureFactor);
|
||||||
|
HybridLookupTable hlt2(hc2);
|
||||||
|
|
||||||
|
|
||||||
|
auto conditional2 = boost::make_shared<GaussianConditional>(X(1), d1, S1,
|
||||||
|
X(2), -S1, model);
|
||||||
|
|
||||||
|
HybridConditional hc1(conditional2);
|
||||||
|
HybridLookupTable hlt1(hc1);
|
||||||
|
|
||||||
|
DecisionTreeFactor f1(m1 , "2 3");
|
||||||
|
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1,f1);
|
||||||
|
|
||||||
|
HybridConditional hc3(discrete_conditional);
|
||||||
|
HybridLookupTable hlt3(hc3);
|
||||||
|
|
||||||
|
HybridLookupDAG dag;
|
||||||
|
dag.push_back(hlt1);
|
||||||
|
dag.push_back(hlt2);
|
||||||
|
dag.push_back(hlt3);
|
||||||
|
auto hv = dag.argmax();
|
||||||
|
|
||||||
|
EXPECT(assert_equal(hv.atDiscrete(1), 1));
|
||||||
|
EXPECT(assert_equal(hv.at(X(2)), d2));
|
||||||
|
EXPECT(assert_equal(hv.at(X(1)), d2+d1));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridValues.cpp
|
||||||
|
* @date Jul 28, 2022
|
||||||
|
* @author Shangjie Xue
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
TEST(HybridValues, basics) {
|
||||||
|
HybridValues values;
|
||||||
|
values.insert(99, Vector2(2, 3));
|
||||||
|
values.insert(100, 3);
|
||||||
|
EXPECT(assert_equal(values.at(99), Vector2(2, 3)));
|
||||||
|
EXPECT(assert_equal(values.atDiscrete(100), int(3)));
|
||||||
|
|
||||||
|
values.print();
|
||||||
|
|
||||||
|
HybridValues values2;
|
||||||
|
values2.insert(100, 3);
|
||||||
|
values2.insert(99, Vector2(2, 3));
|
||||||
|
EXPECT(assert_equal(values2, values));
|
||||||
|
|
||||||
|
values2.insert(98, Vector2(2,3));
|
||||||
|
EXPECT(!assert_equal(values2, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
Loading…
Reference in New Issue