Address review comments

release/4.3a0
sjxue 2022-08-16 18:26:59 -04:00
parent 7d36a9eb98
commit 379a65f40f
7 changed files with 122 additions and 133 deletions

View File

@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
GaussianBayesNet choose(const DiscreteValues &assignment) const; GaussianBayesNet choose(const DiscreteValues &assignment) const;
/// Solve the HybridBayesNet by back-substitution. /// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and put this method there? /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const; HybridValues optimize() const;
}; };

View File

@ -18,10 +18,10 @@
#include <gtsam/discrete/DiscreteBayesNet.h> #include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h> #include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h> #include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/linear/VectorValues.h>
#include <string> #include <string>
@ -32,28 +32,32 @@ using std::vector;
namespace gtsam { namespace gtsam {
/* ************************************************************************** */ /* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const { void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable. // For discrete conditional, uses argmaxInPlace() method in
if (isDiscrete()){ // DiscreteLookupTable.
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(&(values->discrete)); if (isDiscrete()) {
} else if (isContinuous()){ boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional. // For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(boost::static_pointer_cast<GaussianConditional>(inner_)->solve(values->continuous)); values->continuous.insert(
} else if (isHybrid()){ boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
// For hybrid conditional, since children should not contain discrete variable, we can condition on values->continuous));
// the discrete variable in the parents and solve the resulting GaussianConditional. } else if (isHybrid()) {
auto conditional = boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(values->discrete); // For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous)); values->continuous.insert(conditional->solve(values->continuous));
} }
} }
// /* **************************************************************************
// /* ************************************************************************** */ // */
HybridLookupDAG HybridLookupDAG::FromBayesNet( HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
const HybridBayesNet& bayesNet) {
HybridLookupDAG dag; HybridLookupDAG dag;
for (auto&& conditional : bayesNet) { for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional); HybridLookupTable hlt(*conditional);

View File

@ -19,10 +19,10 @@
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h> #include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <string> #include <string>
@ -34,8 +34,8 @@ namespace gtsam {
/** /**
* @brief HybridLookupTable table for max-product * @brief HybridLookupTable table for max-product
* *
* Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience. * Similar to DiscreteLookupTable, inherits from hybrid conditional for
* Is used in the max-product algorithm. * convenience. Is used in the max-product algorithm.
*/ */
class GTSAM_EXPORT HybridLookupTable : public HybridConditional { class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public: public:
@ -58,7 +58,8 @@ class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
void argmaxInPlace(HybridValues* parentsValues) const; void argmaxInPlace(HybridValues* parentsValues) const;
}; };
/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */ /** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> { class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
public: public:
using Base = BayesNet<HybridLookupTable>; using Base = BayesNet<HybridLookupTable>;

View File

@ -19,11 +19,10 @@
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Key.h>
#include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
#include <map> #include <map>
#include <string> #include <string>
@ -32,8 +31,9 @@
namespace gtsam { namespace gtsam {
/** /**
* HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables * HybridValues represents a collection of DiscreteValues and VectorValues. It
* of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class. * is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/ */
class GTSAM_EXPORT HybridValues { class GTSAM_EXPORT HybridValues {
public: public:
@ -44,54 +44,47 @@ class GTSAM_EXPORT HybridValues {
VectorValues continuous; VectorValues continuous;
// Default constructor creates an empty HybridValues. // Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous() {}; HybridValues() : discrete(), continuous(){};
// Construct from DiscreteValues and VectorValues. // Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {}; HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){};
// print required by Testable for unit testing // print required by Testable for unit testing
void print(const std::string& s = "HybridValues", void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n"; std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components discrete.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous", keyFormatter); //print continuous components continuous.print(" Continuous",
keyFormatter); // print continuous components
}; };
// equals required by Testable for unit testing // equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const { bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol); return discrete.equals(other.discrete, tol) &&
continuous.equals(other.continuous, tol);
} }
// Check whether a variable with key \c j exists in DiscreteValue. // Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j){ bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
return (discrete.find(j) != discrete.end());
};
// Check whether a variable with key \c j exists in VectorValue. // Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j){ bool existsVector(Key j) { return continuous.exists(j); };
return continuous.exists(j);
};
// Check whether a variable with key \c j exists. // Check whether a variable with key \c j exists.
bool exists(Key j){ bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
return existsDiscrete(j) || existsVector(j);
};
/** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c /** Insert a discrete \c value with key \c j. Replaces the existing value if
* j is already used. * the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, int value){ void insert(Key j, int value) { discrete[j] = value; };
discrete[j] = value;
};
/** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c /** Insert a vector \c value with key \c j. Throws an invalid_argument
* j is already used. * exception if the key \c j is already used.
* @param value The vector to be inserted. * @param value The vector to be inserted.
* @param j The index with which the value will be associated. */ * @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { void insert(Key j, const Vector& value) { continuous.insert(j, value); }
continuous.insert(j, value);
}
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
@ -99,18 +92,13 @@ class GTSAM_EXPORT HybridValues {
* Read/write access to the discrete value with key \c j, throws * Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
size_t& atDiscrete(Key j){ size_t& atDiscrete(Key j) { return discrete.at(j); };
return discrete.at(j);
};
/** /**
* Read/write access to the vector value with key \c j, throws * Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist. * std::out_of_range if \c j does not exist.
*/ */
Vector& at(Key j) { Vector& at(Key j) { return continuous.at(j); };
return continuous.at(j);
};
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{
@ -121,7 +109,8 @@ class GTSAM_EXPORT HybridValues {
* @param keyFormatter function that formats keys. * @param keyFormatter function that formats keys.
* @return string html output. * @return string html output.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss; std::stringstream ss;
ss << this->discrete.html(keyFormatter); ss << this->discrete.html(keyFormatter);
ss << this->continuous.html(keyFormatter); ss << this->continuous.html(keyFormatter);

View File

@ -523,7 +523,6 @@ TEST(HybridGaussianFactorGraph, optimize) {
HybridValues hv = result->optimize(); HybridValues hv = result->optimize();
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {

View File

@ -17,19 +17,19 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/inference/Key.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/linear/VectorValues.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h> #include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -43,7 +43,7 @@ using symbol_shorthand::M;
using symbol_shorthand::X; using symbol_shorthand::X;
TEST(HybridLookupTable, basics) { TEST(HybridLookupTable, basics) {
// create a conditional gaussian node // create a conditional gaussian node
Matrix S1(2, 2); Matrix S1(2, 2);
S1(0, 0) = 1; S1(0, 0) = 1;
S1(1, 0) = 2; S1(1, 0) = 2;
@ -82,39 +82,38 @@ TEST(HybridLookupTable, basics) {
GaussianMixture::Conditionals conditionals( GaussianMixture::Conditionals conditionals(
{m1}, {m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); // GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
HybridConditional hc(mixtureFactor); HybridConditional hc(mixtureFactor);
GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals(); GaussianMixture::Conditionals conditional2 =
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
DiscreteValues dv; DiscreteValues dv;
dv[1]=1; dv[1] = 1;
VectorValues cv; VectorValues cv;
cv.insert(X(2),Vector2(0.0, 0.0)); cv.insert(X(2), Vector2(0.0, 0.0));
HybridValues hv(dv, cv); HybridValues hv(dv, cv);
// std::cout << conditional2(values).markdown(); // std::cout << conditional2(values).markdown();
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
EXPECT(conditional2(dv)==conditionals(dv)); EXPECT(conditional2(dv) == conditionals(dv));
HybridLookupTable hlt(hc); HybridLookupTable hlt(hc);
// hlt.argmaxInPlace(&hv); // hlt.argmaxInPlace(&hv);
HybridLookupDAG dag; HybridLookupDAG dag;
dag.push_back(hlt); dag.push_back(hlt);
dag.argmax(hv); dag.argmax(hv);
// HybridBayesNet hbn; // HybridBayesNet hbn;
// hbn.push_back(hc); // hbn.push_back(hc);
// hbn.optimize(); // hbn.optimize();
} }
TEST(HybridLookupTable, hybrid_argmax) { TEST(HybridLookupTable, hybrid_argmax) {
@ -124,23 +123,26 @@ TEST(HybridLookupTable, hybrid_argmax) {
S1(0, 1) = 0; S1(0, 1) = 0;
S1(1, 1) = 1; S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5,0.6); Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, S1, model), auto conditional0 =
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, S1, model); boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals( GaussianMixture::Conditionals conditionals(
{m1}, {m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals)); boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
HybridConditional hc(mixtureFactor); HybridConditional hc(mixtureFactor);
DiscreteValues dv; DiscreteValues dv;
dv[1]=1; dv[1] = 1;
VectorValues cv; VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0)); // cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv); HybridValues hv(dv, cv);
@ -150,8 +152,6 @@ TEST(HybridLookupTable, hybrid_argmax) {
hlt.argmaxInPlace(&hv); hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d2)); EXPECT(assert_equal(hv.at(X(1)), d2));
} }
TEST(HybridLookupTable, discrete_argmax) { TEST(HybridLookupTable, discrete_argmax) {
@ -164,18 +164,17 @@ TEST(HybridLookupTable, discrete_argmax) {
HybridLookupTable hlt(hc); HybridLookupTable hlt(hc);
DiscreteValues dv; DiscreteValues dv;
dv[1]=0; dv[1] = 0;
VectorValues cv; VectorValues cv;
// cv.insert(X(2),Vector2(0.0, 0.0)); // cv.insert(X(2),Vector2(0.0, 0.0));
HybridValues hv(dv, cv); HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv); hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.atDiscrete(0), 1)); EXPECT(assert_equal(hv.atDiscrete(0), 1));
DecisionTreeFactor f1(X , "2 3"); DecisionTreeFactor f1(X, "2 3");
auto conditional2 = boost::make_shared<DiscreteConditional>(1,f1); auto conditional2 = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc2(conditional2); HybridConditional hc2(conditional2);
@ -183,7 +182,6 @@ TEST(HybridLookupTable, discrete_argmax) {
HybridValues hv2; HybridValues hv2;
hlt2.argmaxInPlace(&hv2); hlt2.argmaxInPlace(&hv2);
EXPECT(assert_equal(hv2.atDiscrete(0), 1)); EXPECT(assert_equal(hv2.atDiscrete(0), 1));
@ -196,12 +194,12 @@ TEST(HybridLookupTable, gaussian_argmax) {
S1(0, 1) = 0; S1(0, 1) = 0;
S1(1, 1) = 1; S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5,0.6); Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional = boost::make_shared<GaussianConditional>(X(1), d1, S1, auto conditional =
X(2), -S1, model); boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
HybridConditional hc(conditional); HybridConditional hc(conditional);
@ -210,48 +208,47 @@ TEST(HybridLookupTable, gaussian_argmax) {
DiscreteValues dv; DiscreteValues dv;
// dv[1]=0; // dv[1]=0;
VectorValues cv; VectorValues cv;
cv.insert(X(2),d2); cv.insert(X(2), d2);
HybridValues hv(dv, cv); HybridValues hv(dv, cv);
hlt.argmaxInPlace(&hv); hlt.argmaxInPlace(&hv);
EXPECT(assert_equal(hv.at(X(1)), d1+d2)); EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
} }
TEST(HybridLookupDAG, argmax) { TEST(HybridLookupDAG, argmax) {
Matrix S1(2, 2); Matrix S1(2, 2);
S1(0, 0) = 1; S1(0, 0) = 1;
S1(1, 0) = 0; S1(1, 0) = 0;
S1(0, 1) = 0; S1(0, 1) = 0;
S1(1, 1) = 1; S1(1, 1) = 1;
Vector2 d1(0.2, 0.5), d2(-0.5,0.6); Vector2 d1(0.2, 0.5), d2(-0.5, 0.6);
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
auto conditional0 = boost::make_shared<GaussianConditional>(X(2), d1, S1, model), auto conditional0 =
conditional1 = boost::make_shared<GaussianConditional>(X(2), d2, S1, model); boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
conditional1 =
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
GaussianMixture::Conditionals conditionals( GaussianMixture::Conditionals conditionals(
{m1}, {m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1}); vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals)); boost::shared_ptr<GaussianMixture> mixtureFactor(
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
HybridConditional hc2(mixtureFactor); HybridConditional hc2(mixtureFactor);
HybridLookupTable hlt2(hc2); HybridLookupTable hlt2(hc2);
auto conditional2 =
auto conditional2 = boost::make_shared<GaussianConditional>(X(1), d1, S1, boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
X(2), -S1, model);
HybridConditional hc1(conditional2); HybridConditional hc1(conditional2);
HybridLookupTable hlt1(hc1); HybridLookupTable hlt1(hc1);
DecisionTreeFactor f1(m1 , "2 3"); DecisionTreeFactor f1(m1, "2 3");
auto discrete_conditional = boost::make_shared<DiscreteConditional>(1,f1); auto discrete_conditional = boost::make_shared<DiscreteConditional>(1, f1);
HybridConditional hc3(discrete_conditional); HybridConditional hc3(discrete_conditional);
HybridLookupTable hlt3(hc3); HybridLookupTable hlt3(hc3);
@ -264,10 +261,9 @@ TEST(HybridLookupDAG, argmax) {
EXPECT(assert_equal(hv.atDiscrete(1), 1)); EXPECT(assert_equal(hv.atDiscrete(1), 1));
EXPECT(assert_equal(hv.at(X(2)), d2)); EXPECT(assert_equal(hv.at(X(2)), d2));
EXPECT(assert_equal(hv.at(X(1)), d2+d1)); EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -17,19 +17,18 @@
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Key.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -47,7 +46,7 @@ TEST(HybridValues, basics) {
values2.insert(99, Vector2(2, 3)); values2.insert(99, Vector2(2, 3));
EXPECT(assert_equal(values2, values)); EXPECT(assert_equal(values2, values));
values2.insert(98, Vector2(2,3)); values2.insert(98, Vector2(2, 3));
EXPECT(!assert_equal(values2, values)); EXPECT(!assert_equal(values2, values));
} }