Product now has scalars
parent
92540298e1
commit
584a71fb94
|
@ -22,8 +22,8 @@
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
|
||||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
@ -53,8 +53,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
fvs.reserve(p.size());
|
fvs.reserve(p.size());
|
||||||
gcs.reserve(p.size());
|
gcs.reserve(p.size());
|
||||||
for (auto&& [mean, sigma] : p) {
|
for (auto&& [mean, sigma] : p) {
|
||||||
auto gaussianConditional =
|
auto gaussianConditional = GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
|
||||||
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
|
|
||||||
double value = gaussianConditional->negLogConstant();
|
double value = gaussianConditional->negLogConstant();
|
||||||
minNegLogConstant = std::min(minNegLogConstant, value);
|
minNegLogConstant = std::min(minNegLogConstant, value);
|
||||||
fvs.emplace_back(gaussianConditional, value);
|
fvs.emplace_back(gaussianConditional, value);
|
||||||
|
@ -67,8 +66,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
|
|
||||||
/// Construct from tree of GaussianConditionals.
|
/// Construct from tree of GaussianConditionals.
|
||||||
explicit Helper(const Conditionals& conditionals)
|
explicit Helper(const Conditionals& conditionals)
|
||||||
: conditionals(conditionals),
|
: conditionals(conditionals), minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
|
||||||
auto func = [this](const GC::shared_ptr& c) -> GaussianFactorValuePair {
|
auto func = [this](const GC::shared_ptr& c) -> GaussianFactorValuePair {
|
||||||
double value = 0.0;
|
double value = 0.0;
|
||||||
if (c) {
|
if (c) {
|
||||||
|
@ -90,8 +88,8 @@ struct HybridGaussianConditional::Helper {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(const DiscreteKeys& discreteParents,
|
||||||
const DiscreteKeys &discreteParents, const Helper &helper)
|
const Helper& helper)
|
||||||
: BaseFactor(discreteParents, helper.pairs),
|
: BaseFactor(discreteParents, helper.pairs),
|
||||||
BaseConditional(*helper.nrFrontals),
|
BaseConditional(*helper.nrFrontals),
|
||||||
conditionals_(helper.conditionals),
|
conditionals_(helper.conditionals),
|
||||||
|
@ -104,26 +102,32 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
Conditionals({discreteParent}, conditionals)) {}
|
Conditionals({discreteParent}, conditionals)) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKey &discreteParent, Key key, //
|
const DiscreteKey& discreteParent,
|
||||||
|
Key key, //
|
||||||
const std::vector<std::pair<Vector, double>>& parameters)
|
const std::vector<std::pair<Vector, double>>& parameters)
|
||||||
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
||||||
Helper(discreteParent, parameters, key)) {}
|
Helper(discreteParent, parameters, key)) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKey &discreteParent, Key key, //
|
const DiscreteKey& discreteParent,
|
||||||
const Matrix &A, Key parent,
|
Key key, //
|
||||||
|
const Matrix& A,
|
||||||
|
Key parent,
|
||||||
const std::vector<std::pair<Vector, double>>& parameters)
|
const std::vector<std::pair<Vector, double>>& parameters)
|
||||||
: HybridGaussianConditional(
|
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
||||||
DiscreteKeys{discreteParent},
|
|
||||||
Helper(discreteParent, parameters, key, A, parent)) {}
|
Helper(discreteParent, parameters, key, A, parent)) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKey &discreteParent, Key key, //
|
const DiscreteKey& discreteParent,
|
||||||
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
|
Key key, //
|
||||||
|
const Matrix& A1,
|
||||||
|
Key parent1,
|
||||||
|
const Matrix& A2,
|
||||||
|
Key parent2,
|
||||||
const std::vector<std::pair<Vector, double>>& parameters)
|
const std::vector<std::pair<Vector, double>>& parameters)
|
||||||
: HybridGaussianConditional(
|
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
||||||
DiscreteKeys{discreteParent},
|
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {
|
||||||
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
|
}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys& discreteParents,
|
const DiscreteKeys& discreteParents,
|
||||||
|
@ -131,15 +135,14 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
const HybridGaussianConditional::Conditionals &
|
const HybridGaussianConditional::Conditionals& HybridGaussianConditional::conditionals() const {
|
||||||
HybridGaussianConditional::conditionals() const {
|
|
||||||
return conditionals_;
|
return conditionals_;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor HybridGaussianConditional::asProductFactor()
|
HybridGaussianProductFactor HybridGaussianConditional::asProductFactor() const {
|
||||||
const {
|
auto wrap = [this](const std::shared_ptr<GaussianConditional>& gc)
|
||||||
auto wrap = [this](const std::shared_ptr<GaussianConditional> &gc) {
|
-> std::pair<GaussianFactorGraph, double> {
|
||||||
// First check if conditional has not been pruned
|
// First check if conditional has not been pruned
|
||||||
if (gc) {
|
if (gc) {
|
||||||
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
|
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
|
||||||
|
@ -151,10 +154,17 @@ HybridGaussianProductFactor HybridGaussianConditional::asProductFactor()
|
||||||
Vector c(1);
|
Vector c(1);
|
||||||
c << std::sqrt(2.0 * Cgm_Kgcm);
|
c << std::sqrt(2.0 * Cgm_Kgcm);
|
||||||
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
||||||
return GaussianFactorGraph{gc, constantFactor};
|
return {GaussianFactorGraph{gc, constantFactor}, Cgm_Kgcm};
|
||||||
|
} else {
|
||||||
|
// The scalar can be zero.
|
||||||
|
// TODO(Frank): after hiding is gone, this should be only case here.
|
||||||
|
return {GaussianFactorGraph{gc}, Cgm_Kgcm};
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// If the conditional is pruned, return an empty GaussianFactorGraph with zero scalar sum
|
||||||
|
// TODO(Frank): Could we just return an *empty* GaussianFactorGraph?
|
||||||
|
return {GaussianFactorGraph{nullptr}, 0.0};
|
||||||
}
|
}
|
||||||
return GaussianFactorGraph{gc};
|
|
||||||
};
|
};
|
||||||
return {{conditionals_, wrap}};
|
return {{conditionals_, wrap}};
|
||||||
}
|
}
|
||||||
|
@ -177,13 +187,11 @@ GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
if (conditional)
|
if (conditional)
|
||||||
return conditional;
|
return conditional;
|
||||||
else
|
else
|
||||||
throw std::logic_error(
|
throw std::logic_error("A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
bool HybridGaussianConditional::equals(const HybridFactor& lf, double tol) const {
|
||||||
double tol) const {
|
|
||||||
const This* e = dynamic_cast<const This*>(&lf);
|
const This* e = dynamic_cast<const This*>(&lf);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
|
@ -193,15 +201,13 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return BaseFactor::equals(*e, tol) &&
|
return BaseFactor::equals(*e, tol) &&
|
||||||
conditionals_.equals(
|
conditionals_.equals(e->conditionals_, [tol](const auto& f1, const auto& f2) {
|
||||||
e->conditionals_, [tol](const auto &f1, const auto &f2) {
|
|
||||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianConditional::print(const std::string &s,
|
void HybridGaussianConditional::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
BaseConditional::print("", formatter);
|
BaseConditional::print("", formatter);
|
||||||
std::cout << " Discrete Keys = ";
|
std::cout << " Discrete Keys = ";
|
||||||
|
@ -212,7 +218,8 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
conditionals_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"",
|
||||||
|
[&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianConditional::shared_ptr& gf) -> std::string {
|
[&](const GaussianConditional::shared_ptr& gf) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
if (gf && !gf->empty()) {
|
if (gf && !gf->empty()) {
|
||||||
|
@ -233,16 +240,15 @@ KeyVector HybridGaussianConditional::continuousParents() const {
|
||||||
for (const auto& discreteKey : discreteKeys()) {
|
for (const auto& discreteKey : discreteKeys()) {
|
||||||
const Key key = discreteKey.first;
|
const Key key = discreteKey.first;
|
||||||
// remove that key from continuousParentKeys:
|
// remove that key from continuousParentKeys:
|
||||||
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
continuousParentKeys.erase(
|
||||||
continuousParentKeys.end(), key),
|
std::remove(continuousParentKeys.begin(), continuousParentKeys.end(), key),
|
||||||
continuousParentKeys.end());
|
continuousParentKeys.end());
|
||||||
}
|
}
|
||||||
return continuousParentKeys;
|
return continuousParentKeys;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool HybridGaussianConditional::allFrontalsGiven(
|
bool HybridGaussianConditional::allFrontalsGiven(const VectorValues& given) const {
|
||||||
const VectorValues &given) const {
|
|
||||||
for (auto&& kv : given) {
|
for (auto&& kv : given) {
|
||||||
if (given.find(kv.first) == given.end()) {
|
if (given.find(kv.first) == given.end()) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -264,8 +270,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals_,
|
conditionals_,
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianConditional::shared_ptr& conditional) -> GaussianFactorValuePair {
|
||||||
-> GaussianFactorValuePair {
|
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
||||||
if (Cgm_Kgcm == 0.0) {
|
if (Cgm_Kgcm == 0.0) {
|
||||||
|
@ -276,8 +281,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
return {likelihood_m, Cgm_Kgcm};
|
return {likelihood_m, Cgm_Kgcm};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
return std::make_shared<HybridGaussianFactor>(discreteParentKeys, likelihoods);
|
||||||
likelihoods);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -291,11 +295,10 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
const DecisionTreeFactor& discreteProbs) const {
|
const DecisionTreeFactor& discreteProbs) const {
|
||||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
std::set<Key> theirs(discreteProbs.keys().begin(), discreteProbs.keys().end());
|
||||||
discreteProbs.keys().end());
|
|
||||||
std::vector<Key> diff;
|
std::vector<Key> diff;
|
||||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
std::set_difference(
|
||||||
std::back_inserter(diff));
|
theirs.begin(), theirs.end(), mine.begin(), mine.end(), std::back_inserter(diff));
|
||||||
|
|
||||||
// Find maximum probability value for every combination of our keys.
|
// Find maximum probability value for every combination of our keys.
|
||||||
Ordering keys(diff);
|
Ordering keys(diff);
|
||||||
|
@ -303,20 +306,18 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
|
|
||||||
// Check the max value for every combination of our keys.
|
// Check the max value for every combination of our keys.
|
||||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
auto pruner = [&](const Assignment<Key> &choices,
|
auto pruner =
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
[&](const Assignment<Key>& choices,
|
||||||
-> GaussianConditional::shared_ptr {
|
const GaussianConditional::shared_ptr& conditional) -> GaussianConditional::shared_ptr {
|
||||||
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
return (max->evaluate(choices) == 0.0) ? nullptr : conditional;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||||
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(), pruned_conditionals);
|
||||||
pruned_conditionals);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::logProbability(
|
double HybridGaussianConditional::logProbability(const HybridValues& values) const {
|
||||||
const HybridValues &values) const {
|
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,8 +32,8 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::FactorValuePairs
|
HybridGaussianFactor::FactorValuePairs HybridGaussianFactor::augment(
|
||||||
HybridGaussianFactor::augment(const FactorValuePairs &factors) {
|
const FactorValuePairs& factors) {
|
||||||
// Find the minimum value so we can "proselytize" to positive values.
|
// Find the minimum value so we can "proselytize" to positive values.
|
||||||
// Done because we can't have sqrt of negative numbers.
|
// Done because we can't have sqrt of negative numbers.
|
||||||
DecisionTree<Key, GaussianFactor::shared_ptr> gaussianFactors;
|
DecisionTree<Key, GaussianFactor::shared_ptr> gaussianFactors;
|
||||||
|
@ -48,14 +48,12 @@ HybridGaussianFactor::augment(const FactorValuePairs &factors) {
|
||||||
auto [gf, value] = gfv;
|
auto [gf, value] = gfv;
|
||||||
|
|
||||||
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
|
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
|
||||||
if (!jf)
|
if (!jf) return {gf, 0.0}; // should this be zero or infinite?
|
||||||
return {gf, 0.0}; // should this be zero or infinite?
|
|
||||||
|
|
||||||
double normalized_value = value - min_value;
|
double normalized_value = value - min_value;
|
||||||
|
|
||||||
// If the value is 0, do nothing
|
// If the value is 0, do nothing
|
||||||
if (normalized_value == 0.0)
|
if (normalized_value == 0.0) return {gf, value};
|
||||||
return {gf, 0.0};
|
|
||||||
|
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
gfg.push_back(jf);
|
gfg.push_back(jf);
|
||||||
|
@ -66,7 +64,8 @@ HybridGaussianFactor::augment(const FactorValuePairs &factors) {
|
||||||
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
||||||
|
|
||||||
gfg.push_back(constantFactor);
|
gfg.push_back(constantFactor);
|
||||||
return {std::make_shared<JacobianFactor>(gfg), normalized_value};
|
// NOTE(Frank): we store the actual value, not the normalized value:
|
||||||
|
return {std::make_shared<JacobianFactor>(gfg), value};
|
||||||
};
|
};
|
||||||
return FactorValuePairs(factors, update);
|
return FactorValuePairs(factors, update);
|
||||||
}
|
}
|
||||||
|
@ -77,6 +76,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
||||||
FactorValuePairs pairs; // The decision tree with factors and scalars
|
FactorValuePairs pairs; // The decision tree with factors and scalars
|
||||||
|
|
||||||
|
/// Constructor for a single discrete key and a vector of Gaussian factors
|
||||||
ConstructorHelper(const DiscreteKey& discreteKey,
|
ConstructorHelper(const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr>& factors)
|
const std::vector<GaussianFactor::shared_ptr>& factors)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
|
@ -88,13 +88,13 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Build the FactorValuePairs DecisionTree
|
// Build the FactorValuePairs DecisionTree
|
||||||
pairs = FactorValuePairs(
|
pairs = FactorValuePairs(DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||||
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
|
||||||
[](const auto& f) {
|
[](const auto& f) {
|
||||||
return std::pair{f, 0.0};
|
return std::pair{f, 0.0};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Constructor for a single discrete key and a vector of GaussianFactorValuePairs
|
||||||
ConstructorHelper(const DiscreteKey& discreteKey,
|
ConstructorHelper(const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
|
@ -110,8 +110,8 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
pairs = FactorValuePairs(discreteKeys, factorPairs);
|
pairs = FactorValuePairs(discreteKeys, factorPairs);
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstructorHelper(const DiscreteKeys &discreteKeys,
|
/// Constructor for a vector of discrete keys and a vector of GaussianFactorValuePairs
|
||||||
const FactorValuePairs &factorPairs)
|
ConstructorHelper(const DiscreteKeys& discreteKeys, const FactorValuePairs& factorPairs)
|
||||||
: discreteKeys(discreteKeys) {
|
: discreteKeys(discreteKeys) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
// TODO: just stop after first non-null factor
|
// TODO: just stop after first non-null factor
|
||||||
|
@ -128,22 +128,16 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
|
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
|
||||||
: Base(helper.continuousKeys, helper.discreteKeys),
|
: Base(helper.continuousKeys, helper.discreteKeys), factors_(augment(helper.pairs)) {}
|
||||||
factors_(augment(helper.pairs)) {}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey &discreteKey,
|
const DiscreteKey& discreteKey, const std::vector<GaussianFactor::shared_ptr>& factorPairs)
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factorPairs)
|
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKey& discreteKey,
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
|
||||||
const DiscreteKey &discreteKey,
|
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
|
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
|
||||||
const FactorValuePairs& factorPairs)
|
const FactorValuePairs& factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {}
|
||||||
|
@ -151,13 +145,11 @@ HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
||||||
const This* e = dynamic_cast<const This*>(&lf);
|
const This* e = dynamic_cast<const This*>(&lf);
|
||||||
if (e == nullptr)
|
if (e == nullptr) return false;
|
||||||
return false;
|
|
||||||
|
|
||||||
// This will return false if either factors_ is empty or e->factors_ is
|
// This will return false if either factors_ is empty or e->factors_ is
|
||||||
// empty, but not if both are empty or both are not empty:
|
// empty, but not if both are empty or both are not empty:
|
||||||
if (factors_.empty() ^ e->factors_.empty())
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
return false;
|
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
|
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
|
||||||
|
@ -169,8 +161,7 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianFactor::print(const std::string &s,
|
void HybridGaussianFactor::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
HybridFactor::print("", formatter);
|
HybridFactor::print("", formatter);
|
||||||
std::cout << "{\n";
|
std::cout << "{\n";
|
||||||
|
@ -178,7 +169,8 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
std::cout << " empty" << std::endl;
|
std::cout << " empty" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"",
|
||||||
|
[&](Key k) { return formatter(k); },
|
||||||
[&](const auto& pair) -> std::string {
|
[&](const auto& pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
|
@ -195,21 +187,24 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::sharedFactor
|
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
|
||||||
HybridGaussianFactor::operator()(const DiscreteValues &assignment) const {
|
const DiscreteValues& assignment) const {
|
||||||
return factors_(assignment).first;
|
return factors_(assignment).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||||
return {{factors_,
|
// Implemented by creating a new DecisionTree where:
|
||||||
[](const auto &pair) { return GaussianFactorGraph{pair.first}; }}};
|
// - The structure (keys and assignments) is preserved from factors_
|
||||||
|
// - Each leaf converted to a GaussianFactorGraph with just the factor and its scalar.
|
||||||
|
return {{factors_, [](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
|
||||||
|
return {GaussianFactorGraph{pair.first}, pair.second};
|
||||||
|
}}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
/// Helper method to compute the error of a component.
|
/// Helper method to compute the error of a component.
|
||||||
static double
|
static double PotentiallyPrunedComponentError(const GaussianFactor::shared_ptr& gf,
|
||||||
PotentiallyPrunedComponentError(const GaussianFactor::shared_ptr &gf,
|
|
||||||
const VectorValues& values) {
|
const VectorValues& values) {
|
||||||
// Check if valid pointer
|
// Check if valid pointer
|
||||||
if (gf) {
|
if (gf) {
|
||||||
|
@ -222,8 +217,8 @@ PotentiallyPrunedComponentError(const GaussianFactor::shared_ptr &gf,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
AlgebraicDecisionTree<Key>
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
HybridGaussianFactor::errorTree(const VectorValues &continuousValues) const {
|
const VectorValues& continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [&continuousValues](const auto& pair) {
|
auto errorFunc = [&continuousValues](const auto& pair) {
|
||||||
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "gtsam/discrete/DiscreteValues.h"
|
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
|
@ -40,6 +39,7 @@
|
||||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||||
#include <gtsam/linear/HessianFactor.h>
|
#include <gtsam/linear/HessianFactor.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
#include "gtsam/discrete/DiscreteValues.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -59,11 +59,10 @@ using std::dynamic_pointer_cast;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// Throw a runtime exception for method specified in string s, and factor f:
|
// Throw a runtime exception for method specified in string s, and factor f:
|
||||||
static void throwRuntimeError(const std::string &s,
|
static void throwRuntimeError(const std::string& s, const std::shared_ptr<Factor>& f) {
|
||||||
const std::shared_ptr<Factor> &f) {
|
|
||||||
auto& fr = *f;
|
auto& fr = *f;
|
||||||
throw std::runtime_error(s + " not implemented for factor type " +
|
throw std::runtime_error(s + " not implemented for factor type " + demangle(typeid(fr).name()) +
|
||||||
demangle(typeid(fr).name()) + ".");
|
".");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -82,8 +81,7 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
if (assignment.empty())
|
if (assignment.empty())
|
||||||
hgf->print("HybridGaussianFactor:", keyFormatter);
|
hgf->print("HybridGaussianFactor:", keyFormatter);
|
||||||
else
|
else
|
||||||
hgf->operator()(assignment)
|
hgf->operator()(assignment)->print("HybridGaussianFactor, component:", keyFormatter);
|
||||||
->print("HybridGaussianFactor, component:", keyFormatter);
|
|
||||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
factor->print("GaussianFactor:\n", keyFormatter);
|
factor->print("GaussianFactor:\n", keyFormatter);
|
||||||
|
|
||||||
|
@ -98,9 +96,7 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
if (assignment.empty())
|
if (assignment.empty())
|
||||||
hc->print("HybridConditional:", keyFormatter);
|
hc->print("HybridConditional:", keyFormatter);
|
||||||
else
|
else
|
||||||
hc->asHybrid()
|
hc->asHybrid()->choose(assignment)->print("HybridConditional, component:\n", keyFormatter);
|
||||||
->choose(assignment)
|
|
||||||
->print("HybridConditional, component:\n", keyFormatter);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
factor->print("Unknown factor type\n", keyFormatter);
|
factor->print("Unknown factor type\n", keyFormatter);
|
||||||
|
@ -129,11 +125,11 @@ void HybridGaussianFactorGraph::print(const std::string &s,
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::printErrors(
|
void HybridGaussianFactorGraph::printErrors(
|
||||||
const HybridValues &values, const std::string &str,
|
const HybridValues& values,
|
||||||
|
const std::string& str,
|
||||||
const KeyFormatter& keyFormatter,
|
const KeyFormatter& keyFormatter,
|
||||||
const std::function<bool(const Factor * /*factor*/,
|
const std::function<bool(const Factor* /*factor*/, double /*whitenedError*/, size_t /*index*/)>&
|
||||||
double /*whitenedError*/, size_t /*index*/)>
|
printCondition) const {
|
||||||
&printCondition) const {
|
|
||||||
std::cout << str << " size: " << size() << std::endl << std::endl;
|
std::cout << str << " size: " << size() << std::endl << std::endl;
|
||||||
|
|
||||||
for (size_t i = 0; i < factors_.size(); i++) {
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
@ -157,8 +153,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
HybridGaussianProductFactor
|
HybridGaussianProductFactor HybridGaussianFactorGraph::collectProductFactor() const {
|
||||||
HybridGaussianFactorGraph::collectProductFactor() const {
|
|
||||||
HybridGaussianProductFactor result;
|
HybridGaussianProductFactor result;
|
||||||
|
|
||||||
for (auto& f : factors_) {
|
for (auto& f : factors_) {
|
||||||
|
@ -198,9 +193,8 @@ HybridGaussianFactorGraph::collectProductFactor() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> continuousElimination(
|
||||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
const HybridGaussianFactorGraph& factors, const Ordering& frontalKeys) {
|
||||||
const Ordering &frontalKeys) {
|
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
for (auto& f : factors) {
|
for (auto& f : factors) {
|
||||||
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
@ -241,9 +235,8 @@ static AlgebraicDecisionTree<Key> probabilitiesFromNegativeLogValues(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> discreteElimination(
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
const HybridGaussianFactorGraph& factors, const Ordering& frontalKeys) {
|
||||||
const Ordering &frontalKeys) {
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto& f : factors) {
|
for (auto& f : factors) {
|
||||||
|
@ -262,8 +255,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> probabilities =
|
AlgebraicDecisionTree<Key> probabilities =
|
||||||
probabilitiesFromNegativeLogValues(logProbabilities);
|
probabilitiesFromNegativeLogValues(logProbabilities);
|
||||||
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
|
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(), probabilities);
|
||||||
probabilities);
|
|
||||||
|
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
|
@ -284,8 +276,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
using Result = std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
||||||
HybridGaussianFactor::sharedFactor>;
|
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
|
* Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
|
||||||
|
@ -293,11 +285,10 @@ using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
||||||
* The residual error contains no keys, and only
|
* The residual error contains no keys, and only
|
||||||
* depends on the discrete separator if present.
|
* depends on the discrete separator if present.
|
||||||
*/
|
*/
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(const ResultTree& eliminationResults,
|
||||||
const DecisionTree<Key, Result> &eliminationResults,
|
|
||||||
const DiscreteKeys& discreteSeparator) {
|
const DiscreteKeys& discreteSeparator) {
|
||||||
auto negLogProbability = [&](const Result &pair) -> double {
|
auto negLogProbability = [&](const auto& pair) -> double {
|
||||||
const auto &[conditional, factor] = pair;
|
const auto& [conditional, factor] = pair.first;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
// If the factor is not null, it has no keys, just contains the residual.
|
// If the factor is not null, it has no keys, just contains the residual.
|
||||||
|
@ -324,12 +315,11 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
|
|
||||||
// Create HybridGaussianFactor on the separator, taking care to correct
|
// Create HybridGaussianFactor on the separator, taking care to correct
|
||||||
// for conditional constants.
|
// for conditional constants.
|
||||||
static std::shared_ptr<Factor> createHybridGaussianFactor(
|
static std::shared_ptr<Factor> createHybridGaussianFactor(const ResultTree& eliminationResults,
|
||||||
const DecisionTree<Key, Result> &eliminationResults,
|
|
||||||
const DiscreteKeys& discreteSeparator) {
|
const DiscreteKeys& discreteSeparator) {
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const Result &pair) -> GaussianFactorValuePair {
|
auto correct = [&](const auto& pair) -> GaussianFactorValuePair {
|
||||||
const auto &[conditional, factor] = pair;
|
const auto& [conditional, factor] = pair.first;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
|
@ -345,16 +335,14 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults, correct);
|
||||||
correct);
|
|
||||||
|
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
||||||
static auto GetDiscreteKeys =
|
static auto GetDiscreteKeys = [](const HybridGaussianFactorGraph& hfg) -> DiscreteKeys {
|
||||||
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
|
|
||||||
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
||||||
return {discreteKeySet.begin(), discreteKeySet.end()};
|
return {discreteKeySet.begin(), discreteKeySet.end()};
|
||||||
};
|
};
|
||||||
|
@ -377,9 +365,12 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
bool someContinuousLeft = false;
|
bool someContinuousLeft = false;
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
auto eliminate =
|
||||||
|
[&](const std::pair<GaussianFactorGraph, double>& pair) -> std::pair<Result, double> {
|
||||||
|
const auto& [graph, scalar] = pair;
|
||||||
|
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
return {nullptr, nullptr};
|
return {{nullptr, nullptr}, 0.0};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expensive elimination of product factor.
|
// Expensive elimination of product factor.
|
||||||
|
@ -388,25 +379,24 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
// Record whether there any continuous variables left
|
// Record whether there any continuous variables left
|
||||||
someContinuousLeft |= !result.second->empty();
|
someContinuousLeft |= !result.second->empty();
|
||||||
|
|
||||||
return result;
|
return {result, scalar};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, Result> eliminationResults(prunedProductFactor, eliminate);
|
ResultTree eliminationResults(prunedProductFactor, eliminate);
|
||||||
|
|
||||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||||
// on the separator, taking care to correct for conditional constants.
|
// on the separator, taking care to correct for conditional constants.
|
||||||
auto newFactor =
|
auto newFactor = someContinuousLeft
|
||||||
someContinuousLeft
|
|
||||||
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
||||||
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
||||||
|
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional from the conditionals
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::Conditionals conditionals(
|
||||||
eliminationResults, [](const Result &pair) { return pair.first; });
|
eliminationResults, [](const auto& pair) { return pair.first.first; });
|
||||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
auto hybridGaussian =
|
||||||
discreteSeparator, conditionals);
|
std::make_shared<HybridGaussianConditional>(discreteSeparator, conditionals);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
||||||
}
|
}
|
||||||
|
@ -426,8 +416,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
* be INCORRECT and there will be NO error raised.
|
* be INCORRECT and there will be NO error raised.
|
||||||
*/
|
*/
|
||||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
|
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> //
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys) {
|
||||||
const Ordering &keys) {
|
|
||||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||||
// a few cases:
|
// a few cases:
|
||||||
// 1. continuous variable, make a hybrid Gaussian conditional if there are
|
// 1. continuous variable, make a hybrid Gaussian conditional if there are
|
||||||
|
@ -489,11 +478,9 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
only_discrete = false;
|
only_discrete = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else if (auto cont_factor =
|
} else if (auto cont_factor = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
|
||||||
only_discrete = false;
|
only_discrete = false;
|
||||||
} else if (auto discrete_factor =
|
} else if (auto discrete_factor = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
|
||||||
only_continuous = false;
|
only_continuous = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -553,8 +540,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
GaussianFactorGraph HybridGaussianFactorGraph::choose(
|
GaussianFactorGraph HybridGaussianFactorGraph::choose(const DiscreteValues& assignment) const {
|
||||||
const DiscreteValues &assignment) const {
|
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
for (auto&& f : *this) {
|
for (auto&& f : *this) {
|
||||||
if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(f)) {
|
if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
|
|
@ -24,11 +24,12 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
static GaussianFactorGraph add(const GaussianFactorGraph &graph1,
|
using Y = HybridGaussianProductFactor::Y;
|
||||||
const GaussianFactorGraph &graph2) {
|
|
||||||
auto result = graph1;
|
static Y add(const Y& y1, const Y& y2) {
|
||||||
result.push_back(graph2);
|
GaussianFactorGraph result = y1.first;
|
||||||
return result;
|
result.push_back(y2.first);
|
||||||
|
return {result, y1.second + y2.second};
|
||||||
};
|
};
|
||||||
|
|
||||||
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
|
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
|
||||||
|
@ -52,36 +53,33 @@ HybridGaussianProductFactor &HybridGaussianProductFactor::operator+=(
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
HybridGaussianProductFactor &
|
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
||||||
HybridGaussianProductFactor::operator+=(const HybridGaussianFactor &factor) {
|
const HybridGaussianFactor& factor) {
|
||||||
*this = *this + factor;
|
*this = *this + factor;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HybridGaussianProductFactor::print(const std::string &s,
|
void HybridGaussianProductFactor::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
KeySet keys;
|
KeySet keys;
|
||||||
auto printer = [&](const Y &graph) {
|
auto printer = [&](const Y& y) {
|
||||||
if (keys.size() == 0)
|
if (keys.empty()) keys = y.first.keys();
|
||||||
keys = graph.keys();
|
return "Graph of size " + std::to_string(y.first.size()) +
|
||||||
return "Graph of size " + std::to_string(graph.size());
|
", scalar sum: " + std::to_string(y.second);
|
||||||
};
|
};
|
||||||
Base::print(s, formatter, printer);
|
Base::print(s, formatter, printer);
|
||||||
if (keys.size() > 0) {
|
if (!keys.empty()) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << s << " Keys:";
|
ss << s << " Keys:";
|
||||||
for (auto &&key : keys)
|
for (auto&& key : keys) ss << " " << formatter(key);
|
||||||
ss << " " << formatter(key);
|
|
||||||
std::cout << ss.str() << "." << std::endl;
|
std::cout << ss.str() << "." << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
||||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
auto emptyGaussian = [](const Y& y) {
|
||||||
bool hasNull =
|
bool hasNull = std::any_of(
|
||||||
std::any_of(graph.begin(), graph.end(),
|
y.first.begin(), y.first.end(), [](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
|
||||||
return hasNull ? GaussianFactorGraph() : graph;
|
|
||||||
};
|
};
|
||||||
return {Base(*this, emptyGaussian)};
|
return {Base(*this, emptyGaussian)};
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,10 +26,11 @@ namespace gtsam {
|
||||||
|
|
||||||
class HybridGaussianFactor;
|
class HybridGaussianFactor;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
/// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums
|
||||||
class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph> {
|
class HybridGaussianProductFactor
|
||||||
|
: public DecisionTree<Key, std::pair<GaussianFactorGraph, double>> {
|
||||||
public:
|
public:
|
||||||
using Y = GaussianFactorGraph;
|
using Y = std::pair<GaussianFactorGraph, double>;
|
||||||
using Base = DecisionTree<Key, Y>;
|
using Base = DecisionTree<Key, Y>;
|
||||||
|
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
@ -44,7 +45,8 @@ class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph
|
||||||
* @param factor Shared pointer to the factor
|
* @param factor Shared pointer to the factor
|
||||||
*/
|
*/
|
||||||
template <class FACTOR>
|
template <class FACTOR>
|
||||||
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor) : Base(Y{factor}) {}
|
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor)
|
||||||
|
: Base(Y{GaussianFactorGraph{factor}, 0.0}) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct from DecisionTree
|
* @brief Construct from DecisionTree
|
||||||
|
@ -88,7 +90,9 @@ class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph
|
||||||
* @return true if equal, false otherwise
|
* @return true if equal, false otherwise
|
||||||
*/
|
*/
|
||||||
bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const {
|
bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const {
|
||||||
return Base::equals(other, [tol](const Y& a, const Y& b) { return a.equals(b, tol); });
|
return Base::equals(other, [tol](const Y& a, const Y& b) {
|
||||||
|
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
@ -101,9 +105,9 @@ class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph
|
||||||
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed
|
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed
|
||||||
*
|
*
|
||||||
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
* that leaf to an empty GaussianFactorGraph. This is needed because the DecisionTree
|
* that leaf to an empty GaussianFactorGraph with zero scalar sum. This is needed because the
|
||||||
* will otherwise create a GaussianFactorGraph with a single (null) factor,
|
* DecisionTree will otherwise create a GaussianFactorGraph with a single (null) factor, which
|
||||||
* which doesn't register as null.
|
* doesn't register as null.
|
||||||
*/
|
*/
|
||||||
HybridGaussianProductFactor removeEmpty() const;
|
HybridGaussianProductFactor removeEmpty() const;
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
#include "TinyHybridExample.h"
|
#include "TinyHybridExample.h"
|
||||||
|
#include "gtsam/linear/GaussianFactorGraph.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
@ -73,8 +74,7 @@ TEST(HybridGaussianFactorGraph, Creation) {
|
||||||
HybridGaussianConditional gm(
|
HybridGaussianConditional gm(
|
||||||
m0,
|
m0,
|
||||||
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
||||||
std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1),
|
std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1), I_3x3)});
|
||||||
I_3x3)});
|
|
||||||
hfg.add(gm);
|
hfg.add(gm);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(2, hfg.size());
|
EXPECT_LONGS_EQUAL(2, hfg.size());
|
||||||
|
@ -282,8 +282,7 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
|
||||||
expected_continuous.insert<double>(X(1), 1);
|
expected_continuous.insert<double>(X(1), 1);
|
||||||
expected_continuous.insert<double>(X(2), 2);
|
expected_continuous.insert<double>(X(2), 2);
|
||||||
expected_continuous.insert<double>(X(3), 4);
|
expected_continuous.insert<double>(X(3), 4);
|
||||||
Values result_continuous =
|
Values result_continuous = switching.linearizationPoint.retract(result.continuous());
|
||||||
switching.linearizationPoint.retract(result.continuous());
|
|
||||||
EXPECT(assert_equal(expected_continuous, result_continuous));
|
EXPECT(assert_equal(expected_continuous, result_continuous));
|
||||||
|
|
||||||
DiscreteValues expected_discrete;
|
DiscreteValues expected_discrete;
|
||||||
|
@ -376,19 +375,18 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
auto error_tree2 = graph.errorTree(delta.continuous());
|
auto error_tree2 = graph.errorTree(delta.continuous());
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
|
leaves = {
|
||||||
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
0.50985198, 0.0097577296, 0.50009425, 0, 0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
||||||
AlgebraicDecisionTree<Key> expected_error2(s.modes, leaves);
|
AlgebraicDecisionTree<Key> expected_error2(s.modes, leaves);
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
// Check that collectProductFactor works correctly.
|
||||||
// assignment.
|
|
||||||
TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
VectorValues vv{{Z(0), Vector1(5.0)}};
|
||||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
auto fg = tiny::createHybridGaussianFactorGraph(num_measurements, vv);
|
||||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
// Assemble graph tree:
|
// Assemble graph tree:
|
||||||
|
@ -411,21 +409,24 @@ TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||||
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
|
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
|
||||||
|
|
||||||
// Expected decision tree with two factor graphs:
|
// Expected decision tree with two factor graphs:
|
||||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
// f(x0;mode=0)P(x0)
|
||||||
HybridGaussianProductFactor expected{
|
GaussianFactorGraph expectedFG0{(*hybrid)(d0), prior};
|
||||||
{M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
|
EXPECT(assert_equal(expectedFG0, actual(d0).first, 1e-5));
|
||||||
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})}};
|
EXPECT(assert_equal(0.0, actual(d0).second, 1e-5));
|
||||||
|
|
||||||
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
// f(x0;mode=1)P(x0)
|
||||||
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
GaussianFactorGraph expectedFG1{(*hybrid)(d1), prior};
|
||||||
|
EXPECT(assert_equal(expectedFG1, actual(d1).first, 1e-5));
|
||||||
|
EXPECT(assert_equal(1.79176, actual(d1).second, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that the factor graph unnormalized probability is proportional to the
|
// Check that the factor graph unnormalized probability is proportional to the
|
||||||
// Bayes net probability for the given measurements.
|
// Bayes net probability for the given measurements.
|
||||||
bool
|
bool ratioTest(const HybridBayesNet& bn,
|
||||||
ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
const VectorValues& measurements,
|
||||||
const HybridGaussianFactorGraph &fg, size_t num_samples = 100) {
|
const HybridGaussianFactorGraph& fg,
|
||||||
|
size_t num_samples = 100) {
|
||||||
auto compute_ratio = [&](HybridValues* sample) -> double {
|
auto compute_ratio = [&](HybridValues* sample) -> double {
|
||||||
sample->update(measurements); // update sample with given measurements:
|
sample->update(measurements); // update sample with given measurements:
|
||||||
return bn.evaluate(*sample) / fg.probPrime(*sample);
|
return bn.evaluate(*sample) / fg.probPrime(*sample);
|
||||||
|
@ -437,8 +438,7 @@ TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||||
// Test ratios for a number of independent samples:
|
// Test ratios for a number of independent samples:
|
||||||
for (size_t i = 0; i < num_samples; i++) {
|
for (size_t i = 0; i < num_samples; i++) {
|
||||||
HybridValues sample = bn.sample(&kRng);
|
HybridValues sample = bn.sample(&kRng);
|
||||||
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6)
|
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -446,8 +446,10 @@ TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that the bayes net unnormalized probability is proportional to the
|
// Check that the bayes net unnormalized probability is proportional to the
|
||||||
// Bayes net probability for the given measurements.
|
// Bayes net probability for the given measurements.
|
||||||
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
bool ratioTest(const HybridBayesNet& bn,
|
||||||
const HybridBayesNet &posterior, size_t num_samples = 100) {
|
const VectorValues& measurements,
|
||||||
|
const HybridBayesNet& posterior,
|
||||||
|
size_t num_samples = 100) {
|
||||||
auto compute_ratio = [&](HybridValues* sample) -> double {
|
auto compute_ratio = [&](HybridValues* sample) -> double {
|
||||||
sample->update(measurements); // update sample with given measurements:
|
sample->update(measurements); // update sample with given measurements:
|
||||||
return bn.evaluate(*sample) / posterior.evaluate(*sample);
|
return bn.evaluate(*sample) / posterior.evaluate(*sample);
|
||||||
|
@ -461,8 +463,7 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||||
HybridValues sample = bn.sample(&kRng);
|
HybridValues sample = bn.sample(&kRng);
|
||||||
// GTSAM_PRINT(sample);
|
// GTSAM_PRINT(sample);
|
||||||
// std::cout << "ratio: " << compute_ratio(&sample) << std::endl;
|
// std::cout << "ratio: " << compute_ratio(&sample) << std::endl;
|
||||||
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6)
|
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -484,10 +485,10 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
// Create hybrid Gaussian factor on X(0).
|
// Create hybrid Gaussian factor on X(0).
|
||||||
using tiny::mode;
|
using tiny::mode;
|
||||||
// regression, but mean checked to be 5.0 in both cases:
|
// regression, but mean checked to be 5.0 in both cases:
|
||||||
const auto conditional0 = std::make_shared<GaussianConditional>(
|
const auto conditional0 =
|
||||||
X(0), Vector1(14.1421), I_1x1 * 2.82843),
|
std::make_shared<GaussianConditional>(X(0), Vector1(14.1421), I_1x1 * 2.82843),
|
||||||
conditional1 = std::make_shared<GaussianConditional>(
|
conditional1 =
|
||||||
X(0), Vector1(10.1379), I_1x1 * 2.02759);
|
std::make_shared<GaussianConditional>(X(0), Vector1(10.1379), I_1x1 * 2.02759);
|
||||||
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
||||||
mode, std::vector{conditional0, conditional1});
|
mode, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
|
@ -515,8 +516,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
||||||
bn.emplace_shared<HybridGaussianConditional>(m1, Z(0), I_1x1, X(0), parms);
|
bn.emplace_shared<HybridGaussianConditional>(m1, Z(0), I_1x1, X(0), parms);
|
||||||
|
|
||||||
// Create prior on X(0).
|
// Create prior on X(0).
|
||||||
bn.push_back(
|
bn.push_back(GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
|
||||||
GaussianConditional::sharedMeanAndStddev(X(0), Vector1(5.0), 0.5));
|
|
||||||
|
|
||||||
// Add prior on m1.
|
// Add prior on m1.
|
||||||
bn.emplace_shared<DiscreteConditional>(m1, "1/1");
|
bn.emplace_shared<DiscreteConditional>(m1, "1/1");
|
||||||
|
@ -534,10 +534,10 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
|
||||||
|
|
||||||
// Create hybrid Gaussian factor on X(0).
|
// Create hybrid Gaussian factor on X(0).
|
||||||
// regression, but mean checked to be 5.0 in both cases:
|
// regression, but mean checked to be 5.0 in both cases:
|
||||||
const auto conditional0 = std::make_shared<GaussianConditional>(
|
const auto conditional0 =
|
||||||
X(0), Vector1(10.1379), I_1x1 * 2.02759),
|
std::make_shared<GaussianConditional>(X(0), Vector1(10.1379), I_1x1 * 2.02759),
|
||||||
conditional1 = std::make_shared<GaussianConditional>(
|
conditional1 =
|
||||||
X(0), Vector1(14.1421), I_1x1 * 2.82843);
|
std::make_shared<GaussianConditional>(X(0), Vector1(14.1421), I_1x1 * 2.82843);
|
||||||
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
||||||
m1, std::vector{conditional0, conditional1});
|
m1, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
|
@ -570,10 +570,10 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
// Create hybrid Gaussian factor on X(0).
|
// Create hybrid Gaussian factor on X(0).
|
||||||
using tiny::mode;
|
using tiny::mode;
|
||||||
// regression, but mean checked to be 5.0 in both cases:
|
// regression, but mean checked to be 5.0 in both cases:
|
||||||
const auto conditional0 = std::make_shared<GaussianConditional>(
|
const auto conditional0 =
|
||||||
X(0), Vector1(17.3205), I_1x1 * 3.4641),
|
std::make_shared<GaussianConditional>(X(0), Vector1(17.3205), I_1x1 * 3.4641),
|
||||||
conditional1 = std::make_shared<GaussianConditional>(
|
conditional1 =
|
||||||
X(0), Vector1(10.274), I_1x1 * 2.0548);
|
std::make_shared<GaussianConditional>(X(0), Vector1(10.274), I_1x1 * 2.0548);
|
||||||
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
expectedBayesNet.emplace_shared<HybridGaussianConditional>(
|
||||||
mode, std::vector{conditional0, conditional1});
|
mode, std::vector{conditional0, conditional1});
|
||||||
|
|
||||||
|
@ -617,27 +617,25 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
// NOTE: we add reverse topological so we can sample from the Bayes net.:
|
// NOTE: we add reverse topological so we can sample from the Bayes net.:
|
||||||
|
|
||||||
// Add measurements:
|
// Add measurements:
|
||||||
std::vector<std::pair<Vector, double>> measurementModels{{Z_1x1, 3},
|
std::vector<std::pair<Vector, double>> measurementModels{{Z_1x1, 3}, {Z_1x1, 0.5}};
|
||||||
{Z_1x1, 0.5}};
|
|
||||||
for (size_t t : {0, 1, 2}) {
|
for (size_t t : {0, 1, 2}) {
|
||||||
// Create hybrid Gaussian factor on Z(t) conditioned on X(t) and mode N(t):
|
// Create hybrid Gaussian factor on Z(t) conditioned on X(t) and mode N(t):
|
||||||
const auto noise_mode_t = DiscreteKey{N(t), 2};
|
const auto noise_mode_t = DiscreteKey{N(t), 2};
|
||||||
bn.emplace_shared<HybridGaussianConditional>(noise_mode_t, Z(t), I_1x1,
|
bn.emplace_shared<HybridGaussianConditional>(
|
||||||
X(t), measurementModels);
|
noise_mode_t, Z(t), I_1x1, X(t), measurementModels);
|
||||||
|
|
||||||
// Create prior on discrete mode N(t):
|
// Create prior on discrete mode N(t):
|
||||||
bn.emplace_shared<DiscreteConditional>(noise_mode_t, "20/80");
|
bn.emplace_shared<DiscreteConditional>(noise_mode_t, "20/80");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add motion models. TODO(frank): why are they exactly the same?
|
// Add motion models. TODO(frank): why are they exactly the same?
|
||||||
std::vector<std::pair<Vector, double>> motionModels{{Z_1x1, 0.2},
|
std::vector<std::pair<Vector, double>> motionModels{{Z_1x1, 0.2}, {Z_1x1, 0.2}};
|
||||||
{Z_1x1, 0.2}};
|
|
||||||
for (size_t t : {2, 1}) {
|
for (size_t t : {2, 1}) {
|
||||||
// Create hybrid Gaussian factor on X(t) conditioned on X(t-1)
|
// Create hybrid Gaussian factor on X(t) conditioned on X(t-1)
|
||||||
// and mode M(t-1):
|
// and mode M(t-1):
|
||||||
const auto motion_model_t = DiscreteKey{M(t), 2};
|
const auto motion_model_t = DiscreteKey{M(t), 2};
|
||||||
bn.emplace_shared<HybridGaussianConditional>(motion_model_t, X(t), I_1x1,
|
bn.emplace_shared<HybridGaussianConditional>(
|
||||||
X(t - 1), motionModels);
|
motion_model_t, X(t), I_1x1, X(t - 1), motionModels);
|
||||||
|
|
||||||
// Create prior on motion model M(t):
|
// Create prior on motion model M(t):
|
||||||
bn.emplace_shared<DiscreteConditional>(motion_model_t, "40/60");
|
bn.emplace_shared<DiscreteConditional>(motion_model_t, "40/60");
|
||||||
|
@ -650,8 +648,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
EXPECT_LONGS_EQUAL(6, bn.sample().continuous().size());
|
EXPECT_LONGS_EQUAL(6, bn.sample().continuous().size());
|
||||||
|
|
||||||
// Create measurements consistent with moving right every time:
|
// Create measurements consistent with moving right every time:
|
||||||
const VectorValues measurements{
|
const VectorValues measurements{{Z(0), Vector1(0.0)}, {Z(1), Vector1(1.0)}, {Z(2), Vector1(2.0)}};
|
||||||
{Z(0), Vector1(0.0)}, {Z(1), Vector1(1.0)}, {Z(2), Vector1(2.0)}};
|
|
||||||
const HybridGaussianFactorGraph fg = bn.toFactorGraph(measurements);
|
const HybridGaussianFactorGraph fg = bn.toFactorGraph(measurements);
|
||||||
|
|
||||||
// Factor graph is:
|
// Factor graph is:
|
||||||
|
|
|
@ -16,11 +16,11 @@
|
||||||
* @date October 2024
|
* @date October 2024
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "gtsam/inference/Key.h"
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
@ -39,29 +39,27 @@ using symbol_shorthand::X;
|
||||||
namespace examples {
|
namespace examples {
|
||||||
static const DiscreteKey m1(M(1), 2), m2(M(2), 3);
|
static const DiscreteKey m1(M(1), 2), m2(M(2), 3);
|
||||||
|
|
||||||
auto A1 = Matrix::Zero(2, 1);
|
const auto A1 = Matrix::Zero(2, 1);
|
||||||
auto A2 = Matrix::Zero(2, 2);
|
const auto A2 = Matrix::Zero(2, 2);
|
||||||
auto b = Matrix::Zero(2, 1);
|
const auto b = Matrix::Zero(2, 1);
|
||||||
|
|
||||||
auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
const auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||||
auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
const auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||||
|
const HybridGaussianFactor hybridFactorA(m1, {{f10, 10}, {f11, 11}});
|
||||||
|
|
||||||
auto A3 = Matrix::Zero(2, 3);
|
const auto A3 = Matrix::Zero(2, 3);
|
||||||
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
const auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
const auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
const auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
|
|
||||||
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
|
const HybridGaussianFactor hybridFactorB(m2, {{f20, 20}, {f21, 21}, {f22, 22}});
|
||||||
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
|
||||||
// Simulate a pruned hybrid factor, in this case m2==1 is nulled out.
|
// Simulate a pruned hybrid factor, in this case m2==1 is nulled out.
|
||||||
HybridGaussianFactor prunedFactorB(m2, {f20, nullptr, f22});
|
const HybridGaussianFactor prunedFactorB(m2, {{f20, 20}, {nullptr, 1000}, {f22, 22}});
|
||||||
} // namespace examples
|
} // namespace examples
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Constructor
|
// Constructor
|
||||||
TEST(HybridGaussianProductFactor, Construct) {
|
TEST(HybridGaussianProductFactor, Construct) { HybridGaussianProductFactor product; }
|
||||||
HybridGaussianProductFactor product;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Add two Gaussian factors and check only one leaf in tree
|
// Add two Gaussian factors and check only one leaf in tree
|
||||||
|
@ -80,9 +78,10 @@ TEST(HybridGaussianProductFactor, AddTwoGaussianFactors) {
|
||||||
auto leaf = product(Assignment<Key>());
|
auto leaf = product(Assignment<Key>());
|
||||||
|
|
||||||
// Check that the leaf contains both factors
|
// Check that the leaf contains both factors
|
||||||
EXPECT_LONGS_EQUAL(2, leaf.size());
|
EXPECT_LONGS_EQUAL(2, leaf.first.size());
|
||||||
EXPECT(leaf.at(0) == f10);
|
EXPECT(leaf.first.at(0) == f10);
|
||||||
EXPECT(leaf.at(1) == f11);
|
EXPECT(leaf.first.at(1) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, leaf.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -107,9 +106,10 @@ TEST(HybridGaussianProductFactor, AddTwoGaussianConditionals) {
|
||||||
auto leaf = product(Assignment<Key>());
|
auto leaf = product(Assignment<Key>());
|
||||||
|
|
||||||
// Check that the leaf contains both conditionals
|
// Check that the leaf contains both conditionals
|
||||||
EXPECT_LONGS_EQUAL(2, leaf.size());
|
EXPECT_LONGS_EQUAL(2, leaf.first.size());
|
||||||
EXPECT(leaf.at(0) == gc1);
|
EXPECT(leaf.first.at(0) == gc1);
|
||||||
EXPECT(leaf.at(1) == gc2);
|
EXPECT(leaf.first.at(1) == gc2);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, leaf.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -120,9 +120,12 @@ TEST(HybridGaussianProductFactor, AsProductFactor) {
|
||||||
|
|
||||||
// Let's check that this worked:
|
// Let's check that this worked:
|
||||||
Assignment<Key> mode;
|
Assignment<Key> mode;
|
||||||
mode[m1.first] = 1;
|
mode[m1.first] = 0;
|
||||||
auto actual = product(mode);
|
auto actual = product(mode);
|
||||||
EXPECT(actual.at(0) == f11);
|
EXPECT(actual.first.at(0) == f10);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
|
// TODO(Frank): when killed hiding, f11 should also be there
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -134,9 +137,12 @@ TEST(HybridGaussianProductFactor, AddOne) {
|
||||||
|
|
||||||
// Let's check that this worked:
|
// Let's check that this worked:
|
||||||
Assignment<Key> mode;
|
Assignment<Key> mode;
|
||||||
mode[m1.first] = 1;
|
mode[m1.first] = 0;
|
||||||
auto actual = product(mode);
|
auto actual = product(mode);
|
||||||
EXPECT(actual.at(0) == f11);
|
EXPECT(actual.first.at(0) == f10);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
|
// TODO(Frank): when killed hiding, f11 should also be there
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -152,12 +158,15 @@ TEST(HybridGaussianProductFactor, AddTwo) {
|
||||||
|
|
||||||
// Let's check that this worked:
|
// Let's check that this worked:
|
||||||
auto actual00 = product({{M(1), 0}, {M(2), 0}});
|
auto actual00 = product({{M(1), 0}, {M(2), 0}});
|
||||||
EXPECT(actual00.at(0) == f10);
|
EXPECT(actual00.first.at(0) == f10);
|
||||||
EXPECT(actual00.at(1) == f20);
|
EXPECT(actual00.first.at(1) == f20);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10 + 20, actual00.second, 1e-9);
|
||||||
|
|
||||||
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
||||||
EXPECT(actual12.at(0) == f11);
|
// TODO(Frank): when killed hiding, these should also equal:
|
||||||
EXPECT(actual12.at(1) == f22);
|
// EXPECT(actual12.first.at(0) == f11);
|
||||||
|
// EXPECT(actual12.first.at(1) == f22);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11 + 22, actual12.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue