Address review comments
parent
7d36a9eb98
commit
379a65f40f
|
|
@ -64,7 +64,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/// Solve the HybridBayesNet by back-substitution.
|
/// Solve the HybridBayesNet by back-substitution.
|
||||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and put this method there?
|
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
||||||
|
/// put this method there?
|
||||||
HybridValues optimize() const;
|
HybridValues optimize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,10 @@
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -32,28 +32,32 @@ using std::vector;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
|
||||||
// For discrete conditional, uses argmaxInPlace() method in DiscreteLookupTable.
|
// For discrete conditional, uses argmaxInPlace() method in
|
||||||
|
// DiscreteLookupTable.
|
||||||
if (isDiscrete()) {
|
if (isDiscrete()) {
|
||||||
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(&(values->discrete));
|
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
|
||||||
|
&(values->discrete));
|
||||||
} else if (isContinuous()) {
|
} else if (isContinuous()) {
|
||||||
// For Gaussian conditional, uses solve() method in GaussianConditional.
|
// For Gaussian conditional, uses solve() method in GaussianConditional.
|
||||||
values->continuous.insert(boost::static_pointer_cast<GaussianConditional>(inner_)->solve(values->continuous));
|
values->continuous.insert(
|
||||||
|
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
|
||||||
|
values->continuous));
|
||||||
} else if (isHybrid()) {
|
} else if (isHybrid()) {
|
||||||
// For hybrid conditional, since children should not contain discrete variable, we can condition on
|
// For hybrid conditional, since children should not contain discrete
|
||||||
// the discrete variable in the parents and solve the resulting GaussianConditional.
|
// variable, we can condition on the discrete variable in the parents and
|
||||||
auto conditional = boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(values->discrete);
|
// solve the resulting GaussianConditional.
|
||||||
|
auto conditional =
|
||||||
|
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
|
||||||
|
values->discrete);
|
||||||
values->continuous.insert(conditional->solve(values->continuous));
|
values->continuous.insert(conditional->solve(values->continuous));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// /* **************************************************************************
|
||||||
// /* ************************************************************************** */
|
// */
|
||||||
HybridLookupDAG HybridLookupDAG::FromBayesNet(
|
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
|
||||||
const HybridBayesNet& bayesNet) {
|
|
||||||
HybridLookupDAG dag;
|
HybridLookupDAG dag;
|
||||||
for (auto&& conditional : bayesNet) {
|
for (auto&& conditional : bayesNet) {
|
||||||
HybridLookupTable hlt(*conditional);
|
HybridLookupTable hlt(*conditional);
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,10 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
|
||||||
#include <gtsam/inference/FactorGraph.h>
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
#include <gtsam/inference/FactorGraph.h>
|
||||||
|
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -34,8 +34,8 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* @brief HybridLookupTable table for max-product
|
* @brief HybridLookupTable table for max-product
|
||||||
*
|
*
|
||||||
* Similar to DiscreteLookupTable, inherits from hybrid conditional for convenience.
|
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
|
||||||
* Is used in the max-product algorithm.
|
* convenience. Is used in the max-product algorithm.
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
||||||
public:
|
public:
|
||||||
|
|
@ -58,7 +58,8 @@ class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
|
||||||
void argmaxInPlace(HybridValues* parentsValues) const;
|
void argmaxInPlace(HybridValues* parentsValues) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** A DAG made from hybrid lookup tables, as defined above. Similar to DiscreteLookupDAG */
|
/** A DAG made from hybrid lookup tables, as defined above. Similar to
|
||||||
|
* DiscreteLookupDAG */
|
||||||
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
|
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
|
||||||
public:
|
public:
|
||||||
using Base = BayesNet<HybridLookupTable>;
|
using Base = BayesNet<HybridLookupTable>;
|
||||||
|
|
|
||||||
|
|
@ -19,11 +19,10 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Key.h>
|
|
||||||
#include <gtsam/nonlinear/Values.h>
|
|
||||||
#include <gtsam/linear/VectorValues.h>
|
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -32,8 +31,9 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* HybridValues represents a collection of DiscreteValues and VectorValues. It is typically used to store the variables
|
* HybridValues represents a collection of DiscreteValues and VectorValues. It
|
||||||
* of a HybridGaussianFactorGraph. Optimizing a HybridGaussianBayesNet returns this class.
|
* is typically used to store the variables of a HybridGaussianFactorGraph.
|
||||||
|
* Optimizing a HybridGaussianBayesNet returns this class.
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridValues {
|
class GTSAM_EXPORT HybridValues {
|
||||||
public:
|
public:
|
||||||
|
|
@ -47,51 +47,44 @@ class GTSAM_EXPORT HybridValues {
|
||||||
HybridValues() : discrete(), continuous(){};
|
HybridValues() : discrete(), continuous(){};
|
||||||
|
|
||||||
// Construct from DiscreteValues and VectorValues.
|
// Construct from DiscreteValues and VectorValues.
|
||||||
HybridValues(const DiscreteValues &dv, const VectorValues &cv) : discrete(dv), continuous(cv) {};
|
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
|
||||||
|
: discrete(dv), continuous(cv){};
|
||||||
|
|
||||||
// print required by Testable for unit testing
|
// print required by Testable for unit testing
|
||||||
void print(const std::string& s = "HybridValues",
|
void print(const std::string& s = "HybridValues",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
std::cout << s << ": \n";
|
std::cout << s << ": \n";
|
||||||
discrete.print(" Discrete", keyFormatter); // print discrete components
|
discrete.print(" Discrete", keyFormatter); // print discrete components
|
||||||
continuous.print(" Continuous", keyFormatter); //print continuous components
|
continuous.print(" Continuous",
|
||||||
|
keyFormatter); // print continuous components
|
||||||
};
|
};
|
||||||
|
|
||||||
// equals required by Testable for unit testing
|
// equals required by Testable for unit testing
|
||||||
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
bool equals(const HybridValues& other, double tol = 1e-9) const {
|
||||||
return discrete.equals(other.discrete, tol) && continuous.equals(other.continuous, tol);
|
return discrete.equals(other.discrete, tol) &&
|
||||||
|
continuous.equals(other.continuous, tol);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists in DiscreteValue.
|
// Check whether a variable with key \c j exists in DiscreteValue.
|
||||||
bool existsDiscrete(Key j){
|
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };
|
||||||
return (discrete.find(j) != discrete.end());
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists in VectorValue.
|
// Check whether a variable with key \c j exists in VectorValue.
|
||||||
bool existsVector(Key j){
|
bool existsVector(Key j) { return continuous.exists(j); };
|
||||||
return continuous.exists(j);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check whether a variable with key \c j exists.
|
// Check whether a variable with key \c j exists.
|
||||||
bool exists(Key j){
|
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };
|
||||||
return existsDiscrete(j) || existsVector(j);
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Insert a discrete \c value with key \c j. Replaces the existing value if the key \c
|
/** Insert a discrete \c value with key \c j. Replaces the existing value if
|
||||||
* j is already used.
|
* the key \c j is already used.
|
||||||
* @param value The vector to be inserted.
|
* @param value The vector to be inserted.
|
||||||
* @param j The index with which the value will be associated. */
|
* @param j The index with which the value will be associated. */
|
||||||
void insert(Key j, int value){
|
void insert(Key j, int value) { discrete[j] = value; };
|
||||||
discrete[j] = value;
|
|
||||||
};
|
|
||||||
|
|
||||||
/** Insert a vector \c value with key \c j. Throws an invalid_argument exception if the key \c
|
/** Insert a vector \c value with key \c j. Throws an invalid_argument
|
||||||
* j is already used.
|
* exception if the key \c j is already used.
|
||||||
* @param value The vector to be inserted.
|
* @param value The vector to be inserted.
|
||||||
* @param j The index with which the value will be associated. */
|
* @param j The index with which the value will be associated. */
|
||||||
void insert(Key j, const Vector& value) {
|
void insert(Key j, const Vector& value) { continuous.insert(j, value); }
|
||||||
continuous.insert(j, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
|
||||||
|
|
||||||
|
|
@ -99,18 +92,13 @@ class GTSAM_EXPORT HybridValues {
|
||||||
* Read/write access to the discrete value with key \c j, throws
|
* Read/write access to the discrete value with key \c j, throws
|
||||||
* std::out_of_range if \c j does not exist.
|
* std::out_of_range if \c j does not exist.
|
||||||
*/
|
*/
|
||||||
size_t& atDiscrete(Key j){
|
size_t& atDiscrete(Key j) { return discrete.at(j); };
|
||||||
return discrete.at(j);
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read/write access to the vector value with key \c j, throws
|
* Read/write access to the vector value with key \c j, throws
|
||||||
* std::out_of_range if \c j does not exist.
|
* std::out_of_range if \c j does not exist.
|
||||||
*/
|
*/
|
||||||
Vector& at(Key j) {
|
Vector& at(Key j) { return continuous.at(j); };
|
||||||
return continuous.at(j);
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -121,7 +109,8 @@ class GTSAM_EXPORT HybridValues {
|
||||||
* @param keyFormatter function that formats keys.
|
* @param keyFormatter function that formats keys.
|
||||||
* @return string html output.
|
* @return string html output.
|
||||||
*/
|
*/
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter) const{
|
std::string html(
|
||||||
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << this->discrete.html(keyFormatter);
|
ss << this->discrete.html(keyFormatter);
|
||||||
ss << this->continuous.html(keyFormatter);
|
ss << this->continuous.html(keyFormatter);
|
||||||
|
|
|
||||||
|
|
@ -523,7 +523,6 @@ TEST(HybridGaussianFactorGraph, optimize) {
|
||||||
HybridValues hv = result->optimize();
|
HybridValues hv = result->optimize();
|
||||||
|
|
||||||
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
|
EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
|
||||||
|
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
|
|
|
||||||
|
|
@ -17,19 +17,19 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
@ -84,11 +84,13 @@ TEST(HybridLookupTable, basics) {
|
||||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
|
// GaussianMixture mixtureFactor2({X(1)}, {X(2)}, {m1}, conditionals);
|
||||||
|
|
||||||
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
|
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||||
|
new GaussianMixture({X(1)}, {X(2)}, {m1}, conditionals));
|
||||||
|
|
||||||
HybridConditional hc(mixtureFactor);
|
HybridConditional hc(mixtureFactor);
|
||||||
|
|
||||||
GaussianMixture::Conditionals conditional2 = boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
|
GaussianMixture::Conditionals conditional2 =
|
||||||
|
boost::static_pointer_cast<GaussianMixture>(hc.inner())->conditionals();
|
||||||
|
|
||||||
DiscreteValues dv;
|
DiscreteValues dv;
|
||||||
dv[1] = 1;
|
dv[1] = 1;
|
||||||
|
|
@ -98,8 +100,6 @@ TEST(HybridLookupTable, basics) {
|
||||||
|
|
||||||
HybridValues hv(dv, cv);
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// std::cout << conditional2(values).markdown();
|
// std::cout << conditional2(values).markdown();
|
||||||
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
|
EXPECT(assert_equal(*conditional2(dv), *conditionals(dv), 1e-6));
|
||||||
EXPECT(conditional2(dv) == conditionals(dv));
|
EXPECT(conditional2(dv) == conditionals(dv));
|
||||||
|
|
@ -114,7 +114,6 @@ TEST(HybridLookupTable, basics) {
|
||||||
// HybridBayesNet hbn;
|
// HybridBayesNet hbn;
|
||||||
// hbn.push_back(hc);
|
// hbn.push_back(hc);
|
||||||
// hbn.optimize();
|
// hbn.optimize();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridLookupTable, hybrid_argmax) {
|
TEST(HybridLookupTable, hybrid_argmax) {
|
||||||
|
|
@ -128,14 +127,17 @@ TEST(HybridLookupTable, hybrid_argmax) {
|
||||||
|
|
||||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
|
auto conditional0 =
|
||||||
conditional1 = boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
|
boost::make_shared<GaussianConditional>(X(1), d1, S1, model),
|
||||||
|
conditional1 =
|
||||||
|
boost::make_shared<GaussianConditional>(X(1), d2, S1, model);
|
||||||
|
|
||||||
DiscreteKey m1(1, 2);
|
DiscreteKey m1(1, 2);
|
||||||
GaussianMixture::Conditionals conditionals(
|
GaussianMixture::Conditionals conditionals(
|
||||||
{m1},
|
{m1},
|
||||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(1)},{}, {m1}, conditionals));
|
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||||
|
new GaussianMixture({X(1)}, {}, {m1}, conditionals));
|
||||||
|
|
||||||
HybridConditional hc(mixtureFactor);
|
HybridConditional hc(mixtureFactor);
|
||||||
|
|
||||||
|
|
@ -150,8 +152,6 @@ TEST(HybridLookupTable, hybrid_argmax) {
|
||||||
hlt.argmaxInPlace(&hv);
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
EXPECT(assert_equal(hv.at(X(1)), d2));
|
EXPECT(assert_equal(hv.at(X(1)), d2));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridLookupTable, discrete_argmax) {
|
TEST(HybridLookupTable, discrete_argmax) {
|
||||||
|
|
@ -169,7 +169,6 @@ TEST(HybridLookupTable, discrete_argmax) {
|
||||||
// cv.insert(X(2),Vector2(0.0, 0.0));
|
// cv.insert(X(2),Vector2(0.0, 0.0));
|
||||||
HybridValues hv(dv, cv);
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
hlt.argmaxInPlace(&hv);
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
EXPECT(assert_equal(hv.atDiscrete(0), 1));
|
EXPECT(assert_equal(hv.atDiscrete(0), 1));
|
||||||
|
|
@ -183,7 +182,6 @@ TEST(HybridLookupTable, discrete_argmax) {
|
||||||
|
|
||||||
HybridValues hv2;
|
HybridValues hv2;
|
||||||
|
|
||||||
|
|
||||||
hlt2.argmaxInPlace(&hv2);
|
hlt2.argmaxInPlace(&hv2);
|
||||||
|
|
||||||
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
|
EXPECT(assert_equal(hv2.atDiscrete(0), 1));
|
||||||
|
|
@ -200,8 +198,8 @@ TEST(HybridLookupTable, gaussian_argmax) {
|
||||||
|
|
||||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
auto conditional = boost::make_shared<GaussianConditional>(X(1), d1, S1,
|
auto conditional =
|
||||||
X(2), -S1, model);
|
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||||
|
|
||||||
HybridConditional hc(conditional);
|
HybridConditional hc(conditional);
|
||||||
|
|
||||||
|
|
@ -213,15 +211,12 @@ TEST(HybridLookupTable, gaussian_argmax) {
|
||||||
cv.insert(X(2), d2);
|
cv.insert(X(2), d2);
|
||||||
HybridValues hv(dv, cv);
|
HybridValues hv(dv, cv);
|
||||||
|
|
||||||
|
|
||||||
hlt.argmaxInPlace(&hv);
|
hlt.argmaxInPlace(&hv);
|
||||||
|
|
||||||
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
|
EXPECT(assert_equal(hv.at(X(1)), d1 + d2));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(HybridLookupDAG, argmax) {
|
TEST(HybridLookupDAG, argmax) {
|
||||||
|
|
||||||
Matrix S1(2, 2);
|
Matrix S1(2, 2);
|
||||||
S1(0, 0) = 1;
|
S1(0, 0) = 1;
|
||||||
S1(1, 0) = 0;
|
S1(1, 0) = 0;
|
||||||
|
|
@ -232,20 +227,22 @@ TEST(HybridLookupDAG, argmax) {
|
||||||
|
|
||||||
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
SharedDiagonal model = noiseModel::Diagonal::Sigmas(Vector2(1.0, 0.34));
|
||||||
|
|
||||||
auto conditional0 = boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
|
auto conditional0 =
|
||||||
conditional1 = boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
|
boost::make_shared<GaussianConditional>(X(2), d1, S1, model),
|
||||||
|
conditional1 =
|
||||||
|
boost::make_shared<GaussianConditional>(X(2), d2, S1, model);
|
||||||
|
|
||||||
DiscreteKey m1(1, 2);
|
DiscreteKey m1(1, 2);
|
||||||
GaussianMixture::Conditionals conditionals(
|
GaussianMixture::Conditionals conditionals(
|
||||||
{m1},
|
{m1},
|
||||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||||
boost::shared_ptr<GaussianMixture> mixtureFactor(new GaussianMixture({X(2)},{}, {m1}, conditionals));
|
boost::shared_ptr<GaussianMixture> mixtureFactor(
|
||||||
|
new GaussianMixture({X(2)}, {}, {m1}, conditionals));
|
||||||
HybridConditional hc2(mixtureFactor);
|
HybridConditional hc2(mixtureFactor);
|
||||||
HybridLookupTable hlt2(hc2);
|
HybridLookupTable hlt2(hc2);
|
||||||
|
|
||||||
|
auto conditional2 =
|
||||||
auto conditional2 = boost::make_shared<GaussianConditional>(X(1), d1, S1,
|
boost::make_shared<GaussianConditional>(X(1), d1, S1, X(2), -S1, model);
|
||||||
X(2), -S1, model);
|
|
||||||
|
|
||||||
HybridConditional hc1(conditional2);
|
HybridConditional hc1(conditional2);
|
||||||
HybridLookupTable hlt1(hc1);
|
HybridLookupTable hlt1(hc1);
|
||||||
|
|
@ -267,7 +264,6 @@ TEST(HybridLookupDAG, argmax) {
|
||||||
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
|
EXPECT(assert_equal(hv.at(X(1)), d2 + d1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -17,19 +17,18 @@
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
|
||||||
#include <gtsam/linear/VectorValues.h>
|
|
||||||
#include <gtsam/inference/Symbol.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue