Removed HybridDiscreteFactor wrapper
parent
b93de21295
commit
7e32a8739e
|
@ -1,61 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridDiscreteFactor.cpp
|
||||
* @brief Wrapper for a discrete factor
|
||||
* @date Mar 11, 2022
|
||||
* @author Fan Jiang
|
||||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <boost/make_shared.hpp>
|
||||
|
||||
#include "gtsam/discrete/DecisionTreeFactor.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
|
||||
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
|
||||
->discreteKeys()),
|
||||
inner_(other) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
||||
: Base(dtf.discreteKeys()),
|
||||
inner_(boost::make_shared<DecisionTreeFactor>(std::move(dtf))) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
const This *e = dynamic_cast<const This *>(&lf);
|
||||
if (e == nullptr) return false;
|
||||
if (!Base::equals(*e, tol)) return false;
|
||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||
: !(e->inner_);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridDiscreteFactor::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
inner_->print("\n", formatter);
|
||||
};
|
||||
|
||||
/* ************************************************************************ */
|
||||
double HybridDiscreteFactor::error(const HybridValues &values) const {
|
||||
return -log((*inner_)(values.discrete()));
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // namespace gtsam
|
|
@ -1,91 +0,0 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridDiscreteFactor.h
|
||||
* @date Mar 11, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which
|
||||
* allows us to hide the implementation of DiscreteFactor and thus avoid
|
||||
* diamond inheritance.
|
||||
*
|
||||
* @ingroup hybrid
|
||||
*/
|
||||
class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||
private:
|
||||
DiscreteFactor::shared_ptr inner_;
|
||||
|
||||
public:
|
||||
using Base = HybridFactor;
|
||||
using This = HybridDiscreteFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor - for serialization.
|
||||
HybridDiscreteFactor() = default;
|
||||
|
||||
// Implicit conversion from a shared ptr of DF
|
||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
||||
|
||||
// Forwarding constructor from concrete DecisionTreeFactor
|
||||
HybridDiscreteFactor(DecisionTreeFactor &&dtf);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
virtual bool equals(const HybridFactor &lf, double tol) const override;
|
||||
|
||||
void print(
|
||||
const std::string &s = "HybridFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Return pointer to the internal discrete factor.
|
||||
DiscreteFactor::shared_ptr inner() const { return inner_; }
|
||||
|
||||
/// Return the error of the underlying Discrete Factor.
|
||||
double error(const HybridValues &values) const override;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||
}
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<HybridDiscreteFactor> : public Testable<HybridDiscreteFactor> {};
|
||||
|
||||
} // namespace gtsam
|
|
@ -67,11 +67,10 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
|||
const DiscreteKeys &key2);
|
||||
|
||||
/**
|
||||
* Base class for hybrid probabilistic factors
|
||||
* Base class for *truly* hybrid probabilistic factors
|
||||
*
|
||||
* Examples:
|
||||
* - HybridGaussianFactor
|
||||
* - HybridDiscreteFactor
|
||||
* - MixtureFactor
|
||||
* - GaussianMixtureFactor
|
||||
* - GaussianMixture
|
||||
*
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
|
@ -47,7 +46,6 @@
|
|||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
@ -58,6 +56,15 @@ namespace gtsam {
|
|||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
||||
/* ************************************************************************ */
|
||||
// Throw a runtime exception for method specified in string s, and factor f:
|
||||
static void throwRuntimeError(const std::string &s,
|
||||
const boost::shared_ptr<Factor> &f) {
|
||||
auto &fr = *f;
|
||||
throw std::runtime_error(s + " not implemented for factor type " +
|
||||
demangle(typeid(fr).name()) + ".");
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianFactorGraphTree addGaussian(
|
||||
const GaussianFactorGraphTree &gfgTree,
|
||||
|
@ -67,7 +74,6 @@ static GaussianFactorGraphTree addGaussian(
|
|||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
||||
|
||||
} else {
|
||||
auto add = [&factor](const GraphAndConstant &graph_z) {
|
||||
auto result = graph_z.graph;
|
||||
|
@ -103,8 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
}
|
||||
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
result = addGaussian(result, gf->inner());
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// Don't do anything for discrete-only factors
|
||||
// since we want to eliminate continuous values only.
|
||||
continue;
|
||||
|
@ -116,10 +121,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
|
||||
"yet.");
|
||||
} else {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("gtsam::assembleGraphTree: factor type not handled: ") +
|
||||
demangle(typeid(fr).name()));
|
||||
throwRuntimeError("gtsam::assembleGraphTree", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,16 +131,18 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
using boost::dynamic_pointer_cast;
|
||||
GaussianFactorGraph gfg;
|
||||
for (auto &fp : factors) {
|
||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||
gfg.push_back(ptr->inner());
|
||||
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
||||
gfg.push_back(
|
||||
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
||||
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||
gfg.push_back(hgf->inner());
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(fp)) {
|
||||
auto gc = hc->asGaussian();
|
||||
assert(gc);
|
||||
gfg.push_back(gc);
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
}
|
||||
|
@ -150,18 +154,17 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||
discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &factor : factors) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridDiscreteFactor>(factor)) {
|
||||
dfg.push_back(p->inner());
|
||||
} else if (auto p = boost::static_pointer_cast<HybridConditional>(factor)) {
|
||||
auto discrete_conditional =
|
||||
boost::static_pointer_cast<DiscreteConditional>(p->inner());
|
||||
dfg.push_back(discrete_conditional);
|
||||
if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
dfg.push_back(dtf);
|
||||
} else if (auto hc =
|
||||
boost::static_pointer_cast<HybridConditional>(factor)) {
|
||||
dfg.push_back(hc->asDiscrete());
|
||||
} else {
|
||||
// It is an orphan wrapper
|
||||
}
|
||||
|
@ -170,8 +173,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridDiscreteFactor>(result.second)};
|
||||
return {boost::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
@ -189,7 +191,7 @@ GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||
hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys,
|
||||
const KeyVector &continuousSeparator,
|
||||
|
@ -291,7 +293,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||
discreteFactor};
|
||||
} else {
|
||||
// Create a resulting GaussianMixtureFactor on the separator.
|
||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||
|
@ -314,7 +316,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
* eliminate a discrete variable (as specified in the ordering), the result will
|
||||
* be INCORRECT and there will be NO error raised.
|
||||
*/
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
||||
std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>> //
|
||||
EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
// NOTE: Because we are in the Conditional Gaussian regime there are only
|
||||
|
@ -374,14 +376,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
|
||||
// Build a map from keys to DiscreteKeys
|
||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
||||
for (auto &&factor : factors) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
for (auto &k : p->discreteKeys()) {
|
||||
mapFromKeyToDiscreteKey[k.first] = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
|
||||
|
||||
// Fill in discrete frontals and continuous frontals.
|
||||
std::set<DiscreteKey> discreteFrontals;
|
||||
|
@ -433,23 +428,25 @@ void HybridGaussianFactorGraph::add(JacobianFactor &&factor) {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridGaussianFactorGraph::add(boost::shared_ptr<JacobianFactor> &factor) {
|
||||
void HybridGaussianFactorGraph::add(
|
||||
const boost::shared_ptr<JacobianFactor> &factor) {
|
||||
FactorGraph::add(boost::make_shared<HybridGaussianFactor>(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) {
|
||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(std::move(factor)));
|
||||
FactorGraph::add(std::move(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
void HybridGaussianFactorGraph::add(
|
||||
const DecisionTreeFactor::shared_ptr &factor) {
|
||||
FactorGraph::add(factor);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||
KeySet discrete_keys = discreteKeys();
|
||||
const KeySet discrete_keys = discreteKeySet();
|
||||
const VariableIndex index(factors_);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
|
@ -484,16 +481,11 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string(
|
||||
"HybridGaussianFactorGraph::error: factor type not handled: ") +
|
||||
demangle(typeid(fr).name()));
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -503,9 +495,14 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
/* ************************************************************************ */
|
||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &factor : factors_) {
|
||||
if (auto p = boost::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||
error += p->error(values);
|
||||
for (auto &f : factors_) {
|
||||
if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||
error += hf->error(values);
|
||||
} else if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
error -= log((*dtf)(values.discrete()));
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||
}
|
||||
}
|
||||
return error;
|
||||
|
|
|
@ -50,13 +50,13 @@ class HybridValues;
|
|||
* @ingroup hybrid
|
||||
*/
|
||||
GTSAM_EXPORT
|
||||
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
|
||||
std::pair<boost::shared_ptr<HybridConditional>, boost::shared_ptr<Factor>>
|
||||
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);
|
||||
|
||||
/* ************************************************************************* */
|
||||
template <>
|
||||
struct EliminationTraits<HybridGaussianFactorGraph> {
|
||||
typedef HybridFactor FactorType; ///< Type of factors in factor graph
|
||||
typedef Factor FactorType; ///< Type of factors in factor graph
|
||||
typedef HybridGaussianFactorGraph
|
||||
FactorGraphType; ///< Type of the factor graph (e.g.
|
||||
///< HybridGaussianFactorGraph)
|
||||
|
@ -70,7 +70,7 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
|
|||
typedef HybridJunctionTree JunctionTreeType; ///< Type of Junction tree
|
||||
/// The default dense elimination function
|
||||
static std::pair<boost::shared_ptr<ConditionalType>,
|
||||
boost::shared_ptr<FactorType> >
|
||||
boost::shared_ptr<FactorType>>
|
||||
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
|
||||
return EliminateHybrid(factors, keys);
|
||||
}
|
||||
|
@ -80,7 +80,6 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
|
|||
* Hybrid Gaussian Factor Graph
|
||||
* -----------------------
|
||||
* This is the linearized version of a hybrid factor graph.
|
||||
* Everything inside needs to be hybrid factor or hybrid conditional.
|
||||
*
|
||||
* @ingroup hybrid
|
||||
*/
|
||||
|
@ -130,13 +129,13 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
void add(JacobianFactor&& factor);
|
||||
|
||||
/// Add a Jacobian factor as a shared ptr.
|
||||
void add(boost::shared_ptr<JacobianFactor>& factor);
|
||||
void add(const boost::shared_ptr<JacobianFactor>& factor);
|
||||
|
||||
/// Add a DecisionTreeFactor to the factor graph.
|
||||
void add(DecisionTreeFactor&& factor);
|
||||
|
||||
/// Add a DecisionTreeFactor as a shared ptr.
|
||||
void add(DecisionTreeFactor::shared_ptr factor);
|
||||
void add(const boost::shared_ptr<DecisionTreeFactor>& factor);
|
||||
|
||||
/**
|
||||
* Add a gaussian factor *pointer* to the internal gaussian factor graph
|
||||
|
|
|
@ -43,7 +43,7 @@ Ordering HybridGaussianISAM::GetOrdering(
|
|||
HybridGaussianFactorGraph& factors,
|
||||
const HybridGaussianFactorGraph& newFactors) {
|
||||
// Get all the discrete keys from the factors
|
||||
KeySet allDiscrete = factors.discreteKeys();
|
||||
const KeySet allDiscrete = factors.discreteKeySet();
|
||||
|
||||
// Create KeyVector with continuous keys followed by discrete keys.
|
||||
KeyVector newKeysDiscreteLast;
|
||||
|
|
|
@ -16,7 +16,11 @@
|
|||
* @date May 28, 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||
#include <gtsam/hybrid/MixtureFactor.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
|
@ -63,8 +67,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
|||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||
const auto hgf = boost::make_shared<HybridGaussianFactor>(gf);
|
||||
linearFG->push_back(hgf);
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f) ||
|
||||
dynamic_pointer_cast<HybridDiscreteFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
// If discrete-only: doesn't need linearization.
|
||||
linearFG->push_back(f);
|
||||
} else {
|
||||
|
|
|
@ -18,16 +18,12 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/MixtureFactor.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
|
||||
#include <boost/format.hpp>
|
||||
namespace gtsam {
|
||||
|
||||
class HybridGaussianFactorGraph;
|
||||
|
||||
/**
|
||||
* Nonlinear Hybrid Factor Graph
|
||||
* -----------------------
|
||||
|
|
|
@ -68,17 +68,6 @@ virtual class HybridConditional {
|
|||
double error(const gtsam::HybridValues& values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
virtual class HybridDiscreteFactor {
|
||||
HybridDiscreteFactor(gtsam::DecisionTreeFactor dtf);
|
||||
void print(string s = "HybridDiscreteFactor\n",
|
||||
const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
|
||||
gtsam::Factor* inner();
|
||||
double error(const gtsam::HybridValues &values) const;
|
||||
};
|
||||
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
class GaussianMixtureFactor : gtsam::HybridFactor {
|
||||
GaussianMixtureFactor(
|
||||
|
@ -217,9 +206,7 @@ class HybridNonlinearFactorGraph {
|
|||
HybridNonlinearFactorGraph(const gtsam::HybridNonlinearFactorGraph& graph);
|
||||
void push_back(gtsam::HybridFactor* factor);
|
||||
void push_back(gtsam::NonlinearFactor* factor);
|
||||
void push_back(gtsam::HybridDiscreteFactor* factor);
|
||||
void add(gtsam::NonlinearFactor* factor);
|
||||
void add(gtsam::DiscreteFactor* factor);
|
||||
void push_back(gtsam::DiscreteFactor* factor);
|
||||
gtsam::HybridGaussianFactorGraph linearize(const gtsam::Values& continuousValues) const;
|
||||
|
||||
bool empty() const;
|
||||
|
|
|
@ -206,13 +206,11 @@ struct Switching {
|
|||
*/
|
||||
void addModeChain(HybridNonlinearFactorGraph *fg,
|
||||
std::string discrete_transition_prob = "1/2 3/2") {
|
||||
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
fg->push_discrete(prior);
|
||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
for (size_t k = 0; k < K - 2; k++) {
|
||||
auto parents = {modes[k]};
|
||||
auto conditional = boost::make_shared<DiscreteConditional>(
|
||||
modes[k + 1], parents, discrete_transition_prob);
|
||||
fg->push_discrete(conditional);
|
||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||
discrete_transition_prob);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -224,13 +222,11 @@ struct Switching {
|
|||
*/
|
||||
void addModeChain(HybridGaussianFactorGraph *fg,
|
||||
std::string discrete_transition_prob = "1/2 3/2") {
|
||||
auto prior = boost::make_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
fg->push_discrete(prior);
|
||||
fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||
for (size_t k = 0; k < K - 2; k++) {
|
||||
auto parents = {modes[k]};
|
||||
auto conditional = boost::make_shared<DiscreteConditional>(
|
||||
modes[k + 1], parents, discrete_transition_prob);
|
||||
fg->push_discrete(conditional);
|
||||
fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||
discrete_transition_prob);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue