commit
59f97d64eb
|
@ -22,14 +22,13 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/inference/BayesTree-inst.h>
|
#include <gtsam/inference/BayesTree-inst.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "gtsam/hybrid/HybridConditional.h"
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
* @file HybridConditional.cpp
|
* @file HybridConditional.cpp
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
* @file HybridConditional.h
|
* @file HybridConditional.h
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
|
@ -79,7 +79,7 @@ struct HybridGaussianConditional::Helper {
|
||||||
explicit Helper(const Conditionals &conditionals)
|
explicit Helper(const Conditionals &conditionals)
|
||||||
: conditionals(conditionals),
|
: conditionals(conditionals),
|
||||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
||||||
auto func = [this](const GC::shared_ptr& gc) -> GaussianFactorValuePair {
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||||
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
double value = gc->negLogConstant();
|
double value = gc->negLogConstant();
|
||||||
|
@ -97,10 +97,10 @@ struct HybridGaussianConditional::Helper {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys& discreteParents, const Helper& helper)
|
const DiscreteKeys &discreteParents, const Helper &helper)
|
||||||
: BaseFactor(discreteParents,
|
: BaseFactor(discreteParents,
|
||||||
FactorValuePairs(helper.pairs,
|
FactorValuePairs(helper.pairs,
|
||||||
[&](const GaussianFactorValuePair&
|
[&](const GaussianFactorValuePair &
|
||||||
pair) { // subtract minNegLogConstant
|
pair) { // subtract minNegLogConstant
|
||||||
return GaussianFactorValuePair{
|
return GaussianFactorValuePair{
|
||||||
pair.first,
|
pair.first,
|
||||||
|
@ -183,10 +183,12 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return BaseFactor::equals(*e, tol) &&
|
return BaseFactor::equals(*e, tol) &&
|
||||||
conditionals_.equals(
|
conditionals_.equals(e->conditionals_,
|
||||||
e->conditionals_, [tol](const auto &f1, const auto &f2) {
|
[tol](const GaussianConditional::shared_ptr &f1,
|
||||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
const GaussianConditional::shared_ptr &f2) {
|
||||||
});
|
return (!f1 && !f2) ||
|
||||||
|
(f1 && f2 && f1->equals(*f2, tol));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -225,7 +227,7 @@ KeyVector HybridGaussianConditional::continuousParents() const {
|
||||||
// remove that key from continuousParentKeys:
|
// remove that key from continuousParentKeys:
|
||||||
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
|
||||||
continuousParentKeys.end(), key),
|
continuousParentKeys.end(), key),
|
||||||
continuousParentKeys.end());
|
continuousParentKeys.end());
|
||||||
}
|
}
|
||||||
return continuousParentKeys;
|
return continuousParentKeys;
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,9 +24,9 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
|
|
|
@ -51,7 +51,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
// Build the FactorValuePairs DecisionTree
|
// Build the FactorValuePairs DecisionTree
|
||||||
pairs = FactorValuePairs(
|
pairs = FactorValuePairs(
|
||||||
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||||
[](const auto& f) {
|
[](const sharedFactor& f) {
|
||||||
return std::pair{f,
|
return std::pair{f,
|
||||||
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
||||||
});
|
});
|
||||||
|
@ -63,7 +63,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
for (const auto& pair : factorPairs) {
|
for (const GaussianFactorValuePair& pair : factorPairs) {
|
||||||
if (pair.first && continuousKeys.empty()) {
|
if (pair.first && continuousKeys.empty()) {
|
||||||
continuousKeys = pair.first->keys();
|
continuousKeys = pair.first->keys();
|
||||||
break;
|
break;
|
||||||
|
@ -93,27 +93,27 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
|
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
|
||||||
: Base(helper.continuousKeys, helper.discreteKeys),
|
: Base(helper.continuousKeys, helper.discreteKeys),
|
||||||
factors_(helper.pairs) {}
|
factors_(helper.pairs) {}
|
||||||
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey &discreteKey,
|
const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
const std::vector<GaussianFactor::shared_ptr>& factors)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
||||||
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey &discreteKey,
|
const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactorValuePair> &factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||||
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
|
||||||
const FactorValuePairs &factors)
|
const FactorValuePairs& factors)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This* e = dynamic_cast<const This*>(&lf);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
// This will return false if either factors_ is empty or e->factors_ is
|
// This will return false if either factors_ is empty or e->factors_ is
|
||||||
|
@ -121,7 +121,8 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
if (factors_.empty() ^ e->factors_.empty()) return false;
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
auto compareFunc = [tol](const auto& pair1, const auto& pair2) {
|
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||||
|
const GaussianFactorValuePair& pair2) {
|
||||||
auto f1 = pair1.first, f2 = pair2.first;
|
auto f1 = pair1.first, f2 = pair2.first;
|
||||||
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||||
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
||||||
|
@ -130,8 +131,8 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianFactor::print(const std::string &s,
|
void HybridGaussianFactor::print(const std::string& s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
HybridFactor::print("", formatter);
|
HybridFactor::print("", formatter);
|
||||||
std::cout << "{\n";
|
std::cout << "{\n";
|
||||||
|
@ -139,9 +140,8 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
std::cout << " empty" << std::endl;
|
std::cout << " empty" << std::endl;
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"",
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](Key k) { return formatter(k); },
|
[&](const GaussianFactorValuePair& pair) -> std::string {
|
||||||
[&](const auto& pair) -> std::string {
|
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (pair.first) {
|
if (pair.first) {
|
||||||
|
@ -158,7 +158,7 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorValuePair HybridGaussianFactor::operator()(
|
GaussianFactorValuePair HybridGaussianFactor::operator()(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues& assignment) const {
|
||||||
return factors_(assignment);
|
return factors_(assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,18 +169,25 @@ HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||||
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
||||||
// scalar.
|
// scalar.
|
||||||
return {{factors_,
|
return {{factors_,
|
||||||
[](const auto& pair) -> std::pair<GaussianFactorGraph, double> {
|
[](const GaussianFactorValuePair& pair)
|
||||||
|
-> std::pair<GaussianFactorGraph, double> {
|
||||||
return {GaussianFactorGraph{pair.first}, pair.second};
|
return {GaussianFactorGraph{pair.first}, pair.second};
|
||||||
}}};
|
}}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
inline static double PotentiallyPrunedComponentError(
|
||||||
|
const GaussianFactorValuePair& pair, const VectorValues& continuousValues) {
|
||||||
|
return pair.first ? pair.first->error(continuousValues) + pair.second
|
||||||
|
: std::numeric_limits<double>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues& continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [&continuousValues](const auto& pair) {
|
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
||||||
return pair.first ? pair.first->error(continuousValues) + pair.second
|
return PotentiallyPrunedComponentError(pair, continuousValues);
|
||||||
: std::numeric_limits<double>::infinity();
|
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||||
return error_tree;
|
return error_tree;
|
||||||
|
@ -189,9 +196,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianFactor::error(const HybridValues& values) const {
|
double HybridGaussianFactor::error(const HybridValues& values) const {
|
||||||
// Directly index to get the component, no need to build the whole tree.
|
// Directly index to get the component, no need to build the whole tree.
|
||||||
const auto pair = factors_(values.discrete());
|
const GaussianFactorValuePair pair = factors_(values.discrete());
|
||||||
return pair.first ? pair.first->error(values.continuous()) + pair.second
|
return PotentiallyPrunedComponentError(pair, values.continuous());
|
||||||
: std::numeric_limits<double>::infinity();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -58,7 +58,7 @@ using GaussianFactorValuePair = std::pair<GaussianFactor::shared_ptr, double>;
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
public:
|
public:
|
||||||
using Base = HybridFactor;
|
using Base = HybridFactor;
|
||||||
using This = HybridGaussianFactor;
|
using This = HybridGaussianFactor;
|
||||||
using shared_ptr = std::shared_ptr<This>;
|
using shared_ptr = std::shared_ptr<This>;
|
||||||
|
@ -68,11 +68,11 @@ public:
|
||||||
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
||||||
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
FactorValuePairs factors_;
|
FactorValuePairs factors_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
@ -120,9 +120,8 @@ public:
|
||||||
|
|
||||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||||
|
|
||||||
void
|
void print(const std::string &s = "", const KeyFormatter &formatter =
|
||||||
print(const std::string &s = "",
|
DefaultKeyFormatter) const override;
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard API
|
/// @name Standard API
|
||||||
|
@ -138,8 +137,8 @@ public:
|
||||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||||
* as the factors involved, and leaf values as the error.
|
* as the factors involved, and leaf values as the error.
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key>
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
errorTree(const VectorValues &continuousValues) const override;
|
const VectorValues &continuousValues) const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||||
|
@ -159,7 +158,7 @@ public:
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to augment the [A|b] matrices in the factor
|
* @brief Helper function to augment the [A|b] matrices in the factor
|
||||||
* components with the additional scalar values. This is done by storing the
|
* components with the additional scalar values. This is done by storing the
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
@ -39,7 +40,6 @@
|
||||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||||
#include <gtsam/linear/HessianFactor.h>
|
#include <gtsam/linear/HessianFactor.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
#include "gtsam/discrete/DiscreteValues.h"
|
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -57,15 +57,16 @@ using std::dynamic_pointer_cast;
|
||||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||||
using Result =
|
using Result =
|
||||||
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
|
||||||
using ResultTree = DecisionTree<Key, std::pair<Result, double>>;
|
using ResultValuePair = std::pair<Result, double>;
|
||||||
|
using ResultTree = DecisionTree<Key, ResultValuePair>;
|
||||||
|
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// Throw a runtime exception for method specified in string s, and factor f:
|
// Throw a runtime exception for method specified in string s, and factor f:
|
||||||
static void throwRuntimeError(const std::string& s,
|
static void throwRuntimeError(const std::string &s,
|
||||||
const std::shared_ptr<Factor>& f) {
|
const std::shared_ptr<Factor> &f) {
|
||||||
auto& fr = *f;
|
auto &fr = *f;
|
||||||
throw std::runtime_error(s + " not implemented for factor type " +
|
throw std::runtime_error(s + " not implemented for factor type " +
|
||||||
demangle(typeid(fr).name()) + ".");
|
demangle(typeid(fr).name()) + ".");
|
||||||
}
|
}
|
||||||
|
@ -83,11 +84,12 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
const DiscreteValues &assignment,
|
const DiscreteValues &assignment,
|
||||||
const KeyFormatter &keyFormatter) {
|
const KeyFormatter &keyFormatter) {
|
||||||
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||||
if (assignment.empty())
|
if (assignment.empty()) {
|
||||||
hgf->print("HybridGaussianFactor:", keyFormatter);
|
hgf->print("HybridGaussianFactor:", keyFormatter);
|
||||||
else
|
} else {
|
||||||
hgf->operator()(assignment)
|
hgf->operator()(assignment)
|
||||||
.first->print("HybridGaussianFactor, component:", keyFormatter);
|
.first->print("HybridGaussianFactor, component:", keyFormatter);
|
||||||
|
}
|
||||||
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) {
|
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
factor->print("GaussianFactor:\n", keyFormatter);
|
factor->print("GaussianFactor:\n", keyFormatter);
|
||||||
|
|
||||||
|
@ -99,12 +101,13 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
} else if (hc->isDiscrete()) {
|
} else if (hc->isDiscrete()) {
|
||||||
factor->print("DiscreteConditional:\n", keyFormatter);
|
factor->print("DiscreteConditional:\n", keyFormatter);
|
||||||
} else {
|
} else {
|
||||||
if (assignment.empty())
|
if (assignment.empty()) {
|
||||||
hc->print("HybridConditional:", keyFormatter);
|
hc->print("HybridConditional:", keyFormatter);
|
||||||
else
|
} else {
|
||||||
hc->asHybrid()
|
hc->asHybrid()
|
||||||
->choose(assignment)
|
->choose(assignment)
|
||||||
->print("HybridConditional, component:\n", keyFormatter);
|
->print("HybridConditional, component:\n", keyFormatter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
factor->print("Unknown factor type\n", keyFormatter);
|
factor->print("Unknown factor type\n", keyFormatter);
|
||||||
|
@ -112,13 +115,13 @@ static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::print(const std::string& s,
|
void HybridGaussianFactorGraph::print(const std::string &s,
|
||||||
const KeyFormatter& keyFormatter) const {
|
const KeyFormatter &keyFormatter) const {
|
||||||
std::cout << (s.empty() ? "" : s + " ") << std::endl;
|
std::cout << (s.empty() ? "" : s + " ") << std::endl;
|
||||||
std::cout << "size: " << size() << std::endl;
|
std::cout << "size: " << size() << std::endl;
|
||||||
|
|
||||||
for (size_t i = 0; i < factors_.size(); i++) {
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
auto&& factor = factors_[i];
|
auto &&factor = factors_[i];
|
||||||
if (factor == nullptr) {
|
if (factor == nullptr) {
|
||||||
std::cout << "Factor " << i << ": nullptr\n";
|
std::cout << "Factor " << i << ": nullptr\n";
|
||||||
continue;
|
continue;
|
||||||
|
@ -163,7 +166,7 @@ HybridGaussianProductFactor HybridGaussianFactorGraph::collectProductFactor()
|
||||||
const {
|
const {
|
||||||
HybridGaussianProductFactor result;
|
HybridGaussianProductFactor result;
|
||||||
|
|
||||||
for (auto& f : factors_) {
|
for (auto &f : factors_) {
|
||||||
// TODO(dellaert): can we make this cleaner and less error-prone?
|
// TODO(dellaert): can we make this cleaner and less error-prone?
|
||||||
if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
continue; // Ignore OrphanWrapper
|
continue; // Ignore OrphanWrapper
|
||||||
|
@ -235,7 +238,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute discrete probabilities.
|
||||||
auto potential = [&](const auto& pair) -> double {
|
auto potential = [&](const auto &pair) -> double {
|
||||||
auto [factor, scalar] = pair;
|
auto [factor, scalar] = pair;
|
||||||
// If factor is null, it has been pruned, hence return potential zero
|
// If factor is null, it has been pruned, hence return potential zero
|
||||||
if (!factor) return 0.0;
|
if (!factor) return 0.0;
|
||||||
|
@ -270,10 +273,10 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
* depends on the discrete separator if present.
|
* depends on the discrete separator if present.
|
||||||
*/
|
*/
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
const ResultTree& eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
auto potential = [&](const auto &pair) -> double {
|
auto potential = [&](const auto &pair) -> double {
|
||||||
const auto& [conditional, factor] = pair.first;
|
const auto &[conditional, factor] = pair.first;
|
||||||
const double scalar = pair.second;
|
const double scalar = pair.second;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
// `error` has the following contributions:
|
// `error` has the following contributions:
|
||||||
|
@ -303,7 +306,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
const ResultTree &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const auto &pair) -> GaussianFactorValuePair {
|
auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
|
||||||
const auto &[conditional, factor] = pair.first;
|
const auto &[conditional, factor] = pair.first;
|
||||||
const double scalar = pair.second;
|
const double scalar = pair.second;
|
||||||
if (conditional && factor) {
|
if (conditional && factor) {
|
||||||
|
@ -350,9 +353,9 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
bool someContinuousLeft = false;
|
bool someContinuousLeft = false;
|
||||||
auto eliminate = [&](const std::pair<GaussianFactorGraph, double>& pair)
|
auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
|
||||||
-> std::pair<Result, double> {
|
-> std::pair<Result, double> {
|
||||||
const auto& [graph, scalar] = pair;
|
const auto &[graph, scalar] = pair;
|
||||||
|
|
||||||
if (graph.empty()) {
|
if (graph.empty()) {
|
||||||
return {{nullptr, nullptr}, 0.0};
|
return {{nullptr, nullptr}, 0.0};
|
||||||
|
@ -382,7 +385,8 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
|
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional from the conditionals
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::Conditionals conditionals(
|
||||||
eliminationResults, [](const auto& pair) { return pair.first.first; });
|
eliminationResults,
|
||||||
|
[](const ResultValuePair &pair) { return pair.first.first; });
|
||||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
||||||
discreteSeparator, conditionals);
|
discreteSeparator, conditionals);
|
||||||
|
|
||||||
|
|
|
@ -145,9 +145,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
void
|
void print(
|
||||||
print(const std::string &s = "HybridGaussianFactorGraph",
|
const std::string& s = "HybridGaussianFactorGraph",
|
||||||
const KeyFormatter &keyFormatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Print the errors of each factor in the hybrid factor graph.
|
* @brief Print the errors of each factor in the hybrid factor graph.
|
||||||
|
|
|
@ -22,43 +22,52 @@
|
||||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
using Y = HybridGaussianProductFactor::Y;
|
using Y = GaussianFactorGraphValuePair;
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
static Y add(const Y& y1, const Y& y2) {
|
static Y add(const Y& y1, const Y& y2) {
|
||||||
GaussianFactorGraph result = y1.first;
|
GaussianFactorGraph result = y1.first;
|
||||||
result.push_back(y2.first);
|
result.push_back(y2.first);
|
||||||
return {result, y1.second + y2.second};
|
return {result, y1.second + y2.second};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
|
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
|
||||||
const HybridGaussianProductFactor& b) {
|
const HybridGaussianProductFactor& b) {
|
||||||
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
|
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||||
const HybridGaussianFactor& factor) const {
|
const HybridGaussianFactor& factor) const {
|
||||||
return *this + factor.asProductFactor();
|
return *this + factor.asProductFactor();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||||
const GaussianFactor::shared_ptr& factor) const {
|
const GaussianFactor::shared_ptr& factor) const {
|
||||||
return *this + HybridGaussianProductFactor(factor);
|
return *this + HybridGaussianProductFactor(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
||||||
const GaussianFactor::shared_ptr& factor) {
|
const GaussianFactor::shared_ptr& factor) {
|
||||||
*this = *this + factor;
|
*this = *this + factor;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
||||||
const HybridGaussianFactor& factor) {
|
const HybridGaussianFactor& factor) {
|
||||||
*this = *this + factor;
|
*this = *this + factor;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
void HybridGaussianProductFactor::print(const std::string& s,
|
void HybridGaussianProductFactor::print(const std::string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
KeySet keys;
|
KeySet keys;
|
||||||
|
@ -69,18 +78,25 @@ void HybridGaussianProductFactor::print(const std::string& s,
|
||||||
};
|
};
|
||||||
Base::print(s, formatter, printer);
|
Base::print(s, formatter, printer);
|
||||||
if (!keys.empty()) {
|
if (!keys.empty()) {
|
||||||
std::stringstream ss;
|
std::cout << s << " Keys:";
|
||||||
ss << s << " Keys:";
|
for (auto&& key : keys) std::cout << " " << formatter(key);
|
||||||
for (auto&& key : keys) ss << " " << formatter(key);
|
std::cout << "." << std::endl;
|
||||||
std::cout << ss.str() << "." << std::endl;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
bool HybridGaussianProductFactor::equals(
|
||||||
|
const HybridGaussianProductFactor& other, double tol) const {
|
||||||
|
return Base::equals(other, [tol](const Y& a, const Y& b) {
|
||||||
|
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
||||||
auto emptyGaussian = [](const Y& y) {
|
auto emptyGaussian = [](const Y& y) {
|
||||||
bool hasNull =
|
bool hasNull =
|
||||||
std::any_of(y.first.begin(),
|
std::any_of(y.first.begin(), y.first.end(),
|
||||||
y.first.end(),
|
|
||||||
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
|
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
|
||||||
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
|
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
|
||||||
};
|
};
|
||||||
|
|
|
@ -26,12 +26,13 @@ namespace gtsam {
|
||||||
|
|
||||||
class HybridGaussianFactor;
|
class HybridGaussianFactor;
|
||||||
|
|
||||||
|
using GaussianFactorGraphValuePair = std::pair<GaussianFactorGraph, double>;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums
|
/// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums
|
||||||
class HybridGaussianProductFactor
|
class GTSAM_EXPORT HybridGaussianProductFactor
|
||||||
: public DecisionTree<Key, std::pair<GaussianFactorGraph, double>> {
|
: public DecisionTree<Key, GaussianFactorGraphValuePair> {
|
||||||
public:
|
public:
|
||||||
using Y = std::pair<GaussianFactorGraph, double>;
|
using Base = DecisionTree<Key, GaussianFactorGraphValuePair>;
|
||||||
using Base = DecisionTree<Key, Y>;
|
|
||||||
|
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
@ -46,7 +47,7 @@ class HybridGaussianProductFactor
|
||||||
*/
|
*/
|
||||||
template <class FACTOR>
|
template <class FACTOR>
|
||||||
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor)
|
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor)
|
||||||
: Base(Y{GaussianFactorGraph{factor}, 0.0}) {}
|
: Base(GaussianFactorGraphValuePair{GaussianFactorGraph{factor}, 0.0}) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct from DecisionTree
|
* @brief Construct from DecisionTree
|
||||||
|
@ -94,12 +95,7 @@ class HybridGaussianProductFactor
|
||||||
* @return true if equal, false otherwise
|
* @return true if equal, false otherwise
|
||||||
*/
|
*/
|
||||||
bool equals(const HybridGaussianProductFactor& other,
|
bool equals(const HybridGaussianProductFactor& other,
|
||||||
double tol = 1e-9) const {
|
double tol = 1e-9) const;
|
||||||
return Base::equals(other, [tol](const Y& a, const Y& b) {
|
|
||||||
return a.first.equals(b.first, tol) &&
|
|
||||||
std::abs(a.second - b.second) < tol;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -199,7 +199,7 @@ TEST(HybridBayesNet, Tiny) {
|
||||||
factors_x0.push_back(fg.at(1));
|
factors_x0.push_back(fg.at(1));
|
||||||
auto productFactor = factors_x0.collectProductFactor();
|
auto productFactor = factors_x0.collectProductFactor();
|
||||||
|
|
||||||
// Check that scalars are 0 and 1.79
|
// Check that scalars are 0 and 1.79 (regression)
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, productFactor({{M(0), 0}}).second, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.0, productFactor({{M(0), 0}}).second, 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(1.791759, productFactor({{M(0), 1}}).second, 1e-5);
|
EXPECT_DOUBLES_EQUAL(1.791759, productFactor({{M(0), 1}}).second, 1e-5);
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,8 @@
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
#include <gtsam/inference/DotWriter.h>
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
|
@ -62,7 +64,8 @@ std::vector<GaussianFactor::shared_ptr> components(Key key) {
|
||||||
} // namespace two
|
} // namespace two
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
TEST(HybridGaussianFactorGraph,
|
||||||
|
HybridGaussianFactorGraphEliminateFullMultifrontalSimple) {
|
||||||
HybridGaussianFactorGraph hfg;
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
@ -179,10 +182,8 @@ TEST(HybridGaussianFactorGraph, Switching) {
|
||||||
std::vector<int> naturalX(N);
|
std::vector<int> naturalX(N);
|
||||||
std::iota(naturalX.begin(), naturalX.end(), 1);
|
std::iota(naturalX.begin(), naturalX.end(), 1);
|
||||||
std::vector<Key> ordX;
|
std::vector<Key> ordX;
|
||||||
std::transform(
|
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
||||||
naturalX.begin(), naturalX.end(), std::back_inserter(ordX), [](int x) {
|
[](int x) { return X(x); });
|
||||||
return X(x);
|
|
||||||
});
|
|
||||||
|
|
||||||
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
||||||
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
||||||
|
@ -195,10 +196,8 @@ TEST(HybridGaussianFactorGraph, Switching) {
|
||||||
std::vector<int> naturalC(N - 1);
|
std::vector<int> naturalC(N - 1);
|
||||||
std::iota(naturalC.begin(), naturalC.end(), 1);
|
std::iota(naturalC.begin(), naturalC.end(), 1);
|
||||||
std::vector<Key> ordC;
|
std::vector<Key> ordC;
|
||||||
std::transform(
|
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
||||||
naturalC.begin(), naturalC.end(), std::back_inserter(ordC), [](int x) {
|
[](int x) { return M(x); });
|
||||||
return M(x);
|
|
||||||
});
|
|
||||||
|
|
||||||
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
||||||
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
||||||
|
@ -237,10 +236,8 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
|
||||||
std::vector<int> naturalX(N);
|
std::vector<int> naturalX(N);
|
||||||
std::iota(naturalX.begin(), naturalX.end(), 1);
|
std::iota(naturalX.begin(), naturalX.end(), 1);
|
||||||
std::vector<Key> ordX;
|
std::vector<Key> ordX;
|
||||||
std::transform(
|
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
||||||
naturalX.begin(), naturalX.end(), std::back_inserter(ordX), [](int x) {
|
[](int x) { return X(x); });
|
||||||
return X(x);
|
|
||||||
});
|
|
||||||
|
|
||||||
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
||||||
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
||||||
|
@ -253,10 +250,8 @@ TEST(HybridGaussianFactorGraph, SwitchingISAM) {
|
||||||
std::vector<int> naturalC(N - 1);
|
std::vector<int> naturalC(N - 1);
|
||||||
std::iota(naturalC.begin(), naturalC.end(), 1);
|
std::iota(naturalC.begin(), naturalC.end(), 1);
|
||||||
std::vector<Key> ordC;
|
std::vector<Key> ordC;
|
||||||
std::transform(
|
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
||||||
naturalC.begin(), naturalC.end(), std::back_inserter(ordC), [](int x) {
|
[](int x) { return M(x); });
|
||||||
return M(x);
|
|
||||||
});
|
|
||||||
|
|
||||||
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
||||||
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/Test.h>
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
|
@ -37,9 +39,6 @@
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <CppUnitLite/Test.h>
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -73,8 +72,8 @@ TEST(HybridGaussianFactorGraph, Creation) {
|
||||||
HybridGaussianConditional gm(
|
HybridGaussianConditional gm(
|
||||||
m0,
|
m0,
|
||||||
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
||||||
std::make_shared<GaussianConditional>(
|
std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1),
|
||||||
X(0), Vector3::Ones(), I_3x3, X(1), I_3x3)});
|
I_3x3)});
|
||||||
hfg.add(gm);
|
hfg.add(gm);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(2, hfg.size());
|
EXPECT_LONGS_EQUAL(2, hfg.size());
|
||||||
|
@ -118,8 +117,7 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||||
CHECK(factor);
|
CHECK(factor);
|
||||||
// regression test
|
// regression test
|
||||||
EXPECT(
|
EXPECT(assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
|
||||||
assert_equal(DecisionTreeFactor{m1, "15.74961 15.74961"}, *factor, 1e-5));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -177,7 +175,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
Switching s(2, betweenSigma, priorSigma);
|
Switching s(2, betweenSigma, priorSigma);
|
||||||
|
|
||||||
// Check size of linearized factor graph
|
// Check size of linearized factor graph
|
||||||
const HybridGaussianFactorGraph& graph = s.linearizedFactorGraph;
|
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||||
EXPECT_LONGS_EQUAL(4, graph.size());
|
EXPECT_LONGS_EQUAL(4, graph.size());
|
||||||
|
|
||||||
// Create some continuous and discrete values
|
// Create some continuous and discrete values
|
||||||
|
@ -203,20 +201,20 @@ TEST(HybridBayesNet, Switching) {
|
||||||
// Check error for M(0) = 0
|
// Check error for M(0) = 0
|
||||||
const HybridValues values0{continuousValues, modeZero};
|
const HybridValues values0{continuousValues, modeZero};
|
||||||
double expectedError0 = 0;
|
double expectedError0 = 0;
|
||||||
for (const auto& factor : graph) expectedError0 += factor->error(values0);
|
for (const auto &factor : graph) expectedError0 += factor->error(values0);
|
||||||
EXPECT_DOUBLES_EQUAL(expectedError0, graph.error(values0), 1e-5);
|
EXPECT_DOUBLES_EQUAL(expectedError0, graph.error(values0), 1e-5);
|
||||||
|
|
||||||
// Check error for M(0) = 1
|
// Check error for M(0) = 1
|
||||||
const HybridValues values1{continuousValues, modeOne};
|
const HybridValues values1{continuousValues, modeOne};
|
||||||
double expectedError1 = 0;
|
double expectedError1 = 0;
|
||||||
for (const auto& factor : graph) expectedError1 += factor->error(values1);
|
for (const auto &factor : graph) expectedError1 += factor->error(values1);
|
||||||
EXPECT_DOUBLES_EQUAL(expectedError1, graph.error(values1), 1e-5);
|
EXPECT_DOUBLES_EQUAL(expectedError1, graph.error(values1), 1e-5);
|
||||||
|
|
||||||
// Check errorTree
|
// Check errorTree
|
||||||
AlgebraicDecisionTree<Key> actualErrors = graph.errorTree(continuousValues);
|
AlgebraicDecisionTree<Key> actualErrors = graph.errorTree(continuousValues);
|
||||||
// Create expected error tree
|
// Create expected error tree
|
||||||
const AlgebraicDecisionTree<Key> expectedErrors(
|
const AlgebraicDecisionTree<Key> expectedErrors(M(0), expectedError0,
|
||||||
M(0), expectedError0, expectedError1);
|
expectedError1);
|
||||||
|
|
||||||
// Check that the actual error tree matches the expected one
|
// Check that the actual error tree matches the expected one
|
||||||
EXPECT(assert_equal(expectedErrors, actualErrors, 1e-5));
|
EXPECT(assert_equal(expectedErrors, actualErrors, 1e-5));
|
||||||
|
@ -232,8 +230,8 @@ TEST(HybridBayesNet, Switching) {
|
||||||
const AlgebraicDecisionTree<Key> graphPosterior =
|
const AlgebraicDecisionTree<Key> graphPosterior =
|
||||||
graph.discretePosterior(continuousValues);
|
graph.discretePosterior(continuousValues);
|
||||||
const double sum = probPrime0 + probPrime1;
|
const double sum = probPrime0 + probPrime1;
|
||||||
const AlgebraicDecisionTree<Key> expectedPosterior(
|
const AlgebraicDecisionTree<Key> expectedPosterior(M(0), probPrime0 / sum,
|
||||||
M(0), probPrime0 / sum, probPrime1 / sum);
|
probPrime1 / sum);
|
||||||
EXPECT(assert_equal(expectedPosterior, graphPosterior, 1e-5));
|
EXPECT(assert_equal(expectedPosterior, graphPosterior, 1e-5));
|
||||||
|
|
||||||
// Make the clique of factors connected to x0:
|
// Make the clique of factors connected to x0:
|
||||||
|
@ -275,15 +273,13 @@ TEST(HybridBayesNet, Switching) {
|
||||||
// Check that the scalars incorporate the negative log constant of the
|
// Check that the scalars incorporate the negative log constant of the
|
||||||
// conditional
|
// conditional
|
||||||
EXPECT_DOUBLES_EQUAL(scalar0 - (*p_x0_given_x1_m)(modeZero)->negLogConstant(),
|
EXPECT_DOUBLES_EQUAL(scalar0 - (*p_x0_given_x1_m)(modeZero)->negLogConstant(),
|
||||||
(*phi_x1_m)(modeZero).second,
|
(*phi_x1_m)(modeZero).second, 1e-9);
|
||||||
1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(scalar1 - (*p_x0_given_x1_m)(modeOne)->negLogConstant(),
|
EXPECT_DOUBLES_EQUAL(scalar1 - (*p_x0_given_x1_m)(modeOne)->negLogConstant(),
|
||||||
(*phi_x1_m)(modeOne).second,
|
(*phi_x1_m)(modeOne).second, 1e-9);
|
||||||
1e-9);
|
|
||||||
|
|
||||||
// Check that the conditional and remaining factor are consistent for both
|
// Check that the conditional and remaining factor are consistent for both
|
||||||
// modes
|
// modes
|
||||||
for (auto&& mode : {modeZero, modeOne}) {
|
for (auto &&mode : {modeZero, modeOne}) {
|
||||||
const auto gc = (*p_x0_given_x1_m)(mode);
|
const auto gc = (*p_x0_given_x1_m)(mode);
|
||||||
const auto [gf, scalar] = (*phi_x1_m)(mode);
|
const auto [gf, scalar] = (*phi_x1_m)(mode);
|
||||||
|
|
||||||
|
@ -342,7 +338,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
// However, we can still check the total error for the clique factors_x1 and
|
// However, we can still check the total error for the clique factors_x1 and
|
||||||
// the elimination results are equal, modulo -again- the negative log constant
|
// the elimination results are equal, modulo -again- the negative log constant
|
||||||
// of the conditional.
|
// of the conditional.
|
||||||
for (auto&& mode : {modeZero, modeOne}) {
|
for (auto &&mode : {modeZero, modeOne}) {
|
||||||
auto gc_x1 = (*p_x1_given_m)(mode);
|
auto gc_x1 = (*p_x1_given_m)(mode);
|
||||||
double originalError_x1 = factors_x1.error({continuousValues, mode});
|
double originalError_x1 = factors_x1.error({continuousValues, mode});
|
||||||
const double actualError = gc_x1->negLogConstant() +
|
const double actualError = gc_x1->negLogConstant() +
|
||||||
|
@ -372,7 +368,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||||
Switching s(3);
|
Switching s(3);
|
||||||
|
|
||||||
// Check size of linearized factor graph
|
// Check size of linearized factor graph
|
||||||
const HybridGaussianFactorGraph& graph = s.linearizedFactorGraph;
|
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||||
EXPECT_LONGS_EQUAL(7, graph.size());
|
EXPECT_LONGS_EQUAL(7, graph.size());
|
||||||
|
|
||||||
// Eliminate the graph
|
// Eliminate the graph
|
||||||
|
|
Loading…
Reference in New Issue