initial changes
parent
4e13fb717b
commit
f6b1872b13
|
@ -17,7 +17,6 @@
|
|||
* @date January, 2023
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -26,7 +25,7 @@ namespace gtsam {
|
|||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||
std::set<DiscreteKey> keys;
|
||||
for (auto& factor : factors_) {
|
||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||
keys.insert(key);
|
||||
}
|
||||
|
|
|
@ -48,8 +48,6 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// #define HYBRID_TIMING
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
|
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
// TODO(dellaert): in C++20, we can use std::visit.
|
||||
continue;
|
||||
}
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// Don't do anything for discrete-only factors
|
||||
// since we want to eliminate continuous values only.
|
||||
continue;
|
||||
|
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
dfg.push_back(dtf);
|
||||
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
dfg.push_back(df);
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
// TODO(dellaert): is this correct? If so explain here.
|
||||
|
@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
};
|
||||
|
||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||
|
||||
auto dtf =
|
||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||
|
||||
return {
|
||||
std::make_shared<HybridConditional>(gaussianMixture),
|
||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
||||
std::make_shared<TableFactor>(discreteSeparator, dtf->probabilities())};
|
||||
} else {
|
||||
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
||||
// taking care to correct for conditional constant.
|
||||
|
@ -433,7 +435,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
// Add the gaussian factor error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
|
|
|
@ -40,6 +40,7 @@ class HybridEliminationTree;
|
|||
class HybridBayesTree;
|
||||
class HybridJunctionTree;
|
||||
class DecisionTreeFactor;
|
||||
class TableFactor;
|
||||
class JacobianFactor;
|
||||
class HybridValues;
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
|
|||
for (auto& k : hf->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
}
|
||||
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
for (auto& k : hf->discreteKeys()) {
|
||||
data.discreteKeys.insert(k.first);
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/GaussianMixture.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||
|
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
|||
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||
linearFG->push_back(gf);
|
||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// If discrete-only: doesn't need linearization.
|
||||
linearFG->push_back(f);
|
||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
|
|
|
@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
|
|||
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto&& f : *remainingFactorGraph) {
|
||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
|
||||
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
||||
assert(discreteFactor);
|
||||
dfg.push_back(discreteFactor);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue