Merge pull request #1270 from borglab/hybrid/hybrid-optimize

Linear HybridBayesNet optimization
release/4.3a0
Shangjie Xue 2022-08-22 09:31:37 -04:00 committed by GitHub
commit ef066a0747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 776 additions and 0 deletions

View File

@ -11,10 +11,13 @@
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @date January 2022
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
namespace gtsam {
@ -40,4 +43,10 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}
/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
}
} // namespace gtsam

View File

@ -18,6 +18,7 @@
#pragma once
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/linear/GaussianBayesNet.h>
@ -61,6 +62,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return GaussianBayesNet
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;
/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;
};
} // namespace gtsam

View File

@ -0,0 +1,76 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h>
#include <string>
#include <utility>
using std::pair;
using std::vector;
namespace gtsam {
/* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in
// DiscreteLookupTable.
if (isDiscrete()) {
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
values->continuous));
} else if (isHybrid()) {
// For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous));
}
}
/* ************************************************************************** */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag;
for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional);
dag.push_back(hlt);
}
return dag;
}
/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}
} // namespace gtsam

View File

@ -0,0 +1,119 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridLookupDAG.h
* @date Aug, 2022
* @author Shangjie Xue
*/
#pragma once
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>
namespace gtsam {
/**
* @brief HybridLookupTable table for max-product
*
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
* convenience. Is used in the max-product algorithm.
*/
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public:
using Base = HybridConditional;
using This = HybridLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;
/**
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
*
* @param conditional input hybrid conditional
*/
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(HybridValues* parentsValues) const;
};
/** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
public:
using Base = BayesNet<HybridLookupTable>;
using This = HybridLookupDAG;
using shared_ptr = boost::shared_ptr<This>;
/// @name Standard Constructors
/// @{
/// Construct empty DAG.
HybridLookupDAG() {}
/// Create from BayesNet with LookupTables
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);
/// Destructor
virtual ~HybridLookupDAG() {}
/// @}
/// @name Standard Interface
/// @{
/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
}
/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
HybridValues argmax(HybridValues given = HybridValues()) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};
// traits
template <>
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};
} // namespace gtsam

127
gtsam/hybrid/HybridValues.h Normal file
View File

@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridValues.h
* @date Jul 28, 2022
* @author Shangjie Xue
*/
#pragma once
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
#include <map>
#include <string>
#include <vector>
namespace gtsam {
/**
* HybridValues represents a collection of DiscreteValues and VectorValues. It
* is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
public:
// DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete;
// VectorValue stored the continuous components of the HybridValues.
VectorValues continuous;
// Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous(){};
// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){};
// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous",
keyFormatter); // print continuous components
};
// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) &&
continuous.equals(other.continuous, tol);
}
// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous.exists(j); };
// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
/** Insert a discrete \c value with key \c j. Replaces the existing value if
* the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; };
/** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); }
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
/**
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete.at(j); };
/**
* Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
Vector& at(Key j) { return continuous.at(j); };
/// @name Wrapper support
/// @{
/**
* @brief Output as a html table.
*
* @param keyFormatter function that formats keys.
* @return string html output.
*/
std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss;
ss << this->discrete.html(keyFormatter);
ss << this->continuous.html(keyFormatter);
return ss.str();
};
/// @}
};
// traits
template <>
struct traits<HybridValues> : public Testable<HybridValues> {};
} // namespace gtsam

View File

@ -4,6 +4,22 @@
namespace gtsam {
#include <gtsam/hybrid/HybridValues.h>
class HybridValues {
gtsam::DiscreteValues discrete;
gtsam::VectorValues continuous;
HybridValues();
HybridValues(const gtsam::DiscreteValues &dv, const gtsam::VectorValues &cv);
void print(string s = "HybridValues",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridValues& other, double tol) const;
void insert(gtsam::Key j, int value);
void insert(gtsam::Key j, const gtsam::Vector& value);
size_t& atDiscrete(gtsam::Key j);
gtsam::Vector& at(gtsam::Key j);
};
#include <gtsam/hybrid/HybridFactor.h>
virtual class HybridFactor {
void print(string s = "HybridFactor\n",
@ -84,6 +100,7 @@ class HybridBayesNet {
size_t size() const;
gtsam::KeySet keys() const;
const gtsam::HybridConditional* at(size_t i) const;
gtsam::HybridValues optimize() const;
void print(string s = "HybridBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;

View File

@ -17,6 +17,7 @@
#include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
@ -30,6 +31,7 @@
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/DotWriter.h>
#include <gtsam/inference/Key.h>
@ -501,6 +503,27 @@ TEST_DISABLED(HybridGaussianFactorGraph, SwitchingTwoVar) {
}
}
TEST(HybridGaussianFactorGraph, optimize) {
HybridGaussianFactorGraph hfg;
DiscreteKey c1(C(1), 2);
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
DecisionTree<Key, GaussianFactor::shared_ptr> dt(
C(1), boost::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
boost::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));
auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
HybridValues hv = result->optimize();
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -0,0 +1,272 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
#include <iostream>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
TEST(HybridLookupTable, basics) {
// create a conditional gaussian node
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 2;
S1(0, 1) = 3;
S1(1, 1) = 4;
Matrix S2(2, 2);
S2(0, 0) = 6;
S2(1, 0) = 0.2;
S2(0, 1) = 8;
S2(1, 1) = 0.4;
Matrix R1(2, 2);
R1(0, 0) = 0.1;
R1(1, 0) = 0.3;
R1(0, 1) = 0.0;
R1(1, 1) = 0.34;
Matrix R2(2, 2);
R2(0, 0) = 0.1;
R2(1, 0) = 0.3;
R2(0, 1) = 0.0;
R2(1, 1) = 0.34;
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
Vector2 d1(0.2, 0.5), d2(0.5, 0.2);
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, R1,
X(2), S1, model),
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, R2,
X(2), S2, model);
// Create decision tree
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
GaussianMixture::Conditionals conditional2 =
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
cv.insert(X(2), Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
// std::cout << conditional2(values).markdown();
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
EXPECT(conditional2(dv) == conditionals(dv));
HybridLookupTable hlt(hc);
// hlt.argmaxInPlace(&hv);
HybridLookupDAG dag;
dag.push_back(hlt);
dag.argmax(hv);
// HybridBayesNet hbn;
// hbn.push_back(hc);
// hbn.optimize();
}
TEST(HybridLookupTable, hybrid_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
HybridConditional hc(mixtureFactor);
DiscreteValues dv;
dv[1] = 1;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
HybridLookupTable hlt(hc);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d2));
}
TEST(HybridLookupTable, discrete_argmax) {
DiscreteKey X(0, 2), Y(1, 2);
auto conditional = boost::make_shared<DiscreteConditional>(X | Y = "0/1 3/2");
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
dv[1] = 0;
VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.atDiscrete(0), 1));
DecisionTreeFactor f1(X, "2 3");
auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc2(conditional2);
HybridLookupTable hlt2(hc2);
HybridValues hv2;
hlt2.argmaxInPlace(&hv2);
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
}
TEST(HybridLookupTable, gaussian_argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc(conditional);
HybridLookupTable hlt(hc);
DiscreteValues dv;
// dv[1]=0;
VectorValues cv;
cv.insert(X(2), d2);
HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
}
TEST(HybridLookupDAG, argmax) {
Matrix S1(2, 2);
S1(0, 0) = 1;
S1(1, 0) = 0;
S1(0, 1) = 0;
S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 =
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
HybridConditional hc2(mixtureFactor);
HybridLookupTable hlt2(hc2);
auto conditional2 =
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc1(conditional2);
HybridLookupTable hlt1(hc1);
DecisionTreeFactor f1(m1, "2 3");
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc3(discrete_conditional);
HybridLookupTable hlt3(hc3);
HybridLookupDAG dag;
dag.push_back(hlt1);
dag.push_back(hlt2);
dag.push_back(hlt3);
auto hv = dag.argmax();
EXPECT(assert_equal(hv.atDiscrete(1), 1));
EXPECT(assert_equal(hv.at(X(2)), d2));
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -0,0 +1,58 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridValues.cpp
* @date Jul 28, 2022
* @author Shangjie Xue
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
TEST(HybridValues, basics) {
HybridValues values;
values.insert(99, Vector2(2, 3));
values.insert(100, 3);
EXPECT(assert_equal(values.at(99), Vector2(2, 3)));
EXPECT(assert_equal(values.atDiscrete(100), int(3)));
values.print();
HybridValues values2;
values2.insert(100, 3);
values2.insert(99, Vector2(2, 3));
EXPECT(assert_equal(values2, values));
values2.insert(98, Vector2(2, 3));
EXPECT(!assert_equal(values2, values));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

View File

@ -55,6 +55,34 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
discrete_conditional = hbn.at(hbn.size() - 1).inner()
self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional)
def test_optimize(self):
"""Test contruction of hybrid factor graph."""
noiseModel = gtsam.noiseModel.Unit.Create(3)
dk = gtsam.DiscreteKeys()
dk.push_back((C(0), 2))
jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)),
noiseModel)
jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)),
noiseModel)
gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2])
hfg = gtsam.HybridGaussianFactorGraph()
hfg.add(jf1)
hfg.add(jf2)
hfg.push_back(gmf)
dtf = gtsam.DecisionTreeFactor([(C(0), 2)],"0 1")
hfg.add(dtf)
hbn = hfg.eliminateSequential(
gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph(
hfg, [C(0)]))
# print("hbn = ", hbn)
hv = hbn.optimize()
self.assertEqual(hv.atDiscrete(C(0)), 1)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,41 @@
"""
GTSAM Copyright 2010-2019, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Unit tests for Hybrid Values.
Author: Shangjie Xue
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import unittest
import gtsam
import numpy as np
from gtsam.symbol_shorthand import C, X
from gtsam.utils.test_case import GtsamTestCase
class TestHybridGaussianFactorGraph(GtsamTestCase):
"""Unit tests for HybridValues."""
def test_basic(self):
"""Test contruction and basic methods of hybrid values."""
hv1 = gtsam.HybridValues()
hv1.insert(X(0), np.ones((3,1)))
hv1.insert(C(0), 2)
hv2 = gtsam.HybridValues()
hv2.insert(C(0), 2)
hv2.insert(X(0), np.ones((3,1)))
self.assertEqual(hv1.atDiscrete(C(0)), 2)
self.assertEqual(hv1.at(X(0))[0], np.ones((3,1))[0])
if __name__ == "__main__":
unittest.main()