diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index b19528120..9d6d5f236 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { GaussianBayesNet choose(const DiscreteValues &assignment) const; /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and put this method there? + /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and + /// put this method there? HybridValues optimize() const; }; diff --git a/gtsam/hybrid/HybridLookupDAG.cpp b/gtsam/hybrid/HybridLookupDAG.cpp index 7232309f4..7acff081b 100644 --- a/gtsam/hybrid/HybridLookupDAG.cpp +++ b/gtsam/hybrid/HybridLookupDAG.cpp @@ -18,10 +18,10 @@ #include #include #include -#include -#include #include +#include #include +#include #include #include @@ -32,28 +32,32 @@ using std::vector; namespace gtsam { - - /* ************************************************************************** */ void HybridLookupTable::argmaxInPlace(HybridValues* values) const { - // For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable. - if (isDiscrete()){ - boost::static_pointer_cast(inner_)->argmaxInPlace(&(values->discrete)); - } else if (isContinuous()){ + // For discrete conditional, uses argmaxInPlace() method in + // DiscreteLookupTable. + if (isDiscrete()) { + boost::static_pointer_cast(inner_)->argmaxInPlace( + &(values->discrete)); + } else if (isContinuous()) { // For Gaussian conditional, uses solve() method in GaussianConditional. - values->continuous.insert(boost::static_pointer_cast(inner_)->solve(values->continuous)); - } else if (isHybrid()){ - // For hybrid conditional, since children should not contain discrete variable, we can condition on - // the discrete variable in the parents and solve the resulting GaussianConditional. - auto conditional = boost::static_pointer_cast(inner_)->conditionals()(values->discrete); + values->continuous.insert( + boost::static_pointer_cast(inner_)->solve( + values->continuous)); + } else if (isHybrid()) { + // For hybrid conditional, since children should not contain discrete + // variable, we can condition on the discrete variable in the parents and + // solve the resulting GaussianConditional. + auto conditional = + boost::static_pointer_cast(inner_)->conditionals()( + values->discrete); values->continuous.insert(conditional->solve(values->continuous)); - } + } } - -// /* ************************************************************************** */ -HybridLookupDAG HybridLookupDAG::FromBayesNet( - const HybridBayesNet& bayesNet) { +// /* ************************************************************************** +// */ +HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { HybridLookupDAG dag; for (auto&& conditional : bayesNet) { HybridLookupTable hlt(*conditional); diff --git a/gtsam/hybrid/HybridLookupDAG.h b/gtsam/hybrid/HybridLookupDAG.h index 903cc5519..cc1c58c58 100644 --- a/gtsam/hybrid/HybridLookupDAG.h +++ b/gtsam/hybrid/HybridLookupDAG.h @@ -19,10 +19,10 @@ #include #include -#include -#include #include #include +#include +#include #include #include @@ -34,8 +34,8 @@ namespace gtsam { /** * @brief HybridLookupTable table for max-product * - * Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience. - * Is used in the max-product algorithm. + * Similar to DiscreteLookupTable, inherits from hybrid conditional for + * convenience. Is used in the max-product algorithm. */ class GTSAM_EXPORT HybridLookupTable : public HybridConditional { public: @@ -58,7 +58,8 @@ class GTSAM_EXPORT HybridLookupTable : public HybridConditional { void argmaxInPlace(HybridValues* parentsValues) const; }; -/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */ +/** A DAG made from hybrid lookup tables, as defined above. Similar to + * DiscreteLookupDAG */ class GTSAM_EXPORT HybridLookupDAG : public BayesNet { public: using Base = BayesNet; diff --git a/gtsam/hybrid/HybridValues.h b/gtsam/hybrid/HybridValues.h index 89f7bb58a..5e1bd4164 100644 --- a/gtsam/hybrid/HybridValues.h +++ b/gtsam/hybrid/HybridValues.h @@ -19,11 +19,10 @@ #include #include -#include -#include -#include #include - +#include +#include +#include #include #include @@ -32,8 +31,9 @@ namespace gtsam { /** - * HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables - * of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class. + * HybridValues represents a collection of DiscreteValues and VectorValues. It + * is typically used to store the variables of a HybridGaussianFactorGraph. + * Optimizing a HybridGaussianBayesNet returns this class. */ class GTSAM_EXPORT HybridValues { public: @@ -44,54 +44,47 @@ class GTSAM_EXPORT HybridValues { VectorValues continuous; // Default constructor creates an empty HybridValues. - HybridValues() : discrete(), continuous() {}; + HybridValues() : discrete(), continuous(){}; // Construct from DiscreteValues and VectorValues. - HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {}; + HybridValues(const DiscreteValues& dv, const VectorValues& cv) + : discrete(dv), continuous(cv){}; // print required by Testable for unit testing void print(const std::string& s = "HybridValues", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::cout << s << ": \n"; - discrete.print(" Discrete", keyFormatter); // print discrete components - continuous.print(" Continuous", keyFormatter); //print continuous components + discrete.print(" Discrete", keyFormatter); // print discrete components + continuous.print(" Continuous", + keyFormatter); // print continuous components }; // equals required by Testable for unit testing bool equals(const HybridValues& other, double tol = 1e-9) const { - return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol); + return discrete.equals(other.discrete, tol) && + continuous.equals(other.continuous, tol); } // Check whether a variable with key \c j exists in DiscreteValue. - bool existsDiscrete(Key j){ - return (discrete.find(j) != discrete.end()); - }; + bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; // Check whether a variable with key \c j exists in VectorValue. - bool existsVector(Key j){ - return continuous.exists(j); - }; + bool existsVector(Key j) { return continuous.exists(j); }; // Check whether a variable with key \c j exists. - bool exists(Key j){ - return existsDiscrete(j) || existsVector(j); - }; + bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; - /** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c - * j is already used. - * @param value The vector to be inserted. - * @param j The index with which the value will be associated. */ - void insert(Key j, int value){ - discrete[j] = value; - }; + /** Insert a discrete \c value with key \c j. Replaces the existing value if + * the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, int value) { discrete[j] = value; }; - /** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c - * j is already used. - * @param value The vector to be inserted. - * @param j The index with which the value will be associated. */ - void insert(Key j, const Vector& value) { - continuous.insert(j, value); - } + /** Insert a vector \c value with key \c j. Throws an invalid_argument + * exception if the key \c j is already used. + * @param value The vector to be inserted. + * @param j The index with which the value will be associated. */ + void insert(Key j, const Vector& value) { continuous.insert(j, value); } // TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h @@ -99,18 +92,13 @@ class GTSAM_EXPORT HybridValues { * Read/write access to the discrete value with key \c j, throws * std::out_of_range if \c j does not exist. */ - size_t& atDiscrete(Key j){ - return discrete.at(j); - }; + size_t& atDiscrete(Key j) { return discrete.at(j); }; /** * Read/write access to the vector value with key \c j, throws * std::out_of_range if \c j does not exist. */ - Vector& at(Key j) { - return continuous.at(j); - }; - + Vector& at(Key j) { return continuous.at(j); }; /// @name Wrapper support /// @{ @@ -121,7 +109,8 @@ class GTSAM_EXPORT HybridValues { * @param keyFormatter function that formats keys. * @return string html output. */ - std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{ + std::string html( + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { std::stringstream ss; ss << this->discrete.html(keyFormatter); ss << this->continuous.html(keyFormatter); diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp index 7e532b013..17a2d94d7 100644 --- a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -523,7 +523,6 @@ TEST(HybridGaussianFactorGraph, optimize) { HybridValues hv = result->optimize(); EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0))); - } /* ************************************************************************* */ int main() { diff --git a/gtsam/hybrid/tests/testHybridLookupDAG.cpp b/gtsam/hybrid/tests/testHybridLookupDAG.cpp index 70b09ecbe..c472aa22f 100644 --- a/gtsam/hybrid/tests/testHybridLookupDAG.cpp +++ b/gtsam/hybrid/tests/testHybridLookupDAG.cpp @@ -17,19 +17,19 @@ #include #include -#include #include -#include #include -#include -#include -#include -#include -#include +#include +#include +#include #include #include -#include +#include +#include +#include #include +#include +#include // Include for test suite #include @@ -43,7 +43,7 @@ using symbol_shorthand::M; using symbol_shorthand::X; TEST(HybridLookupTable, basics) { - // create a conditional gaussian node + // create a conditional gaussian node Matrix S1(2, 2); S1(0, 0) = 1; S1(1, 0) = 2; @@ -82,39 +82,38 @@ TEST(HybridLookupTable, basics) { GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); -// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); - - boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); - + // GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals); + + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals)); + HybridConditional hc(mixtureFactor); - GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast(hc.inner())->conditionals(); + GaussianMixture::Conditionals conditional2 = + boost::static_pointer_cast(hc.inner())->conditionals(); DiscreteValues dv; - dv[1]=1; + dv[1] = 1; VectorValues cv; - cv.insert(X(2),Vector2(0.0, 0.0)); - - HybridValues hv(dv, cv); + cv.insert(X(2), Vector2(0.0, 0.0)); - + HybridValues hv(dv, cv); // std::cout << conditional2(values).markdown(); EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6)); - EXPECT(conditional2(dv)==conditionals(dv)); + EXPECT(conditional2(dv) == conditionals(dv)); HybridLookupTable hlt(hc); -// hlt.argmaxInPlace(&hv); - + // hlt.argmaxInPlace(&hv); + HybridLookupDAG dag; dag.push_back(hlt); dag.argmax(hv); -// HybridBayesNet hbn; -// hbn.push_back(hc); -// hbn.optimize(); - + // HybridBayesNet hbn; + // hbn.push_back(hc); + // hbn.optimize(); } TEST(HybridLookupTable, hybrid_argmax) { @@ -124,23 +123,26 @@ TEST(HybridLookupTable, hybrid_argmax) { S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional0 = boost::make_shared(X(1), d1, S1, model), - conditional1 = boost::make_shared(X(1), d2, S1, model); + auto conditional0 = + boost::make_shared(X(1), d1, S1, model), + conditional1 = + boost::make_shared(X(1), d2, S1, model); DiscreteKey m1(1, 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals)); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(1)}, {}, {m1}, conditionals)); HybridConditional hc(mixtureFactor); DiscreteValues dv; - dv[1]=1; + dv[1] = 1; VectorValues cv; // cv.insert(X(2),Vector2(0.0, 0.0)); HybridValues hv(dv, cv); @@ -150,8 +152,6 @@ TEST(HybridLookupTable, hybrid_argmax) { hlt.argmaxInPlace(&hv); EXPECT(assert_equal(hv.at(X(1)), d2)); - - } TEST(HybridLookupTable, discrete_argmax) { @@ -164,18 +164,17 @@ TEST(HybridLookupTable, discrete_argmax) { HybridLookupTable hlt(hc); DiscreteValues dv; - dv[1]=0; + dv[1] = 0; VectorValues cv; // cv.insert(X(2),Vector2(0.0, 0.0)); HybridValues hv(dv, cv); - hlt.argmaxInPlace(&hv); EXPECT(assert_equal(hv.atDiscrete(0), 1)); - DecisionTreeFactor f1(X , "2 3"); - auto conditional2 = boost::make_shared(1,f1); + DecisionTreeFactor f1(X, "2 3"); + auto conditional2 = boost::make_shared(1, f1); HybridConditional hc2(conditional2); @@ -183,7 +182,6 @@ TEST(HybridLookupTable, discrete_argmax) { HybridValues hv2; - hlt2.argmaxInPlace(&hv2); EXPECT(assert_equal(hv2.atDiscrete(0), 1)); @@ -196,12 +194,12 @@ TEST(HybridLookupTable, gaussian_argmax) { S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional = boost::make_shared(X(1), d1, S1, - X(2), -S1, model); + auto conditional = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); HybridConditional hc(conditional); @@ -210,52 +208,51 @@ TEST(HybridLookupTable, gaussian_argmax) { DiscreteValues dv; // dv[1]=0; VectorValues cv; - cv.insert(X(2),d2); + cv.insert(X(2), d2); HybridValues hv(dv, cv); - hlt.argmaxInPlace(&hv); - EXPECT(assert_equal(hv.at(X(1)), d1+d2)); - + EXPECT(assert_equal(hv.at(X(1)), d1 + d2)); } TEST(HybridLookupDAG, argmax) { - Matrix S1(2, 2); S1(0, 0) = 1; S1(1, 0) = 0; S1(0, 1) = 0; S1(1, 1) = 1; - Vector2 d1(0.2, 0.5), d2(-0.5,0.6); + Vector2 d1(0.2, 0.5), d2(-0.5, 0.6); SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34)); - auto conditional0 = boost::make_shared(X(2), d1, S1, model), - conditional1 = boost::make_shared(X(2), d2, S1, model); + auto conditional0 = + boost::make_shared(X(2), d1, S1, model), + conditional1 = + boost::make_shared(X(2), d2, S1, model); DiscreteKey m1(1, 2); GaussianMixture::Conditionals conditionals( {m1}, vector{conditional0, conditional1}); - boost::shared_ptr mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals)); + boost::shared_ptr mixtureFactor( + new GaussianMixture({X(2)}, {}, {m1}, conditionals)); HybridConditional hc2(mixtureFactor); HybridLookupTable hlt2(hc2); - - auto conditional2 = boost::make_shared(X(1), d1, S1, - X(2), -S1, model); + auto conditional2 = + boost::make_shared(X(1), d1, S1, X(2), -S1, model); HybridConditional hc1(conditional2); HybridLookupTable hlt1(hc1); - DecisionTreeFactor f1(m1 , "2 3"); - auto discrete_conditional = boost::make_shared(1,f1); + DecisionTreeFactor f1(m1, "2 3"); + auto discrete_conditional = boost::make_shared(1, f1); HybridConditional hc3(discrete_conditional); HybridLookupTable hlt3(hc3); - + HybridLookupDAG dag; dag.push_back(hlt1); dag.push_back(hlt2); @@ -264,10 +261,9 @@ TEST(HybridLookupDAG, argmax) { EXPECT(assert_equal(hv.atDiscrete(1), 1)); EXPECT(assert_equal(hv.at(X(2)), d2)); - EXPECT(assert_equal(hv.at(X(1)), d2+d1)); + EXPECT(assert_equal(hv.at(X(1)), d2 + d1)); } - /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridValues.cpp b/gtsam/hybrid/tests/testHybridValues.cpp index 3e821aef2..9581faaa0 100644 --- a/gtsam/hybrid/tests/testHybridValues.cpp +++ b/gtsam/hybrid/tests/testHybridValues.cpp @@ -17,19 +17,18 @@ #include #include -#include #include #include -#include -#include -#include -#include +#include #include +#include +#include +#include +#include // Include for test suite #include - using namespace std; using namespace gtsam; @@ -47,7 +46,7 @@ TEST(HybridValues, basics) { values2.insert(99, Vector2(2, 3)); EXPECT(assert_equal(values2, values)); - values2.insert(98, Vector2(2,3)); + values2.insert(98, Vector2(2, 3)); EXPECT(!assert_equal(values2, values)); }