Merge pull request #1865 from borglab/feature/no_hiding-2

Updates to `No Hiding` PR
release/4.3a0
Frank Dellaert 2024-10-09 13:37:30 +09:00 committed by GitHub
commit 59f97d64eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 141 additions and 126 deletions

View File

@ -22,14 +22,13 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesTree-inst.h> #include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h> #include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/linear/GaussianJunctionTree.h> #include <gtsam/linear/GaussianJunctionTree.h>
#include <memory> #include <memory>
#include "gtsam/hybrid/HybridConditional.h"
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class

View File

@ -13,6 +13,7 @@
* @file HybridConditional.cpp * @file HybridConditional.cpp
* @date Mar 11, 2022 * @date Mar 11, 2022
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal
*/ */
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>

View File

@ -13,6 +13,7 @@
* @file HybridConditional.h * @file HybridConditional.h
* @date Mar 11, 2022 * @date Mar 11, 2022
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal
*/ */
#pragma once #pragma once

View File

@ -79,7 +79,7 @@ struct HybridGaussianConditional::Helper {
explicit Helper(const Conditionals &conditionals) explicit Helper(const Conditionals &conditionals)
: conditionals(conditionals), : conditionals(conditionals),
minNegLogConstant(std::numeric_limits<double>::infinity()) { minNegLogConstant(std::numeric_limits<double>::infinity()) {
auto func = [this](const GC::shared_ptr& gc) -> GaussianFactorValuePair { auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()}; if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
if (!nrFrontals) nrFrontals = gc->nrFrontals(); if (!nrFrontals) nrFrontals = gc->nrFrontals();
double value = gc->negLogConstant(); double value = gc->negLogConstant();
@ -97,10 +97,10 @@ struct HybridGaussianConditional::Helper {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys& discreteParents, const Helper& helper) const DiscreteKeys &discreteParents, const Helper &helper)
: BaseFactor(discreteParents, : BaseFactor(discreteParents,
FactorValuePairs(helper.pairs, FactorValuePairs(helper.pairs,
[&](const GaussianFactorValuePair& [&](const GaussianFactorValuePair &
pair) { // subtract minNegLogConstant pair) { // subtract minNegLogConstant
return GaussianFactorValuePair{ return GaussianFactorValuePair{
pair.first, pair.first,
@ -183,10 +183,12 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
// Check the base and the factors: // Check the base and the factors:
return BaseFactor::equals(*e, tol) && return BaseFactor::equals(*e, tol) &&
conditionals_.equals( conditionals_.equals(e->conditionals_,
e->conditionals_, [tol](const auto &f1, const auto &f2) { [tol](const GaussianConditional::shared_ptr &f1,
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); const GaussianConditional::shared_ptr &f2) {
}); return (!f1 && !f2) ||
(f1 && f2 && f1->equals(*f2, tol));
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -225,7 +227,7 @@ KeyVector HybridGaussianConditional::continuousParents() const {
// remove that key from continuousParentKeys: // remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(), continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key), continuousParentKeys.end(), key),
continuousParentKeys.end()); continuousParentKeys.end());
} }
return continuousParentKeys; return continuousParentKeys;
} }

View File

@ -24,9 +24,9 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>

View File

@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
// Build the FactorValuePairs DecisionTree // Build the FactorValuePairs DecisionTree
pairs = FactorValuePairs( pairs = FactorValuePairs(
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors), DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
[](const auto& f) { [](const sharedFactor& f) {
return std::pair{f, return std::pair{f,
f ? 0.0 : std::numeric_limits<double>::infinity()}; f ? 0.0 : std::numeric_limits<double>::infinity()};
}); });
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
const std::vector<GaussianFactorValuePair>& factorPairs) const std::vector<GaussianFactorValuePair>& factorPairs)
: discreteKeys({discreteKey}) { : discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor // Extract continuous keys from the first non-null factor
for (const auto& pair : factorPairs) { for (const GaussianFactorValuePair& pair : factorPairs) {
if (pair.first && continuousKeys.empty()) { if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys(); continuousKeys = pair.first->keys();
break; break;
@ -93,27 +93,27 @@ struct HybridGaussianFactor::ConstructorHelper {
}; };
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper) HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
: Base(helper.continuousKeys, helper.discreteKeys), : Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.pairs) {} factors_(helper.pairs) {}
HybridGaussianFactor::HybridGaussianFactor( HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey, const DiscreteKey& discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors) const std::vector<GaussianFactor::shared_ptr>& factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {} : HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
HybridGaussianFactor::HybridGaussianFactor( HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey, const DiscreteKey& discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs) const std::vector<GaussianFactorValuePair>& factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {} : HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
const FactorValuePairs &factors) const FactorValuePairs& factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {} : HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This* e = dynamic_cast<const This*>(&lf);
if (e == nullptr) return false; if (e == nullptr) return false;
// This will return false if either factors_ is empty or e->factors_ is // This will return false if either factors_ is empty or e->factors_ is
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
if (factors_.empty() ^ e->factors_.empty()) return false; if (factors_.empty() ^ e->factors_.empty()) return false;
// Check the base and the factors: // Check the base and the factors:
auto compareFunc = [tol](const auto& pair1, const auto& pair2) { auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
const GaussianFactorValuePair& pair2) {
auto f1 = pair1.first, f2 = pair2.first; auto f1 = pair1.first, f2 = pair2.first;
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return match && gtsam::equal(pair1.second, pair2.second, tol); return match && gtsam::equal(pair1.second, pair2.second, tol);
@ -130,8 +131,8 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string& s,
const KeyFormatter &formatter) const { const KeyFormatter& formatter) const {
std::cout << (s.empty() ? "" : s + "\n"); std::cout << (s.empty() ? "" : s + "\n");
HybridFactor::print("", formatter); HybridFactor::print("", formatter);
std::cout << "{\n"; std::cout << "{\n";
@ -139,9 +140,8 @@ void HybridGaussianFactor::print(const std::string &s,
std::cout << " empty" << std::endl; std::cout << " empty" << std::endl;
} else { } else {
factors_.print( factors_.print(
"", "", [&](Key k) { return formatter(k); },
[&](Key k) { return formatter(k); }, [&](const GaussianFactorValuePair& pair) -> std::string {
[&](const auto& pair) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (pair.first) { if (pair.first) {
@ -158,7 +158,7 @@ void HybridGaussianFactor::print(const std::string &s,
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorValuePair HybridGaussianFactor::operator()( GaussianFactorValuePair HybridGaussianFactor::operator()(
const DiscreteValues &assignment) const { const DiscreteValues& assignment) const {
return factors_(assignment); return factors_(assignment);
} }
@ -169,18 +169,25 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
// - Each leaf converted to a GaussianFactorGraph with just the factor and its // - Each leaf converted to a GaussianFactorGraph with just the factor and its
// scalar. // scalar.
return {{factors_, return {{factors_,
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> { [](const GaussianFactorValuePair& pair)
-> std::pair<GaussianFactorGraph, double> {
return {GaussianFactorGraph{pair.first}, pair.second}; return {GaussianFactorGraph{pair.first}, pair.second};
}}}; }}};
} }
/* *******************************************************************************/
inline static double PotentiallyPrunedComponentError(
const GaussianFactorValuePair& pair, const VectorValues& continuousValues) {
return pair.first ? pair.first->error(continuousValues) + pair.second
: std::numeric_limits<double>::infinity();
}
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues& continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const auto& pair) { auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
return pair.first ? pair.first->error(continuousValues) + pair.second return PotentiallyPrunedComponentError(pair, continuousValues);
: std::numeric_limits<double>::infinity();
}; };
DecisionTree<Key, double> error_tree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree; return error_tree;
@ -189,9 +196,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues& values) const { double HybridGaussianFactor::error(const HybridValues& values) const {
// Directly index to get the component, no need to build the whole tree. // Directly index to get the component, no need to build the whole tree.
const auto pair = factors_(values.discrete()); const GaussianFactorValuePair pair = factors_(values.discrete());
return pair.first ? pair.first->error(values.continuous()) + pair.second return PotentiallyPrunedComponentError(pair, values.continuous());
: std::numeric_limits<double>::infinity();
} }
} // namespace gtsam } // namespace gtsam

View File

@ -58,7 +58,7 @@ using GaussianFactorValuePair = std::pair<GaussianFactor::shared_ptr, double>;
* @ingroup hybrid * @ingroup hybrid
*/ */
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
public: public:
using Base = HybridFactor; using Base = HybridFactor;
using This = HybridGaussianFactor; using This = HybridGaussianFactor;
using shared_ptr = std::shared_ptr<This>; using shared_ptr = std::shared_ptr<This>;
@ -68,11 +68,11 @@ public:
/// typedef for Decision Tree of Gaussian factors and arbitrary value. /// typedef for Decision Tree of Gaussian factors and arbitrary value.
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>; using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
FactorValuePairs factors_; FactorValuePairs factors_;
public: public:
/// @name Constructors /// @name Constructors
/// @{ /// @{
@ -120,9 +120,8 @@ public:
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void void print(const std::string &s = "", const KeyFormatter &formatter =
print(const std::string &s = "", DefaultKeyFormatter) const override;
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API /// @name Standard API
@ -138,8 +137,8 @@ public:
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error. * as the factors involved, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> AlgebraicDecisionTree<Key> errorTree(
errorTree(const VectorValues &continuousValues) const override; const VectorValues &continuousValues) const override;
/** /**
* @brief Compute the log-likelihood, including the log-normalizing constant. * @brief Compute the log-likelihood, including the log-normalizing constant.
@ -159,7 +158,7 @@ public:
/// @} /// @}
private: private:
/** /**
* @brief Helper function to augment the [A|b] matrices in the factor * @brief Helper function to augment the [A|b] matrices in the factor
* components with the additional scalar values. This is done by storing the * components with the additional scalar values. This is done by storing the

View File

@ -24,6 +24,7 @@
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
@ -39,7 +40,6 @@
#include <gtsam/linear/GaussianJunctionTree.h> #include <gtsam/linear/GaussianJunctionTree.h>
#include <gtsam/linear/HessianFactor.h> #include <gtsam/linear/HessianFactor.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include "gtsam/discrete/DiscreteValues.h"
#include <cstddef> #include <cstddef>
#include <iostream> #include <iostream>
@ -57,15 +57,16 @@ using std::dynamic_pointer_cast;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>; using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using Result = using Result =
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>; std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
using ResultTree = DecisionTree<Key, std::pair<Result, double>>; using ResultValuePair = std::pair<Result, double>;
using ResultTree = DecisionTree<Key, ResultValuePair>;
static const VectorValues kEmpty; static const VectorValues kEmpty;
/* ************************************************************************ */ /* ************************************************************************ */
// Throw a runtime exception for method specified in string s, and factor f: // Throw a runtime exception for method specified in string s, and factor f:
static void throwRuntimeError(const std::string& s, static void throwRuntimeError(const std::string &s,
const std::shared_ptr<Factor>& f) { const std::shared_ptr<Factor> &f) {
auto& fr = *f; auto &fr = *f;
throw std::runtime_error(s + " not implemented for factor type " + throw std::runtime_error(s + " not implemented for factor type " +
demangle(typeid(fr).name()) + "."); demangle(typeid(fr).name()) + ".");
} }
@ -83,11 +84,12 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
const DiscreteValues &assignment, const DiscreteValues &assignment,
const KeyFormatter &keyFormatter) { const KeyFormatter &keyFormatter) {
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(factor)) { if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
if (assignment.empty()) if (assignment.empty()) {
hgf->print("HybridGaussianFactor:", keyFormatter); hgf->print("HybridGaussianFactor:", keyFormatter);
else } else {
hgf->operator()(assignment) hgf->operator()(assignment)
.first->print("HybridGaussianFactor, component:", keyFormatter); .first->print("HybridGaussianFactor, component:", keyFormatter);
}
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) { } else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) {
factor->print("GaussianFactor:\n", keyFormatter); factor->print("GaussianFactor:\n", keyFormatter);
@ -99,12 +101,13 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
} else if (hc->isDiscrete()) { } else if (hc->isDiscrete()) {
factor->print("DiscreteConditional:\n", keyFormatter); factor->print("DiscreteConditional:\n", keyFormatter);
} else { } else {
if (assignment.empty()) if (assignment.empty()) {
hc->print("HybridConditional:", keyFormatter); hc->print("HybridConditional:", keyFormatter);
else } else {
hc->asHybrid() hc->asHybrid()
->choose(assignment) ->choose(assignment)
->print("HybridConditional, component:\n", keyFormatter); ->print("HybridConditional, component:\n", keyFormatter);
}
} }
} else { } else {
factor->print("Unknown factor type\n", keyFormatter); factor->print("Unknown factor type\n", keyFormatter);
@ -112,13 +115,13 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
} }
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::print(const std::string& s, void HybridGaussianFactorGraph::print(const std::string &s,
const KeyFormatter& keyFormatter) const { const KeyFormatter &keyFormatter) const {
std::cout << (s.empty() ? "" : s + " ") << std::endl; std::cout << (s.empty() ? "" : s + " ") << std::endl;
std::cout << "size: " << size() << std::endl; std::cout << "size: " << size() << std::endl;
for (size_t i = 0; i < factors_.size(); i++) { for (size_t i = 0; i < factors_.size(); i++) {
auto&& factor = factors_[i]; auto &&factor = factors_[i];
if (factor == nullptr) { if (factor == nullptr) {
std::cout << "Factor " << i << ": nullptr\n"; std::cout << "Factor " << i << ": nullptr\n";
continue; continue;
@ -163,7 +166,7 @@ HybridGaussianProductFactor HybridGaussianFactorGraph::collectProductFactor()
const { const {
HybridGaussianProductFactor result; HybridGaussianProductFactor result;
for (auto& f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): can we make this cleaner and less error-prone? // TODO(dellaert): can we make this cleaner and less error-prone?
if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
continue; // Ignore OrphanWrapper continue; // Ignore OrphanWrapper
@ -235,7 +238,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys. // Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities. // In this case, compute discrete probabilities.
auto potential = [&](const auto& pair) -> double { auto potential = [&](const auto &pair) -> double {
auto [factor, scalar] = pair; auto [factor, scalar] = pair;
// If factor is null, it has been pruned, hence return potential zero // If factor is null, it has been pruned, hence return potential zero
if (!factor) return 0.0; if (!factor) return 0.0;
@ -270,10 +273,10 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
* depends on the discrete separator if present. * depends on the discrete separator if present.
*/ */
static std::shared_ptr<Factor> createDiscreteFactor( static std::shared_ptr<Factor> createDiscreteFactor(
const ResultTree& eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
auto potential = [&](const auto &pair) -> double { auto potential = [&](const auto &pair) -> double {
const auto& [conditional, factor] = pair.first; const auto &[conditional, factor] = pair.first;
const double scalar = pair.second; const double scalar = pair.second;
if (conditional && factor) { if (conditional && factor) {
// `error` has the following contributions: // `error` has the following contributions:
@ -303,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
const ResultTree &eliminationResults, const ResultTree &eliminationResults,
const DiscreteKeys &discreteSeparator) { const DiscreteKeys &discreteSeparator) {
// Correct for the normalization constant used up by the conditional // Correct for the normalization constant used up by the conditional
auto correct = [&](const auto &pair) -> GaussianFactorValuePair { auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
const auto &[conditional, factor] = pair.first; const auto &[conditional, factor] = pair.first;
const double scalar = pair.second; const double scalar = pair.second;
if (conditional && factor) { if (conditional && factor) {
@ -350,9 +353,9 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
bool someContinuousLeft = false; bool someContinuousLeft = false;
auto eliminate = [&](const std::pair<GaussianFactorGraph, double>& pair) auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
-> std::pair<Result, double> { -> std::pair<Result, double> {
const auto& [graph, scalar] = pair; const auto &[graph, scalar] = pair;
if (graph.empty()) { if (graph.empty()) {
return {{nullptr, nullptr}, 0.0}; return {{nullptr, nullptr}, 0.0};
@ -382,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::Conditionals conditionals(
eliminationResults, [](const auto& pair) { return pair.first.first; }); eliminationResults,
[](const ResultValuePair &pair) { return pair.first.first; });
auto hybridGaussian = std::make_shared<HybridGaussianConditional>( auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
discreteSeparator, conditionals); discreteSeparator, conditionals);

View File

@ -145,9 +145,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Testable /// @name Testable
/// @{ /// @{
void void print(
print(const std::string &s = "HybridGaussianFactorGraph", const std::string& s = "HybridGaussianFactorGraph",
const KeyFormatter &keyFormatter = DefaultKeyFormatter) const override; const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
/** /**
* @brief Print the errors of each factor in the hybrid factor graph. * @brief Print the errors of each factor in the hybrid factor graph.

View File

@ -22,43 +22,52 @@
#include <gtsam/hybrid/HybridGaussianProductFactor.h> #include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include <string>
namespace gtsam { namespace gtsam {
using Y = HybridGaussianProductFactor::Y; using Y = GaussianFactorGraphValuePair;
/* *******************************************************************************/
static Y add(const Y& y1, const Y& y2) { static Y add(const Y& y1, const Y& y2) {
GaussianFactorGraph result = y1.first; GaussianFactorGraph result = y1.first;
result.push_back(y2.first); result.push_back(y2.first);
return {result, y1.second + y2.second}; return {result, y1.second + y2.second};
}; };
/* *******************************************************************************/
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a, HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
const HybridGaussianProductFactor& b) { const HybridGaussianProductFactor& b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add)); return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
} }
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+( HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor& factor) const { const HybridGaussianFactor& factor) const {
return *this + factor.asProductFactor(); return *this + factor.asProductFactor();
} }
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+( HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr& factor) const { const GaussianFactor::shared_ptr& factor) const {
return *this + HybridGaussianProductFactor(factor); return *this + HybridGaussianProductFactor(factor);
} }
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr& factor) { const GaussianFactor::shared_ptr& factor) {
*this = *this + factor; *this = *this + factor;
return *this; return *this;
} }
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const HybridGaussianFactor& factor) { const HybridGaussianFactor& factor) {
*this = *this + factor; *this = *this + factor;
return *this; return *this;
} }
/* *******************************************************************************/
void HybridGaussianProductFactor::print(const std::string& s, void HybridGaussianProductFactor::print(const std::string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
KeySet keys; KeySet keys;
@ -69,18 +78,25 @@ void HybridGaussianProductFactor::print(const std::string& s,
}; };
Base::print(s, formatter, printer); Base::print(s, formatter, printer);
if (!keys.empty()) { if (!keys.empty()) {
std::stringstream ss; std::cout << s << " Keys:";
ss << s << " Keys:"; for (auto&& key : keys) std::cout << " " << formatter(key);
for (auto&& key : keys) ss << " " << formatter(key); std::cout << "." << std::endl;
std::cout << ss.str() << "." << std::endl;
} }
} }
/* *******************************************************************************/
bool HybridGaussianProductFactor::equals(
const HybridGaussianProductFactor& other, double tol) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
});
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const { HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const Y& y) { auto emptyGaussian = [](const Y& y) {
bool hasNull = bool hasNull =
std::any_of(y.first.begin(), std::any_of(y.first.begin(), y.first.end(),
y.first.end(),
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; }); [](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y; return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
}; };

View File

@ -26,12 +26,13 @@ namespace gtsam {
class HybridGaussianFactor; class HybridGaussianFactor;
using GaussianFactorGraphValuePair = std::pair<GaussianFactorGraph, double>;
/// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums /// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums
class HybridGaussianProductFactor class GTSAM_EXPORT HybridGaussianProductFactor
: public DecisionTree<Key, std::pair<GaussianFactorGraph, double>> { : public DecisionTree<Key, GaussianFactorGraphValuePair> {
public: public:
using Y = std::pair<GaussianFactorGraph, double>; using Base = DecisionTree<Key, GaussianFactorGraphValuePair>;
using Base = DecisionTree<Key, Y>;
/// @name Constructors /// @name Constructors
/// @{ /// @{
@ -46,7 +47,7 @@ class HybridGaussianProductFactor
*/ */
template <class FACTOR> template <class FACTOR>
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor) HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor)
: Base(Y{GaussianFactorGraph{factor}, 0.0}) {} : Base(GaussianFactorGraphValuePair{GaussianFactorGraph{factor}, 0.0}) {}
/** /**
* @brief Construct from DecisionTree * @brief Construct from DecisionTree
@ -94,12 +95,7 @@ class HybridGaussianProductFactor
* @return true if equal, false otherwise * @return true if equal, false otherwise
*/ */
bool equals(const HybridGaussianProductFactor& other, bool equals(const HybridGaussianProductFactor& other,
double tol = 1e-9) const { double tol = 1e-9) const;
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) &&
std::abs(a.second - b.second) < tol;
});
}
/// @} /// @}

View File

@ -199,7 +199,7 @@ TEST(HybridBayesNet, Tiny) {
factors_x0.push_back(fg.at(1)); factors_x0.push_back(fg.at(1));
auto productFactor = factors_x0.collectProductFactor(); auto productFactor = factors_x0.collectProductFactor();
// Check that scalars are 0 and 1.79 // Check that scalars are 0 and 1.79 (regression)
EXPECT_DOUBLES_EQUAL(0.0, productFactor({{M(0), 0}}).second, 1e-9); EXPECT_DOUBLES_EQUAL(0.0, productFactor({{M(0), 0}}).second, 1e-9);
EXPECT_DOUBLES_EQUAL(1.791759, productFactor({{M(0), 1}}).second, 1e-5); EXPECT_DOUBLES_EQUAL(1.791759, productFactor({{M(0), 1}}).second, 1e-5);

View File

@ -22,6 +22,8 @@
#include <gtsam/hybrid/HybridGaussianISAM.h> #include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
#include <numeric>
#include "Switching.h" #include "Switching.h"
// Include for test suite // Include for test suite
@ -62,7 +64,8 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
} // namespace two } // namespace two
/* ************************************************************************* */ /* ************************************************************************* */
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { TEST(HybridGaussianFactorGraph,
HybridGaussianFactorGraphEliminateFullMultifrontalSimple) {
HybridGaussianFactorGraph hfg; HybridGaussianFactorGraph hfg;
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
@ -179,10 +182,8 @@ TEST(HybridGaussianFactorGraph, Switching) {
std::vector<int> naturalX(N); std::vector<int> naturalX(N);
std::iota(naturalX.begin(), naturalX.end(), 1); std::iota(naturalX.begin(), naturalX.end(), 1);
std::vector<Key> ordX; std::vector<Key> ordX;
std::transform( std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
naturalX.begin(), naturalX.end(), std::back_inserter(ordX), [](int x) { [](int x) { return X(x); });
return X(x);
});
auto [ndX, lvls] = makeBinaryOrdering(ordX); auto [ndX, lvls] = makeBinaryOrdering(ordX);
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
@ -195,10 +196,8 @@ TEST(HybridGaussianFactorGraph, Switching) {
std::vector<int> naturalC(N - 1); std::vector<int> naturalC(N - 1);
std::iota(naturalC.begin(), naturalC.end(), 1); std::iota(naturalC.begin(), naturalC.end(), 1);
std::vector<Key> ordC; std::vector<Key> ordC;
std::transform( std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
naturalC.begin(), naturalC.end(), std::back_inserter(ordC), [](int x) { [](int x) { return M(x); });
return M(x);
});
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
const auto [ndC, lvls] = makeBinaryOrdering(ordC); const auto [ndC, lvls] = makeBinaryOrdering(ordC);
@ -237,10 +236,8 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
std::vector<int> naturalX(N); std::vector<int> naturalX(N);
std::iota(naturalX.begin(), naturalX.end(), 1); std::iota(naturalX.begin(), naturalX.end(), 1);
std::vector<Key> ordX; std::vector<Key> ordX;
std::transform( std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
naturalX.begin(), naturalX.end(), std::back_inserter(ordX), [](int x) { [](int x) { return X(x); });
return X(x);
});
auto [ndX, lvls] = makeBinaryOrdering(ordX); auto [ndX, lvls] = makeBinaryOrdering(ordX);
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
@ -253,10 +250,8 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
std::vector<int> naturalC(N - 1); std::vector<int> naturalC(N - 1);
std::iota(naturalC.begin(), naturalC.end(), 1); std::iota(naturalC.begin(), naturalC.end(), 1);
std::vector<Key> ordC; std::vector<Key> ordC;
std::transform( std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
naturalC.begin(), naturalC.end(), std::back_inserter(ordC), [](int x) { [](int x) { return M(x); });
return M(x);
});
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
const auto [ndC, lvls] = makeBinaryOrdering(ordC); const auto [ndC, lvls] = makeBinaryOrdering(ordC);

View File

@ -17,6 +17,8 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/base/Vector.h> #include <gtsam/base/Vector.h>
@ -37,9 +39,6 @@
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/JacobianFactor.h> #include <gtsam/linear/JacobianFactor.h>
#include <CppUnitLite/Test.h>
#include <CppUnitLite/TestHarness.h>
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include <vector> #include <vector>
@ -73,8 +72,8 @@ TEST(HybridGaussianFactorGraph, Creation) {
HybridGaussianConditional gm( HybridGaussianConditional gm(
m0, m0,
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3), {std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
std::make_shared<GaussianConditional>( std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1),
X(0), Vector3::Ones(), I_3x3, X(1), I_3x3)}); I_3x3)});
hfg.add(gm); hfg.add(gm);
EXPECT_LONGS_EQUAL(2, hfg.size()); EXPECT_LONGS_EQUAL(2, hfg.size());
@ -118,8 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second); auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
CHECK(factor); CHECK(factor);
// regression test // regression test
EXPECT( EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -177,7 +175,7 @@ TEST(HybridBayesNet, Switching) {
Switching s(2, betweenSigma, priorSigma); Switching s(2, betweenSigma, priorSigma);
// Check size of linearized factor graph // Check size of linearized factor graph
const HybridGaussianFactorGraph& graph = s.linearizedFactorGraph; const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
EXPECT_LONGS_EQUAL(4, graph.size()); EXPECT_LONGS_EQUAL(4, graph.size());
// Create some continuous and discrete values // Create some continuous and discrete values
@ -203,20 +201,20 @@ TEST(HybridBayesNet, Switching) {
// Check error for M(0) = 0 // Check error for M(0) = 0
const HybridValues values0{continuousValues, modeZero}; const HybridValues values0{continuousValues, modeZero};
double expectedError0 = 0; double expectedError0 = 0;
for (const auto& factor : graph) expectedError0 += factor->error(values0); for (const auto &factor : graph) expectedError0 += factor->error(values0);
EXPECT_DOUBLES_EQUAL(expectedError0, graph.error(values0), 1e-5); EXPECT_DOUBLES_EQUAL(expectedError0, graph.error(values0), 1e-5);
// Check error for M(0) = 1 // Check error for M(0) = 1
const HybridValues values1{continuousValues, modeOne}; const HybridValues values1{continuousValues, modeOne};
double expectedError1 = 0; double expectedError1 = 0;
for (const auto& factor : graph) expectedError1 += factor->error(values1); for (const auto &factor : graph) expectedError1 += factor->error(values1);
EXPECT_DOUBLES_EQUAL(expectedError1, graph.error(values1), 1e-5); EXPECT_DOUBLES_EQUAL(expectedError1, graph.error(values1), 1e-5);
// Check errorTree // Check errorTree
AlgebraicDecisionTree<Key> actualErrors = graph.errorTree(continuousValues); AlgebraicDecisionTree<Key> actualErrors = graph.errorTree(continuousValues);
// Create expected error tree // Create expected error tree
const AlgebraicDecisionTree<Key> expectedErrors( const AlgebraicDecisionTree<Key> expectedErrors(M(0), expectedError0,
M(0), expectedError0, expectedError1); expectedError1);
// Check that the actual error tree matches the expected one // Check that the actual error tree matches the expected one
EXPECT(assert_equal(expectedErrors, actualErrors, 1e-5)); EXPECT(assert_equal(expectedErrors, actualErrors, 1e-5));
@ -232,8 +230,8 @@ TEST(HybridBayesNet, Switching) {
const AlgebraicDecisionTree<Key> graphPosterior = const AlgebraicDecisionTree<Key> graphPosterior =
graph.discretePosterior(continuousValues); graph.discretePosterior(continuousValues);
const double sum = probPrime0 + probPrime1; const double sum = probPrime0 + probPrime1;
const AlgebraicDecisionTree<Key> expectedPosterior( const AlgebraicDecisionTree<Key> expectedPosterior(M(0), probPrime0 / sum,
M(0), probPrime0 / sum, probPrime1 / sum); probPrime1 / sum);
EXPECT(assert_equal(expectedPosterior, graphPosterior, 1e-5)); EXPECT(assert_equal(expectedPosterior, graphPosterior, 1e-5));
// Make the clique of factors connected to x0: // Make the clique of factors connected to x0:
@ -275,15 +273,13 @@ TEST(HybridBayesNet, Switching) {
// Check that the scalars incorporate the negative log constant of the // Check that the scalars incorporate the negative log constant of the
// conditional // conditional
EXPECT_DOUBLES_EQUAL(scalar0 - (*p_x0_given_x1_m)(modeZero)->negLogConstant(), EXPECT_DOUBLES_EQUAL(scalar0 - (*p_x0_given_x1_m)(modeZero)->negLogConstant(),
(*phi_x1_m)(modeZero).second, (*phi_x1_m)(modeZero).second, 1e-9);
1e-9);
EXPECT_DOUBLES_EQUAL(scalar1 - (*p_x0_given_x1_m)(modeOne)->negLogConstant(), EXPECT_DOUBLES_EQUAL(scalar1 - (*p_x0_given_x1_m)(modeOne)->negLogConstant(),
(*phi_x1_m)(modeOne).second, (*phi_x1_m)(modeOne).second, 1e-9);
1e-9);
// Check that the conditional and remaining factor are consistent for both // Check that the conditional and remaining factor are consistent for both
// modes // modes
for (auto&& mode : {modeZero, modeOne}) { for (auto &&mode : {modeZero, modeOne}) {
const auto gc = (*p_x0_given_x1_m)(mode); const auto gc = (*p_x0_given_x1_m)(mode);
const auto [gf, scalar] = (*phi_x1_m)(mode); const auto [gf, scalar] = (*phi_x1_m)(mode);
@ -342,7 +338,7 @@ TEST(HybridBayesNet, Switching) {
// However, we can still check the total error for the clique factors_x1 and // However, we can still check the total error for the clique factors_x1 and
// the elimination results are equal, modulo -again- the negative log constant // the elimination results are equal, modulo -again- the negative log constant
// of the conditional. // of the conditional.
for (auto&& mode : {modeZero, modeOne}) { for (auto &&mode : {modeZero, modeOne}) {
auto gc_x1 = (*p_x1_given_m)(mode); auto gc_x1 = (*p_x1_given_m)(mode);
double originalError_x1 = factors_x1.error({continuousValues, mode}); double originalError_x1 = factors_x1.error({continuousValues, mode});
const double actualError = gc_x1->negLogConstant() + const double actualError = gc_x1->negLogConstant() +
@ -372,7 +368,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
Switching s(3); Switching s(3);
// Check size of linearized factor graph // Check size of linearized factor graph
const HybridGaussianFactorGraph& graph = s.linearizedFactorGraph; const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
EXPECT_LONGS_EQUAL(7, graph.size()); EXPECT_LONGS_EQUAL(7, graph.size());
// Eliminate the graph // Eliminate the graph