Removed HybridDiscreteFactor wrapper

release/4.3a0
Frank Dellaert 2023-01-06 21:07:51 -08:00
parent b93de21295
commit 7e32a8739e
10 changed files with 69 additions and 244 deletions

View File

@ -1,61 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridDiscreteFactor.cpp
* @brief Wrapper for a discrete factor
* @date Mar 11, 2022
* @author Fan Jiang
*/
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <boost/make_shared.hpp>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam {
/* ************************************************************************ */
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
->discreteKeys()),
inner_(other) {}
/* ************************************************************************ */
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
: Base(dtf.discreteKeys()),
inner_(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {}
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false;
if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
void HybridDiscreteFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner_->print("\n", formatter);
};
/* ************************************************************************ */
double HybridDiscreteFactor::error(const HybridValues &values) const {
return -log((*inner_)(values.discrete()));
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -1,91 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridDiscreteFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
namespace gtsam {
class HybridValues;
/**
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
* allows us to hide the implementation of DiscreteFactor and thus avoid
* diamond inheritance.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
private:
DiscreteFactor::shared_ptr inner_;
public:
using Base = HybridFactor;
using This = HybridDiscreteFactor;
using shared_ptr = boost::shared_ptr<This>;
/// @name Constructors
/// @{
/// Default constructor - for serialization.
HybridDiscreteFactor() = default;
// Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
// Forwarding constructor from concrete DecisionTreeFactor
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
/// @}
/// @name Testable
/// @{
virtual bool equals(const HybridFactor &lf, double tol) const override;
void print(
const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard Interface
/// @{
/// Return pointer to the internal discrete factor.
DiscreteFactor::shared_ptr inner() const { return inner_; }
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
};
// traits
template <>
struct traits<HybridDiscreteFactor> : public Testable<HybridDiscreteFactor> {};
} // namespace gtsam

View File

@ -67,11 +67,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
const DiscreteKeys &key2); const DiscreteKeys &key2);
/** /**
* Base class for hybrid probabilistic factors * Base class for *truly* hybrid probabilistic factors
* *
* Examples: * Examples:
* - HybridGaussianFactor * - MixtureFactor
* - HybridDiscreteFactor
* - GaussianMixtureFactor * - GaussianMixtureFactor
* - GaussianMixture * - GaussianMixture
* *

View File

@ -26,7 +26,6 @@
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
@ -47,7 +46,6 @@
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -58,6 +56,15 @@ namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s, and factor f:
static void throwRuntimeError(const std::string &s,
const boost::shared_ptr<Factor> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for factor type " +
demangle(typeid(fr).name()) + ".");
}
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianFactorGraphTree addGaussian( static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree, const GaussianFactorGraphTree &gfgTree,
@ -67,7 +74,6 @@ static GaussianFactorGraphTree addGaussian(
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0)); return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
} else { } else {
auto add = [&factor](const GraphAndConstant &graph_z) { auto add = [&factor](const GraphAndConstant &graph_z) {
auto result = graph_z.graph; auto result = graph_z.graph;
@ -103,8 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = addGaussian(result, gf->inner()); result = addGaussian(result, gf->inner());
} else if (dynamic_pointer_cast<DiscreteFactor>(f) || } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
@ -116,10 +121,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented " "gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
"yet."); "yet.");
} else { } else {
auto &fr = *f; throwRuntimeError("gtsam::assembleGraphTree", f);
throw std::invalid_argument(
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
demangle(typeid(fr).name()));
} }
} }
@ -129,16 +131,18 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
continuousElimination(const HybridGaussianFactorGraph &factors, continuousElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
using boost::dynamic_pointer_cast;
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
for (auto &fp : factors) { for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) { if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
gfg.push_back(ptr->inner()); gfg.push_back(hgf->inner());
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(fp)) {
gfg.push_back( auto gc = hc->asGaussian();
boost::static_pointer_cast<GaussianConditional>(ptr->inner())); assert(gc);
gfg.push_back(gc);
} else { } else {
// It is an orphan wrapped conditional // It is an orphan wrapped conditional
} }
@ -150,18 +154,17 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
} }
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &factor : factors) { for (auto &factor : factors) {
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) { if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
dfg.push_back(p->inner()); dfg.push_back(dtf);
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) { } else if (auto hc =
auto discrete_conditional = boost::static_pointer_cast<HybridConditional>(factor)) {
boost::static_pointer_cast<DiscreteConditional>(p->inner()); dfg.push_back(hc->asDiscrete());
dfg.push_back(discrete_conditional);
} else { } else {
// It is an orphan wrapper // It is an orphan wrapper
} }
@ -170,8 +173,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// NOTE: This does sum-product. For max-product, use EliminateForMPE. // NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys); auto result = EliminateDiscrete(dfg, frontalKeys);
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first), result.second};
boost::make_shared<HybridDiscreteFactor>(result.second)};
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -189,7 +191,7 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
} }
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
hybridElimination(const HybridGaussianFactorGraph &factors, hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys, const Ordering &frontalKeys,
const KeyVector &continuousSeparator, const KeyVector &continuousSeparator,
@ -291,7 +293,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(gaussianMixture), return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)}; discreteFactor};
} else { } else {
// Create a resulting GaussianMixtureFactor on the separator. // Create a resulting GaussianMixtureFactor on the separator.
return {boost::make_shared<HybridConditional>(gaussianMixture), return {boost::make_shared<HybridConditional>(gaussianMixture),
@ -314,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
* eliminate a discrete variable (as specified in the ordering), the result will * eliminate a discrete variable (as specified in the ordering), the result will
* be INCORRECT and there will be NO error raised. * be INCORRECT and there will be NO error raised.
*/ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>> //
EliminateHybrid(const HybridGaussianFactorGraph &factors, EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
// NOTE: Because we are in the Conditional Gaussian regime there are only // NOTE: Because we are in the Conditional Gaussian regime there are only
@ -374,14 +376,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
} }
// Build a map from keys to DiscreteKeys // Build a map from keys to DiscreteKeys
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey; auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
for (auto &&factor : factors) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
for (auto &k : p->discreteKeys()) {
mapFromKeyToDiscreteKey[k.first] = k;
}
}
}
// Fill in discrete frontals and continuous frontals. // Fill in discrete frontals and continuous frontals.
std::set<DiscreteKey> discreteFrontals; std::set<DiscreteKey> discreteFrontals;
@ -433,23 +428,25 @@ void HybridGaussianFactorGraph::add(JacobianFactor &&factor) {
} }
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::add(boost::shared_ptr<JacobianFactor> &factor) { void HybridGaussianFactorGraph::add(
const boost::shared_ptr<JacobianFactor> &factor) {
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor)); FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
} }
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) { void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor))); FactorGraph::add(std::move(factor));
} }
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { void HybridGaussianFactorGraph::add(
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor)); const DecisionTreeFactor::shared_ptr &factor) {
FactorGraph::add(factor);
} }
/* ************************************************************************ */ /* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys(); const KeySet discrete_keys = discreteKeySet();
const VariableIndex index(factors_); const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast( Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
@ -484,16 +481,11 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f) || } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else { } else {
auto &fr = *f; throwRuntimeError("HybridGaussianFactorGraph::error", f);
throw std::invalid_argument(
std::string(
"HybridGaussianFactorGraph::error: factor type not handled: ") +
demangle(typeid(fr).name()));
} }
} }
@ -503,9 +495,14 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
/* ************************************************************************ */ /* ************************************************************************ */
double HybridGaussianFactorGraph::error(const HybridValues &values) const { double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0; double error = 0.0;
for (auto &factor : factors_) { for (auto &f : factors_) {
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) { if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
error += p->error(values); // TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values);
} else if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete()));
} else {
throwRuntimeError("HybridGaussianFactorGraph::error", f);
} }
} }
return error; return error;

View File

@ -50,13 +50,13 @@ class HybridValues;
* @ingroup hybrid * @ingroup hybrid
*/ */
GTSAM_EXPORT GTSAM_EXPORT
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr> std::pair<boost::shared_ptr<HybridConditional>, boost::shared_ptr<Factor>>
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys); EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
/* ************************************************************************* */ /* ************************************************************************* */
template <> template <>
struct EliminationTraits<HybridGaussianFactorGraph> { struct EliminationTraits<HybridGaussianFactorGraph> {
typedef HybridFactor FactorType; ///< Type of factors in factor graph typedef Factor FactorType; ///< Type of factors in factor graph
typedef HybridGaussianFactorGraph typedef HybridGaussianFactorGraph
FactorGraphType; ///< Type of the factor graph (e.g. FactorGraphType; ///< Type of the factor graph (e.g.
///< HybridGaussianFactorGraph) ///< HybridGaussianFactorGraph)
@ -70,7 +70,7 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function /// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> > boost::shared_ptr<FactorType>>
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateHybrid(factors, keys); return EliminateHybrid(factors, keys);
} }
@ -80,7 +80,6 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
* Hybrid Gaussian Factor Graph * Hybrid Gaussian Factor Graph
* ----------------------- * -----------------------
* This is the linearized version of a hybrid factor graph. * This is the linearized version of a hybrid factor graph.
* Everything inside needs to be hybrid factor or hybrid conditional.
* *
* @ingroup hybrid * @ingroup hybrid
*/ */
@ -130,13 +129,13 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
void add(JacobianFactor&& factor); void add(JacobianFactor&& factor);
/// Add a Jacobian factor as a shared ptr. /// Add a Jacobian factor as a shared ptr.
void add(boost::shared_ptr<JacobianFactor>& factor); void add(const boost::shared_ptr<JacobianFactor>& factor);
/// Add a DecisionTreeFactor to the factor graph. /// Add a DecisionTreeFactor to the factor graph.
void add(DecisionTreeFactor&& factor); void add(DecisionTreeFactor&& factor);
/// Add a DecisionTreeFactor as a shared ptr. /// Add a DecisionTreeFactor as a shared ptr.
void add(DecisionTreeFactor::shared_ptr factor); void add(const boost::shared_ptr<DecisionTreeFactor>& factor);
/** /**
* Add a gaussian factor *pointer* to the internal gaussian factor graph * Add a gaussian factor *pointer* to the internal gaussian factor graph

View File

@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering(
HybridGaussianFactorGraph& factors, HybridGaussianFactorGraph& factors,
const HybridGaussianFactorGraph& newFactors) { const HybridGaussianFactorGraph& newFactors) {
// Get all the discrete keys from the factors // Get all the discrete keys from the factors
KeySet allDiscrete = factors.discreteKeys(); const KeySet allDiscrete = factors.discreteKeySet();
// Create KeyVector with continuous keys followed by discrete keys. // Create KeyVector with continuous keys followed by discrete keys.
KeyVector newKeysDiscreteLast; KeyVector newKeysDiscreteLast;

View File

@ -16,7 +16,11 @@
* @date May 28, 2022 * @date May 28, 2022
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/hybrid/MixtureFactor.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
namespace gtsam { namespace gtsam {
@ -63,8 +67,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf); const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
linearFG->push_back(hgf); linearFG->push_back(hgf);
} else if (dynamic_pointer_cast<DiscreteFactor>(f) || } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization. // If discrete-only: doesn't need linearization.
linearFG->push_back(f); linearFG->push_back(f);
} else { } else {

View File

@ -18,16 +18,12 @@
#pragma once #pragma once
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/MixtureFactor.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <boost/format.hpp>
namespace gtsam { namespace gtsam {
class HybridGaussianFactorGraph;
/** /**
* Nonlinear Hybrid Factor Graph * Nonlinear Hybrid Factor Graph
* ----------------------- * -----------------------

View File

@ -68,17 +68,6 @@ virtual class HybridConditional {
double error(const gtsam::HybridValues& values) const; double error(const gtsam::HybridValues& values) const;
}; };
#include <gtsam/hybrid/HybridDiscreteFactor.h>
virtual class HybridDiscreteFactor {
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
void print(string s = "HybridDiscreteFactor\n",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
gtsam::Factor* inner();
double error(const gtsam::HybridValues &values) const;
};
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
class GaussianMixtureFactor : gtsam::HybridFactor { class GaussianMixtureFactor : gtsam::HybridFactor {
GaussianMixtureFactor( GaussianMixtureFactor(
@ -217,9 +206,7 @@ class HybridNonlinearFactorGraph {
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph); HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
void push_back(gtsam::HybridFactor* factor); void push_back(gtsam::HybridFactor* factor);
void push_back(gtsam::NonlinearFactor* factor); void push_back(gtsam::NonlinearFactor* factor);
void push_back(gtsam::HybridDiscreteFactor* factor); void push_back(gtsam::DiscreteFactor* factor);
void add(gtsam::NonlinearFactor* factor);
void add(gtsam::DiscreteFactor* factor);
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const; gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;
bool empty() const; bool empty() const;

View File

@ -206,13 +206,11 @@ struct Switching {
*/ */
void addModeChain(HybridNonlinearFactorGraph *fg, void addModeChain(HybridNonlinearFactorGraph *fg,
std::string discrete_transition_prob = "1/2 3/2") { std::string discrete_transition_prob = "1/2 3/2") {
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1"); fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
fg->push_discrete(prior);
for (size_t k = 0; k < K - 2; k++) { for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]}; auto parents = {modes[k]};
auto conditional = boost::make_shared<DiscreteConditional>( fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
modes[k + 1], parents, discrete_transition_prob); discrete_transition_prob);
fg->push_discrete(conditional);
} }
} }
@ -224,13 +222,11 @@ struct Switching {
*/ */
void addModeChain(HybridGaussianFactorGraph *fg, void addModeChain(HybridGaussianFactorGraph *fg,
std::string discrete_transition_prob = "1/2 3/2") { std::string discrete_transition_prob = "1/2 3/2") {
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1"); fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
fg->push_discrete(prior);
for (size_t k = 0; k < K - 2; k++) { for (size_t k = 0; k < K - 2; k++) {
auto parents = {modes[k]}; auto parents = {modes[k]};
auto conditional = boost::make_shared<DiscreteConditional>( fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
modes[k + 1], parents, discrete_transition_prob); discrete_transition_prob);
fg->push_discrete(conditional);
} }
} }
}; };