Removed HybridDiscreteFactor wrapper
parent
b93de21295
commit
7e32a8739e
|
@ -1,61 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridDiscreteFactor.cpp
|
|
||||||
* @brief Wrapper for a discrete factor
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
|
|
||||||
#include "gtsam/discrete/DecisionTreeFactor.h"
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
|
|
||||||
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
|
|
||||||
->discreteKeys()),
|
|
||||||
inner_(other) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
|
||||||
: Base(dtf.discreteKeys()),
|
|
||||||
inner_(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
|
||||||
if (e == nullptr) return false;
|
|
||||||
if (!Base::equals(*e, tol)) return false;
|
|
||||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
|
||||||
: !(e->inner_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
void HybridDiscreteFactor::print(const std::string &s,
|
|
||||||
const KeyFormatter &formatter) const {
|
|
||||||
HybridFactor::print(s, formatter);
|
|
||||||
inner_->print("\n", formatter);
|
|
||||||
};
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
double HybridDiscreteFactor::error(const HybridValues &values) const {
|
|
||||||
return -log((*inner_)(values.discrete()));
|
|
||||||
}
|
|
||||||
/* ************************************************************************ */
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -1,91 +0,0 @@
|
||||||
/* ----------------------------------------------------------------------------
|
|
||||||
|
|
||||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
||||||
* Atlanta, Georgia 30332-0415
|
|
||||||
* All Rights Reserved
|
|
||||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
||||||
|
|
||||||
* See LICENSE for the license information
|
|
||||||
|
|
||||||
* -------------------------------------------------------------------------- */
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file HybridDiscreteFactor.h
|
|
||||||
* @date Mar 11, 2022
|
|
||||||
* @author Fan Jiang
|
|
||||||
* @author Varun Agrawal
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
|
||||||
|
|
||||||
namespace gtsam {
|
|
||||||
|
|
||||||
class HybridValues;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
|
|
||||||
* allows us to hide the implementation of DiscreteFactor and thus avoid
|
|
||||||
* diamond inheritance.
|
|
||||||
*
|
|
||||||
* @ingroup hybrid
|
|
||||||
*/
|
|
||||||
class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
|
||||||
private:
|
|
||||||
DiscreteFactor::shared_ptr inner_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
using Base = HybridFactor;
|
|
||||||
using This = HybridDiscreteFactor;
|
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
|
||||||
|
|
||||||
/// @name Constructors
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/// Default constructor - for serialization.
|
|
||||||
HybridDiscreteFactor() = default;
|
|
||||||
|
|
||||||
// Implicit conversion from a shared ptr of DF
|
|
||||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
|
||||||
|
|
||||||
// Forwarding constructor from concrete DecisionTreeFactor
|
|
||||||
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Testable
|
|
||||||
/// @{
|
|
||||||
virtual bool equals(const HybridFactor &lf, double tol) const override;
|
|
||||||
|
|
||||||
void print(
|
|
||||||
const std::string &s = "HybridFactor\n",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/// @}
|
|
||||||
/// @name Standard Interface
|
|
||||||
/// @{
|
|
||||||
|
|
||||||
/// Return pointer to the internal discrete factor.
|
|
||||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
|
||||||
|
|
||||||
/// Return the error of the underlying Discrete Factor.
|
|
||||||
double error(const HybridValues &values) const override;
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/** Serialization function */
|
|
||||||
friend class boost::serialization::access;
|
|
||||||
template <class ARCHIVE>
|
|
||||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
|
||||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// traits
|
|
||||||
template <>
|
|
||||||
struct traits<HybridDiscreteFactor> : public Testable<HybridDiscreteFactor> {};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
|
|
@ -67,11 +67,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
||||||
const DiscreteKeys &key2);
|
const DiscreteKeys &key2);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class for hybrid probabilistic factors
|
* Base class for *truly* hybrid probabilistic factors
|
||||||
*
|
*
|
||||||
* Examples:
|
* Examples:
|
||||||
* - HybridGaussianFactor
|
* - MixtureFactor
|
||||||
* - HybridDiscreteFactor
|
|
||||||
* - GaussianMixtureFactor
|
* - GaussianMixtureFactor
|
||||||
* - GaussianMixture
|
* - GaussianMixture
|
||||||
*
|
*
|
||||||
|
|
|
@ -26,7 +26,6 @@
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
@ -47,7 +46,6 @@
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -58,6 +56,15 @@ namespace gtsam {
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
// Throw a runtime exception for method specified in string s, and factor f:
|
||||||
|
static void throwRuntimeError(const std::string &s,
|
||||||
|
const boost::shared_ptr<Factor> &f) {
|
||||||
|
auto &fr = *f;
|
||||||
|
throw std::runtime_error(s + " not implemented for factor type " +
|
||||||
|
demangle(typeid(fr).name()) + ".");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianFactorGraphTree addGaussian(
|
static GaussianFactorGraphTree addGaussian(
|
||||||
const GaussianFactorGraphTree &gfgTree,
|
const GaussianFactorGraphTree &gfgTree,
|
||||||
|
@ -67,7 +74,6 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto add = [&factor](const GraphAndConstant &graph_z) {
|
auto add = [&factor](const GraphAndConstant &graph_z) {
|
||||||
auto result = graph_z.graph;
|
auto result = graph_z.graph;
|
||||||
|
@ -103,8 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
}
|
}
|
||||||
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
result = addGaussian(result, gf->inner());
|
result = addGaussian(result, gf->inner());
|
||||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
|
||||||
// Don't do anything for discrete-only factors
|
// Don't do anything for discrete-only factors
|
||||||
// since we want to eliminate continuous values only.
|
// since we want to eliminate continuous values only.
|
||||||
continue;
|
continue;
|
||||||
|
@ -116,10 +121,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
|
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
|
||||||
"yet.");
|
"yet.");
|
||||||
} else {
|
} else {
|
||||||
auto &fr = *f;
|
throwRuntimeError("gtsam::assembleGraphTree", f);
|
||||||
throw std::invalid_argument(
|
|
||||||
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
|
|
||||||
demangle(typeid(fr).name()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,16 +131,18 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
|
using boost::dynamic_pointer_cast;
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
for (auto &fp : factors) {
|
for (auto &fp : factors) {
|
||||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||||
gfg.push_back(ptr->inner());
|
gfg.push_back(hgf->inner());
|
||||||
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(fp)) {
|
||||||
gfg.push_back(
|
auto gc = hc->asGaussian();
|
||||||
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
assert(gc);
|
||||||
|
gfg.push_back(gc);
|
||||||
} else {
|
} else {
|
||||||
// It is an orphan wrapped conditional
|
// It is an orphan wrapped conditional
|
||||||
}
|
}
|
||||||
|
@ -150,18 +154,17 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &factor : factors) {
|
for (auto &factor : factors) {
|
||||||
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) {
|
if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||||
dfg.push_back(p->inner());
|
dfg.push_back(dtf);
|
||||||
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) {
|
} else if (auto hc =
|
||||||
auto discrete_conditional =
|
boost::static_pointer_cast<HybridConditional>(factor)) {
|
||||||
boost::static_pointer_cast<DiscreteConditional>(p->inner());
|
dfg.push_back(hc->asDiscrete());
|
||||||
dfg.push_back(discrete_conditional);
|
|
||||||
} else {
|
} else {
|
||||||
// It is an orphan wrapper
|
// It is an orphan wrapper
|
||||||
}
|
}
|
||||||
|
@ -170,8 +173,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(result.first),
|
return {boost::make_shared<HybridConditional>(result.first), result.second};
|
||||||
boost::make_shared<HybridDiscreteFactor>(result.second)};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -189,7 +191,7 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys,
|
const Ordering &frontalKeys,
|
||||||
const KeyVector &continuousSeparator,
|
const KeyVector &continuousSeparator,
|
||||||
|
@ -291,7 +293,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
discreteFactor};
|
||||||
} else {
|
} else {
|
||||||
// Create a resulting GaussianMixtureFactor on the separator.
|
// Create a resulting GaussianMixtureFactor on the separator.
|
||||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
|
@ -314,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
* eliminate a discrete variable (as specified in the ordering), the result will
|
* eliminate a discrete variable (as specified in the ordering), the result will
|
||||||
* be INCORRECT and there will be NO error raised.
|
* be INCORRECT and there will be NO error raised.
|
||||||
*/
|
*/
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>> //
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
const Ordering &frontalKeys) {
|
const Ordering &frontalKeys) {
|
||||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||||
|
@ -374,14 +376,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build a map from keys to DiscreteKeys
|
// Build a map from keys to DiscreteKeys
|
||||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
||||||
for (auto &&factor : factors) {
|
|
||||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
|
||||||
for (auto &k : p->discreteKeys()) {
|
|
||||||
mapFromKeyToDiscreteKey[k.first] = k;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals.
|
// Fill in discrete frontals and continuous frontals.
|
||||||
std::set<DiscreteKey> discreteFrontals;
|
std::set<DiscreteKey> discreteFrontals;
|
||||||
|
@ -433,23 +428,25 @@ void HybridGaussianFactorGraph::add(JacobianFactor &&factor) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::add(boost::shared_ptr<JacobianFactor> &factor) {
|
void HybridGaussianFactorGraph::add(
|
||||||
|
const boost::shared_ptr<JacobianFactor> &factor) {
|
||||||
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
|
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) {
|
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) {
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor)));
|
FactorGraph::add(std::move(factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
void HybridGaussianFactorGraph::add(
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
const DecisionTreeFactor::shared_ptr &factor) {
|
||||||
|
FactorGraph::add(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||||
KeySet discrete_keys = discreteKeys();
|
const KeySet discrete_keys = discreteKeySet();
|
||||||
const VariableIndex index(factors_);
|
const VariableIndex index(factors_);
|
||||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||||
|
@ -484,16 +481,11 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
error_tree = error_tree.apply(
|
error_tree = error_tree.apply(
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
|
|
||||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
auto &fr = *f;
|
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||||
throw std::invalid_argument(
|
|
||||||
std::string(
|
|
||||||
"HybridGaussianFactorGraph::error: factor type not handled: ") +
|
|
||||||
demangle(typeid(fr).name()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -503,9 +495,14 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (auto &factor : factors_) {
|
for (auto &f : factors_) {
|
||||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
|
||||||
error += p->error(values);
|
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||||
|
error += hf->error(values);
|
||||||
|
} else if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
|
error -= log((*dtf)(values.discrete()));
|
||||||
|
} else {
|
||||||
|
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return error;
|
return error;
|
||||||
|
|
|
@ -50,13 +50,13 @@ class HybridValues;
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
GTSAM_EXPORT
|
GTSAM_EXPORT
|
||||||
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
|
std::pair<boost::shared_ptr<HybridConditional>, boost::shared_ptr<Factor>>
|
||||||
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
template <>
|
template <>
|
||||||
struct EliminationTraits<HybridGaussianFactorGraph> {
|
struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||||
typedef HybridFactor FactorType; ///< Type of factors in factor graph
|
typedef Factor FactorType; ///< Type of factors in factor graph
|
||||||
typedef HybridGaussianFactorGraph
|
typedef HybridGaussianFactorGraph
|
||||||
FactorGraphType; ///< Type of the factor graph (e.g.
|
FactorGraphType; ///< Type of the factor graph (e.g.
|
||||||
///< HybridGaussianFactorGraph)
|
///< HybridGaussianFactorGraph)
|
||||||
|
@ -80,7 +80,6 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||||
* Hybrid Gaussian Factor Graph
|
* Hybrid Gaussian Factor Graph
|
||||||
* -----------------------
|
* -----------------------
|
||||||
* This is the linearized version of a hybrid factor graph.
|
* This is the linearized version of a hybrid factor graph.
|
||||||
* Everything inside needs to be hybrid factor or hybrid conditional.
|
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
|
@ -130,13 +129,13 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
void add(JacobianFactor&& factor);
|
void add(JacobianFactor&& factor);
|
||||||
|
|
||||||
/// Add a Jacobian factor as a shared ptr.
|
/// Add a Jacobian factor as a shared ptr.
|
||||||
void add(boost::shared_ptr<JacobianFactor>& factor);
|
void add(const boost::shared_ptr<JacobianFactor>& factor);
|
||||||
|
|
||||||
/// Add a DecisionTreeFactor to the factor graph.
|
/// Add a DecisionTreeFactor to the factor graph.
|
||||||
void add(DecisionTreeFactor&& factor);
|
void add(DecisionTreeFactor&& factor);
|
||||||
|
|
||||||
/// Add a DecisionTreeFactor as a shared ptr.
|
/// Add a DecisionTreeFactor as a shared ptr.
|
||||||
void add(DecisionTreeFactor::shared_ptr factor);
|
void add(const boost::shared_ptr<DecisionTreeFactor>& factor);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a gaussian factor *pointer* to the internal gaussian factor graph
|
* Add a gaussian factor *pointer* to the internal gaussian factor graph
|
||||||
|
|
|
@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering(
|
||||||
HybridGaussianFactorGraph& factors,
|
HybridGaussianFactorGraph& factors,
|
||||||
const HybridGaussianFactorGraph& newFactors) {
|
const HybridGaussianFactorGraph& newFactors) {
|
||||||
// Get all the discrete keys from the factors
|
// Get all the discrete keys from the factors
|
||||||
KeySet allDiscrete = factors.discreteKeys();
|
const KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
// Create KeyVector with continuous keys followed by discrete keys.
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector newKeysDiscreteLast;
|
||||||
|
|
|
@ -16,7 +16,11 @@
|
||||||
* @date May 28, 2022
|
* @date May 28, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
|
#include <gtsam/hybrid/MixtureFactor.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
@ -63,8 +67,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||||
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
|
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
|
||||||
linearFG->push_back(hgf);
|
linearFG->push_back(hgf);
|
||||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
|
||||||
// If discrete-only: doesn't need linearization.
|
// If discrete-only: doesn't need linearization.
|
||||||
linearFG->push_back(f);
|
linearFG->push_back(f);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -18,16 +18,12 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
|
||||||
#include <gtsam/hybrid/MixtureFactor.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridGaussianFactorGraph;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Nonlinear Hybrid Factor Graph
|
* Nonlinear Hybrid Factor Graph
|
||||||
* -----------------------
|
* -----------------------
|
||||||
|
|
|
@ -68,17 +68,6 @@ virtual class HybridConditional {
|
||||||
double error(const gtsam::HybridValues& values) const;
|
double error(const gtsam::HybridValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
|
||||||
virtual class HybridDiscreteFactor {
|
|
||||||
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
|
|
||||||
void print(string s = "HybridDiscreteFactor\n",
|
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
|
||||||
gtsam::DefaultKeyFormatter) const;
|
|
||||||
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
|
|
||||||
gtsam::Factor* inner();
|
|
||||||
double error(const gtsam::HybridValues &values) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
class GaussianMixtureFactor : gtsam::HybridFactor {
|
class GaussianMixtureFactor : gtsam::HybridFactor {
|
||||||
GaussianMixtureFactor(
|
GaussianMixtureFactor(
|
||||||
|
@ -217,9 +206,7 @@ class HybridNonlinearFactorGraph {
|
||||||
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
|
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
|
||||||
void push_back(gtsam::HybridFactor* factor);
|
void push_back(gtsam::HybridFactor* factor);
|
||||||
void push_back(gtsam::NonlinearFactor* factor);
|
void push_back(gtsam::NonlinearFactor* factor);
|
||||||
void push_back(gtsam::HybridDiscreteFactor* factor);
|
void push_back(gtsam::DiscreteFactor* factor);
|
||||||
void add(gtsam::NonlinearFactor* factor);
|
|
||||||
void add(gtsam::DiscreteFactor* factor);
|
|
||||||
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;
|
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;
|
||||||
|
|
||||||
bool empty() const;
|
bool empty() const;
|
||||||
|
|
|
@ -206,13 +206,11 @@ struct Switching {
|
||||||
*/
|
*/
|
||||||
void addModeChain(HybridNonlinearFactorGraph *fg,
|
void addModeChain(HybridNonlinearFactorGraph *fg,
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
std::string discrete_transition_prob = "1/2 3/2") {
|
||||||
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1");
|
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||||
fg->push_discrete(prior);
|
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
for (size_t k = 0; k < K - 2; k++) {
|
||||||
auto parents = {modes[k]};
|
auto parents = {modes[k]};
|
||||||
auto conditional = boost::make_shared<DiscreteConditional>(
|
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||||
modes[k + 1], parents, discrete_transition_prob);
|
discrete_transition_prob);
|
||||||
fg->push_discrete(conditional);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -224,13 +222,11 @@ struct Switching {
|
||||||
*/
|
*/
|
||||||
void addModeChain(HybridGaussianFactorGraph *fg,
|
void addModeChain(HybridGaussianFactorGraph *fg,
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
std::string discrete_transition_prob = "1/2 3/2") {
|
||||||
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1");
|
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||||
fg->push_discrete(prior);
|
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
for (size_t k = 0; k < K - 2; k++) {
|
||||||
auto parents = {modes[k]};
|
auto parents = {modes[k]};
|
||||||
auto conditional = boost::make_shared<DiscreteConditional>(
|
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||||
modes[k + 1], parents, discrete_transition_prob);
|
discrete_transition_prob);
|
||||||
fg->push_discrete(conditional);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue