Merge pull request #1802 from borglab/working-hybrid

Working Hybrid
release/4.3a0
Varun Agrawal 2024-09-05 09:25:36 -04:00 committed by GitHub
commit 232fa02b19
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 597 additions and 77 deletions

View File

@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -86,7 +87,22 @@ GaussianFactorGraphTree GaussianMixture::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) { auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm =
this->logConstant_ - gc->logNormalizationConstant();
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
// We add a constant factor which will be used when computing
// the probability of the discrete variables.
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
return GaussianFactorGraph{gc, constantFactor};
}
}
return GaussianFactorGraph{gc}; return GaussianFactorGraph{gc};
}; };
return {conditionals_, wrap}; return {conditionals_, wrap};
@ -145,6 +161,8 @@ void GaussianMixture::print(const std::string &s,
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
} }
std::cout << "\n"; std::cout << "\n";
std::cout << " logNormalizationConstant: " << logConstant_ << "\n"
<< std::endl;
conditionals_.print( conditionals_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string { [&](const GaussianConditional::shared_ptr &gf) -> std::string {
@ -312,12 +330,28 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc); return DecisionTree<Key, double>(conditionals_, probFunc);
} }
/* ************************************************************************* */
double GaussianMixture::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::errorTree( AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + // return conditionalError(conditional, continuousValues);
logConstant_ - conditional->logNormalizationConstant();
}; };
DecisionTree<Key, double> error_tree(conditionals_, errorFunc); DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree; return error_tree;
@ -327,8 +361,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
double GaussianMixture::error(const HybridValues &values) const { double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + // return conditionalError(conditional, values.continuous());
logConstant_ - conditional->logNormalizationConstant();
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -67,7 +67,7 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant. double logConstant_; ///< log of the normalization constant.
/** /**
* @brief Convert a DecisionTree of factors into * @brief Convert a GaussianMixture of conditionals into
* a DecisionTree of Gaussian factor graphs. * a DecisionTree of Gaussian factor graphs.
*/ */
GaussianFactorGraphTree asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;
/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -54,7 +54,9 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s, void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); std::cout << (s.empty() ? "" : s + "\n");
std::cout << "GaussianMixtureFactor" << std::endl;
HybridFactor::print("", formatter);
std::cout << "{\n"; std::cout << "{\n";
if (factors_.empty()) { if (factors_.empty()) {
std::cout << " empty" << std::endl; std::cout << " empty" << std::endl;
@ -64,7 +66,7 @@ void GaussianMixtureFactor::print(const std::string &s,
[&](const sharedFactor &gf) -> std::string { [&](const sharedFactor &gf) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf && !gf->empty()) { if (gf) {
gf->print("", formatter); gf->print("", formatter);
return rd.str(); return rd.str();
} else { } else {
@ -117,6 +119,5 @@ double GaussianMixtureFactor::error(const HybridValues &values) const {
const sharedFactor gf = factors_(values.discrete()); const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous()); return gf->error(values.continuous());
} }
/* *******************************************************************************/
} // namespace gtsam } // namespace gtsam

View File

@ -80,8 +80,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @param continuousKeys A vector of keys representing continuous variables. * @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and * @param discreteKeys A vector of keys representing discrete variables and
* their cardinalities. * their cardinalities.
* @param factors The decision tree of Gaussian factors stored as the mixture * @param factors The decision tree of Gaussian factors stored
* density. * as the mixture density.
*/ */
GaussianMixtureFactor(const KeyVector &continuousKeys, GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
@ -107,9 +107,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print( void print(const std::string &s = "", const KeyFormatter &formatter =
const std::string &s = "GaussianMixtureFactor\n", DefaultKeyFormatter) const override;
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API /// @name Standard API

View File

@ -220,15 +220,16 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE // Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn; DiscreteFactorGraph discrete_fg;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete()); discrete_fg.push_back(conditional->asDiscrete());
} }
} }
// Solve for the MPE // Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); DiscreteValues mpe = discrete_fg.optimize();
// Given the MPE, compute the optimal continuous values. // Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe); return HybridValues(optimize(mpe), mpe);

View File

@ -61,7 +61,7 @@ class GTSAM_EXPORT HybridConditional
public Conditional<HybridFactor, HybridConditional> { public Conditional<HybridFactor, HybridConditional> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef HybridConditional This; ///< Typedef to this class typedef HybridConditional This; ///< Typedef to this class
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
typedef Conditional<BaseFactor, This> typedef Conditional<BaseFactor, This>
@ -185,7 +185,7 @@ class GTSAM_EXPORT HybridConditional
* Return the log normalization constant. * Return the log normalization constant.
* Note this is 0.0 for discrete and hybrid conditionals, but depends * Note this is 0.0 for discrete and hybrid conditionals, but depends
* on the continuous parameters for Gaussian conditionals. * on the continuous parameters for Gaussian conditionals.
*/ */
double logNormalizationConstant() const override; double logNormalizationConstant() const override;
/// Return the probability (or density) of the underlying conditional. /// Return the probability (or density) of the underlying conditional.

View File

@ -13,6 +13,7 @@
* @file HybridFactor.h * @file HybridFactor.h
* @date Mar 11, 2022 * @date Mar 11, 2022
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal
*/ */
#pragma once #pragma once

View File

@ -49,15 +49,6 @@ KeySet HybridFactorGraph::discreteKeySet() const {
return keys; return keys;
} }
/* ************************************************************************* */
std::unordered_map<Key, DiscreteKey> HybridFactorGraph::discreteKeyMap() const {
std::unordered_map<Key, DiscreteKey> result;
for (const DiscreteKey& k : discreteKeys()) {
result[k.first] = k;
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
const KeySet HybridFactorGraph::continuousKeySet() const { const KeySet HybridFactorGraph::continuousKeySet() const {
KeySet keys; KeySet keys;

View File

@ -38,7 +38,7 @@ using SharedFactor = std::shared_ptr<Factor>;
class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> { class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> {
public: public:
using Base = FactorGraph<Factor>; using Base = FactorGraph<Factor>;
using This = HybridFactorGraph; ///< this class using This = HybridFactorGraph; ///< this class
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
@ -66,12 +66,9 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> {
/// Get all the discrete keys in the factor graph. /// Get all the discrete keys in the factor graph.
std::set<DiscreteKey> discreteKeys() const; std::set<DiscreteKey> discreteKeys() const;
/// Get all the discrete keys in the factor graph, as a set. /// Get all the discrete keys in the factor graph, as a set of Keys.
KeySet discreteKeySet() const; KeySet discreteKeySet() const;
/// Get a map from Key to corresponding DiscreteKey.
std::unordered_map<Key, DiscreteKey> discreteKeyMap() const;
/// Get all the continuous keys in the factor graph. /// Get all the continuous keys in the factor graph.
const KeySet continuousKeySet() const; const KeySet continuousKeySet() const;

View File

@ -97,29 +97,27 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter); gmf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = " << gmf->error(values) << std::endl;
gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl;
} }
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) { } else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) { if (factor == nullptr) {
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter);
if (hc->isContinuous()) { if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) { } else if (hc->isDiscrete()) {
std::cout << "error = "; factor->print(ss.str(), keyFormatter);
hc->asDiscrete()->errorTree().print("", keyFormatter); std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
std::cout << "\n"; << "\n";
} else { } else {
// Is hybrid // Is hybrid
std::cout << "error = "; auto mixtureComponent =
hc->asMixture()->errorTree(values.continuous()).print(); hc->asMixture()->operator()(values.discrete());
std::cout << "\n"; mixtureComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << mixtureComponent->error(values) << "\n";
} }
} }
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) { } else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
@ -140,8 +138,7 @@ void HybridGaussianFactorGraph::printErrors(
<< "\n"; << "\n";
} else { } else {
factor->print(ss.str(), keyFormatter); factor->print(ss.str(), keyFormatter);
std::cout << "error = "; std::cout << "error = " << df->error(values.discrete()) << std::endl;
df->errorTree().print("", keyFormatter);
} }
} else { } else {
@ -233,6 +230,25 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
/* ************************************************************************ */
/**
* @brief Exponentiate log-values, not necessarily normalized, normalize, and
* return as AlgebraicDecisionTree<Key>.
*
* @param logValues DecisionTree of (unnormalized) log values.
* @return AlgebraicDecisionTree<Key>
*/
static AlgebraicDecisionTree<Key> probabilitiesFromLogValues(
const AlgebraicDecisionTree<Key> &logValues) {
// Perform normalization
double max_log = logValues.max();
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
logValues, [&max_log](const double x) { return exp(x - max_log); });
probabilities = probabilities.normalize(probabilities.sum());
return probabilities;
}
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>> static std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
@ -242,6 +258,22 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
for (auto &f : factors) { for (auto &f : factors) {
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) { if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df); dfg.push_back(df);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Case where we have a GaussianMixtureFactor with no continuous keys.
// In this case, compute discrete probabilities.
auto logProbability =
[&](const GaussianFactor::shared_ptr &factor) -> double {
if (!factor) return 0.0;
return -factor->error(VectorValues());
};
AlgebraicDecisionTree<Key> logProbabilities =
DecisionTree<Key, double>(gmf->factors(), logProbability);
AlgebraicDecisionTree<Key> probabilities =
probabilitiesFromLogValues(logProbabilities);
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
probabilities);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here. // TODO(dellaert): is this correct? If so explain here.
@ -279,21 +311,32 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
using Result = std::pair<std::shared_ptr<GaussianConditional>, using Result = std::pair<std::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::sharedFactor>; GaussianMixtureFactor::sharedFactor>;
// Integrate the probability mass in the last continuous conditional using /**
// the unnormalized probability q(μ;m) = exp(-error(μ;m)) at the mean. * Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
// discrete_probability = exp(-error(μ;m)) * sqrt(det(2π Σ_m)) * from the residual error ||b||^2 at the mean μ.
* The residual error contains no keys, and only
* depends on the discrete separator if present.
*/
static std::shared_ptr<Factor> createDiscreteFactor( static std::shared_ptr<Factor> createDiscreteFactor(
const DecisionTree<Key, Result> &eliminationResults, const DecisionTree<Key, Result> &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto probability = [&](const Result &pair) -> double { auto logProbability = [&](const Result &pair) -> double {
const auto &[conditional, factor] = pair; const auto &[conditional, factor] = pair;
static const VectorValues kEmpty; static const VectorValues kEmpty;
// If the factor is not null, it has no keys, just contains the residual. // If the factor is not null, it has no keys, just contains the residual.
if (!factor) return 1.0; // TODO(dellaert): not loving this. if (!factor) return 1.0; // TODO(dellaert): not loving this.
return exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
return -factor->error(kEmpty) - conditional->logNormalizationConstant();
}; };
DecisionTree<Key, double> probabilities(eliminationResults, probability); AlgebraicDecisionTree<Key> logProbabilities(
DecisionTree<Key, double>(eliminationResults, logProbability));
AlgebraicDecisionTree<Key> probabilities =
probabilitiesFromLogValues(logProbabilities);
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities); return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
} }
@ -480,18 +523,9 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
std::inserter(continuousSeparator, continuousSeparator.begin())); std::inserter(continuousSeparator, continuousSeparator.begin()));
// Similarly for the discrete separator. // Similarly for the discrete separator.
KeySet discreteSeparatorSet; // Since we eliminate all continuous variables first,
std::set<DiscreteKey> discreteSeparator; // the discrete separator will be *all* the discrete keys.
auto discreteKeySet = factors.discreteKeySet(); std::set<DiscreteKey> discreteSeparator = factors.discreteKeys();
std::set_difference(
discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
frontalKeysSet.end(),
std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
// Convert from set of keys to set of DiscreteKeys
auto discreteKeyMap = factors.discreteKeyMap();
for (auto key : discreteSeparatorSet) {
discreteSeparator.insert(discreteKeyMap.at(key));
}
return hybridElimination(factors, frontalKeys, continuousSeparator, return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparator); discreteSeparator);
@ -504,10 +538,15 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
AlgebraicDecisionTree<Key> error_tree(0.0); AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor. // Iterate over each factor.
for (auto &f : factors_) { for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error; AlgebraicDecisionTree<Key> factor_error;
auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
// Compute factor error and add it. // Compute factor error and add it.
error_tree = error_tree + gaussianMixture->errorTree(continuousValues); error_tree = error_tree + gaussianMixture->errorTree(continuousValues);

View File

@ -144,6 +144,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
// const std::string& s = "HybridGaussianFactorGraph", // const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/**
* @brief Print the errors of each factor in the hybrid factor graph.
*
* @param values The HybridValues for the variables used to compute the error.
* @param str String that is output before the factor graph and errors.
* @param keyFormatter Formatter function for the keys in the factors.
* @param printCondition A condition to check if a factor should be printed.
*/
void printErrors( void printErrors(
const HybridValues& values, const HybridValues& values,
const std::string& str = "HybridGaussianFactorGraph: ", const std::string& str = "HybridGaussianFactorGraph: ",

View File

@ -22,9 +22,13 @@
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/PriorFactor.h>
#include <gtsam/slam/BetweenFactor.h>
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -32,8 +36,10 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using noiseModel::Isotropic; using noiseModel::Isotropic;
using symbol_shorthand::F;
using symbol_shorthand::M; using symbol_shorthand::M;
using symbol_shorthand::X; using symbol_shorthand::X;
using symbol_shorthand::Z;
/* ************************************************************************* */ /* ************************************************************************* */
// Check iterators of empty mixture. // Check iterators of empty mixture.
@ -56,7 +62,6 @@ TEST(GaussianMixtureFactor, Sum) {
auto b = Matrix::Zero(2, 1); auto b = Matrix::Zero(2, 1);
Vector2 sigmas; Vector2 sigmas;
sigmas << 1, 2; sigmas << 1, 2;
auto model = noiseModel::Diagonal::Sigmas(sigmas, true);
auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b); auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b); auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
@ -106,7 +111,8 @@ TEST(GaussianMixtureFactor, Printing) {
GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); GaussianMixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
std::string expected = std::string expected =
R"(Hybrid [x1 x2; 1]{ R"(GaussianMixtureFactor
Hybrid [x1 x2; 1]{
Choice(1) Choice(1)
0 Leaf : 0 Leaf :
A[x1] = [ A[x1] = [
@ -178,7 +184,8 @@ TEST(GaussianMixtureFactor, Error) {
continuousValues.insert(X(2), Vector2(1, 1)); continuousValues.insert(X(2), Vector2(1, 1));
// error should return a tree of errors, with nodes for each discrete value. // error should return a tree of errors, with nodes for each discrete value.
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues); AlgebraicDecisionTree<Key> error_tree =
mixtureFactor.errorTree(continuousValues);
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test // Error values for regression test
@ -191,8 +198,390 @@ TEST(GaussianMixtureFactor, Error) {
DiscreteValues discreteValues; DiscreteValues discreteValues;
discreteValues[m1.first] = 1; discreteValues[m1.first] = 1;
EXPECT_DOUBLES_EQUAL( EXPECT_DOUBLES_EQUAL(
4.0, mixtureFactor.error({continuousValues, discreteValues}), 4.0, mixtureFactor.error({continuousValues, discreteValues}), 1e-9);
1e-9); }
namespace test_gmm {
/**
* Function to compute P(m=1|z). For P(m=0|z), swap mus and sigmas.
* If sigma0 == sigma1, it simplifies to a sigmoid function.
*
* Follows equation 7.108 since it is more generic.
*/
double prob_m_z(double mu0, double mu1, double sigma0, double sigma1,
double z) {
double x1 = ((z - mu0) / sigma0), x2 = ((z - mu1) / sigma1);
double d = sigma1 / sigma0;
double e = d * std::exp(-0.5 * (x1 * x1 - x2 * x2));
return 1 / (1 + e);
};
static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1,
double sigma0, double sigma1) {
DiscreteKey m(M(0), 2);
Key z = Z(0);
auto model0 = noiseModel::Isotropic::Sigma(1, sigma0);
auto model1 = noiseModel::Isotropic::Sigma(1, sigma1);
auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model0),
c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1);
auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1});
auto mixing = new DiscreteConditional(m, "0.5/0.5");
HybridBayesNet hbn;
hbn.emplace_back(gm);
hbn.emplace_back(mixing);
return hbn;
}
} // namespace test_gmm
/* ************************************************************************* */
/**
* Test a simple Gaussian Mixture Model represented as P(m)P(z|m)
* where m is a discrete variable and z is a continuous variable.
* m is binary and depending on m, we have 2 different means
* μ1 and μ2 for the Gaussian distribution around which we sample z.
*
* The resulting factor graph should eliminate to a Bayes net
* which represents a sigmoid function.
*/
TEST(GaussianMixtureFactor, GaussianMixtureModel) {
using namespace test_gmm;
double mu0 = 1.0, mu1 = 3.0;
double sigma = 2.0;
DiscreteKey m(M(0), 2);
Key z = Z(0);
auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma, sigma);
// The result should be a sigmoid.
// So should be P(m=1|z) = 0.5 at z=3.0 - 1.0=2.0
double midway = mu1 - mu0, lambda = 4;
{
VectorValues given;
given.insert(z, Vector1(midway));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma, sigma, midway),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}),
1e-8);
// At the halfway point between the means, we should get P(m|z)=0.5
HybridBayesNet expected;
expected.emplace_back(new DiscreteConditional(m, "0.5/0.5"));
EXPECT(assert_equal(expected, *bn));
}
{
// Shift by -lambda
VectorValues given;
given.insert(z, Vector1(midway - lambda));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma, sigma, midway - lambda),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}),
1e-8);
}
{
// Shift by lambda
VectorValues given;
given.insert(z, Vector1(midway + lambda));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma, sigma, midway + lambda),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}),
1e-8);
}
}
/* ************************************************************************* */
/**
* Test a simple Gaussian Mixture Model represented as P(m)P(z|m)
* where m is a discrete variable and z is a continuous variable.
* m is binary and depending on m, we have 2 different means
* and covariances each for the
* Gaussian distribution around which we sample z.
*
* The resulting factor graph should eliminate to a Bayes net
* which represents a Gaussian-like function
* where m1>m0 close to 3.1333.
*/
TEST(GaussianMixtureFactor, GaussianMixtureModel2) {
using namespace test_gmm;
double mu0 = 1.0, mu1 = 3.0;
double sigma0 = 8.0, sigma1 = 4.0;
DiscreteKey m(M(0), 2);
Key z = Z(0);
auto hbn = GetGaussianMixtureModel(mu0, mu1, sigma0, sigma1);
double m1_high = 3.133, lambda = 4;
{
// The result should be a bell curve like function
// with m1 > m0 close to 3.1333.
// We get 3.1333 by finding the maximum value of the function.
VectorValues given;
given.insert(z, Vector1(3.133));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma0, sigma1, m1_high),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{M(0), 1}}), 1e-8);
// At the halfway point between the means
HybridBayesNet expected;
expected.emplace_back(new DiscreteConditional(
m, {},
vector<double>{prob_m_z(mu1, mu0, sigma1, sigma0, m1_high),
prob_m_z(mu0, mu1, sigma0, sigma1, m1_high)}));
EXPECT(assert_equal(expected, *bn));
}
{
// Shift by -lambda
VectorValues given;
given.insert(z, Vector1(m1_high - lambda));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma0, sigma1, m1_high - lambda),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}),
1e-8);
}
{
// Shift by lambda
VectorValues given;
given.insert(z, Vector1(m1_high + lambda));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
EXPECT_DOUBLES_EQUAL(
prob_m_z(mu0, mu1, sigma0, sigma1, m1_high + lambda),
bn->at(0)->asDiscrete()->operator()(DiscreteValues{{m.first, 1}}),
1e-8);
}
}
namespace test_two_state_estimation {
/// Create Two State Bayes Network with measurements
static HybridBayesNet CreateBayesNet(double mu0, double mu1, double sigma0,
double sigma1,
bool add_second_measurement = false,
double prior_sigma = 1e-3,
double measurement_sigma = 3.0) {
DiscreteKey m1(M(1), 2);
Key z0 = Z(0), z1 = Z(1);
Key x0 = X(0), x1 = X(1);
HybridBayesNet hbn;
auto measurement_model = noiseModel::Isotropic::Sigma(1, measurement_sigma);
// Add measurement P(z0 | x0)
auto p_z0 = new GaussianConditional(z0, Vector1(0.0), -I_1x1, x0, I_1x1,
measurement_model);
hbn.emplace_back(p_z0);
// Add hybrid motion model
auto model0 = noiseModel::Isotropic::Sigma(1, sigma0);
auto model1 = noiseModel::Isotropic::Sigma(1, sigma1);
auto c0 = make_shared<GaussianConditional>(x1, Vector1(mu0), I_1x1, x0,
-I_1x1, model0),
c1 = make_shared<GaussianConditional>(x1, Vector1(mu1), I_1x1, x0,
-I_1x1, model1);
auto motion = new GaussianMixture({x1}, {x0}, {m1}, {c0, c1});
hbn.emplace_back(motion);
if (add_second_measurement) {
// Add second measurement
auto p_z1 = new GaussianConditional(z1, Vector1(0.0), -I_1x1, x1, I_1x1,
measurement_model);
hbn.emplace_back(p_z1);
}
// Discrete uniform prior.
auto p_m1 = new DiscreteConditional(m1, "0.5/0.5");
hbn.emplace_back(p_m1);
return hbn;
}
} // namespace test_two_state_estimation
/* ************************************************************************* */
/**
* Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1).
*
* P(f01|x1,x0,m1) has different means and same covariance.
*
* Converting to a factor graph gives us
* ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1)
*
* If we only have a measurement on z0, then
* the probability of m1 should be 0.5/0.5.
* Getting a measurement on z1 gives use more information.
*/
TEST(GaussianMixtureFactor, TwoStateModel) {
using namespace test_two_state_estimation;
double mu0 = 1.0, mu1 = 3.0;
double sigma = 2.0;
DiscreteKey m1(M(1), 2);
Key z0 = Z(0), z1 = Z(1);
// Start with no measurement on x1, only on x0
HybridBayesNet hbn = CreateBayesNet(mu0, mu1, sigma, sigma, false);
VectorValues given;
given.insert(z0, Vector1(0.5));
{
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
// Since no measurement on x1, we hedge our bets
DiscreteConditional expected(m1, "0.5/0.5");
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete())));
}
{
// Now we add a measurement z1 on x1
hbn = CreateBayesNet(mu0, mu1, sigma, sigma, true);
// If we see z1=2.6 (> 2.5 which is the halfway point),
// discrete mode should say m1=1
given.insert(z1, Vector1(2.6));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
// Since we have a measurement on z2, we get a definite result
DiscreteConditional expected(m1, "0.49772729/0.50227271");
// regression
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 1e-6));
}
}
/* ************************************************************************* */
/**
* Test a model P(z0|x0)P(x1|x0,m1)P(z1|x1)P(m1).
*
* P(f01|x1,x0,m1) has different means and different covariances.
*
* Converting to a factor graph gives us
* ϕ(x0)ϕ(x1,x0,m1)ϕ(x1)P(m1)
*
* If we only have a measurement on z0, then
* the P(m1) should be 0.5/0.5.
* Getting a measurement on z1 gives use more information.
*/
TEST(GaussianMixtureFactor, TwoStateModel2) {
using namespace test_two_state_estimation;
double mu0 = 1.0, mu1 = 3.0;
double sigma0 = 6.0, sigma1 = 4.0;
auto model0 = noiseModel::Isotropic::Sigma(1, sigma0);
auto model1 = noiseModel::Isotropic::Sigma(1, sigma1);
DiscreteKey m1(M(1), 2);
Key z0 = Z(0), z1 = Z(1);
// Start with no measurement on x1, only on x0
HybridBayesNet hbn = CreateBayesNet(mu0, mu1, sigma0, sigma1, false);
VectorValues given;
given.insert(z0, Vector1(0.5));
{
// Start with no measurement on x1, only on x0
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
{
VectorValues vv{
{X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}, {Z(0), Vector1(0.5)}};
HybridValues hv0(vv, DiscreteValues{{M(1), 0}}),
hv1(vv, DiscreteValues{{M(1), 1}});
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
gfg.error(hv1) / hbn.error(hv1), 1e-9);
}
{
VectorValues vv{
{X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}, {Z(0), Vector1(0.5)}};
HybridValues hv0(vv, DiscreteValues{{M(1), 0}}),
hv1(vv, DiscreteValues{{M(1), 1}});
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
gfg.error(hv1) / hbn.error(hv1), 1e-9);
}
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
// Since no measurement on x1, we a 50/50 probability
auto p_m = bn->at(2)->asDiscrete();
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()(DiscreteValues{{m1.first, 0}}),
1e-9);
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()(DiscreteValues{{m1.first, 1}}),
1e-9);
}
{
// Now we add a measurement z1 on x1
hbn = CreateBayesNet(mu0, mu1, sigma0, sigma1, true);
given.insert(z1, Vector1(2.2));
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
{
VectorValues vv{{X(0), Vector1(0.0)},
{X(1), Vector1(1.0)},
{Z(0), Vector1(0.5)},
{Z(1), Vector1(2.2)}};
HybridValues hv0(vv, DiscreteValues{{M(1), 0}}),
hv1(vv, DiscreteValues{{M(1), 1}});
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
gfg.error(hv1) / hbn.error(hv1), 1e-9);
}
{
VectorValues vv{{X(0), Vector1(0.5)},
{X(1), Vector1(3.0)},
{Z(0), Vector1(0.5)},
{Z(1), Vector1(2.2)}};
HybridValues hv0(vv, DiscreteValues{{M(1), 0}}),
hv1(vv, DiscreteValues{{M(1), 1}});
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
gfg.error(hv1) / hbn.error(hv1), 1e-9);
}
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
// Since we have a measurement on z2, we get a definite result
DiscreteConditional expected(m1, "0.44744586/0.55255414");
// regression
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 1e-6));
}
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -200,4 +589,4 @@ int main() {
TestResult tr; TestResult tr;
return TestRegistry::runAllTests(tr); return TestRegistry::runAllTests(tr);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -598,6 +598,57 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
EXPECT(assert_equal(expected_probs, probs, 1e-7)); EXPECT(assert_equal(expected_probs, probs, 1e-7));
} }
/* ****************************************************************************/
// Test hybrid gaussian factor graph errorTree when there is a HybridConditional in the graph
TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
using symbol_shorthand::F;
DiscreteKey m1(M(1), 2);
Key z0 = Z(0), f01 = F(0);
Key x0 = X(0), x1 = X(1);
HybridBayesNet hbn;
auto prior_model = noiseModel::Isotropic::Sigma(1, 1e-1);
auto measurement_model = noiseModel::Isotropic::Sigma(1, 2.0);
// Set a prior P(x0) at x0=0
hbn.emplace_back(
new GaussianConditional(x0, Vector1(0.0), I_1x1, prior_model));
// Add measurement P(z0 | x0)
hbn.emplace_back(new GaussianConditional(z0, Vector1(0.0), -I_1x1, x0, I_1x1,
measurement_model));
// Add hybrid motion model
double mu = 0.0;
double sigma0 = 1e2, sigma1 = 1e-2;
auto model0 = noiseModel::Isotropic::Sigma(1, sigma0);
auto model1 = noiseModel::Isotropic::Sigma(1, sigma1);
auto c0 = make_shared<GaussianConditional>(f01, Vector1(mu), I_1x1, x1, I_1x1,
x0, -I_1x1, model0),
c1 = make_shared<GaussianConditional>(f01, Vector1(mu), I_1x1, x1, I_1x1,
x0, -I_1x1, model1);
hbn.emplace_back(new GaussianMixture({f01}, {x0, x1}, {m1}, {c0, c1}));
// Discrete uniform prior.
hbn.emplace_back(new DiscreteConditional(m1, "0.5/0.5"));
VectorValues given;
given.insert(z0, Vector1(0.0));
given.insert(f01, Vector1(0.0));
auto gfg = hbn.toFactorGraph(given);
VectorValues vv;
vv.insert(x0, Vector1(1.0));
vv.insert(x1, Vector1(2.0));
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
// regression
AlgebraicDecisionTree<Key> expected(m1, 59.335390372, 5050.125);
EXPECT(assert_equal(expected, errorTree, 1e-9));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each // Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment. // assignment.

View File

@ -510,6 +510,7 @@ factor 0:
b = [ -10 ] b = [ -10 ]
No noise model No noise model
factor 1: factor 1:
GaussianMixtureFactor
Hybrid [x0 x1; m0]{ Hybrid [x0 x1; m0]{
Choice(m0) Choice(m0)
0 Leaf : 0 Leaf :
@ -534,6 +535,7 @@ Hybrid [x0 x1; m0]{
} }
factor 2: factor 2:
GaussianMixtureFactor
Hybrid [x1 x2; m1]{ Hybrid [x1 x2; m1]{
Choice(m1) Choice(m1)
0 Leaf : 0 Leaf :
@ -675,6 +677,8 @@ factor 6: P( m1 | m0 ):
size: 3 size: 3
conditional 0: Hybrid P( x0 | x1 m0) conditional 0: Hybrid P( x0 | x1 m0)
Discrete Keys = (m0, 2), Discrete Keys = (m0, 2),
logNormalizationConstant: 1.38862
Choice(m0) Choice(m0)
0 Leaf p(x0 | x1) 0 Leaf p(x0 | x1)
R = [ 10.0499 ] R = [ 10.0499 ]
@ -692,6 +696,8 @@ conditional 0: Hybrid P( x0 | x1 m0)
conditional 1: Hybrid P( x1 | x2 m0 m1) conditional 1: Hybrid P( x1 | x2 m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
logNormalizationConstant: 1.3935
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x1 | x2) 0 0 Leaf p(x1 | x2)
@ -725,6 +731,8 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
conditional 2: Hybrid P( x2 | m0 m1) conditional 2: Hybrid P( x2 | m0 m1)
Discrete Keys = (m0, 2), (m1, 2), Discrete Keys = (m0, 2), (m1, 2),
logNormalizationConstant: 1.38857
Choice(m1) Choice(m1)
0 Choice(m0) 0 Choice(m0)
0 0 Leaf p(x2) 0 0 Leaf p(x2)

View File

@ -18,6 +18,9 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/hybrid/MixtureFactor.h> #include <gtsam/hybrid/MixtureFactor.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/slam/BetweenFactor.h> #include <gtsam/slam/BetweenFactor.h>

View File

@ -263,11 +263,6 @@ namespace gtsam {
/** equals required by Testable for unit testing */ /** equals required by Testable for unit testing */
bool equals(const VectorValues& x, double tol = 1e-9) const; bool equals(const VectorValues& x, double tol = 1e-9) const;
/// Check equality.
friend bool operator==(const VectorValues& lhs, const VectorValues& rhs) {
return lhs.equals(rhs);
}
/// @{ /// @{
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{