initial changes
parent
4e13fb717b
commit
f6b1872b13
|
@ -17,7 +17,6 @@
|
||||||
* @date January, 2023
|
* @date January, 2023
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -26,7 +25,7 @@ namespace gtsam {
|
||||||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
|
||||||
std::set<DiscreteKey> keys;
|
std::set<DiscreteKey> keys;
|
||||||
for (auto& factor : factors_) {
|
for (auto& factor : factors_) {
|
||||||
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
for (const DiscreteKey& key : p->discreteKeys()) {
|
for (const DiscreteKey& key : p->discreteKeys()) {
|
||||||
keys.insert(key);
|
keys.insert(key);
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,8 +48,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// #define HYBRID_TIMING
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
|
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
// TODO(dellaert): in C++20, we can use std::visit.
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// Don't do anything for discrete-only factors
|
// Don't do anything for discrete-only factors
|
||||||
// since we want to eliminate continuous values only.
|
// since we want to eliminate continuous values only.
|
||||||
continue;
|
continue;
|
||||||
|
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
|
|
||||||
for (auto &f : factors) {
|
for (auto &f : factors) {
|
||||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
dfg.push_back(dtf);
|
dfg.push_back(df);
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
// TODO(dellaert): is this correct? If so explain here.
|
// TODO(dellaert): is this correct? If so explain here.
|
||||||
|
@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
DecisionTree<Key, double> probabilities(eliminationResults, probability);
|
||||||
|
|
||||||
|
auto dtf =
|
||||||
|
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
std::make_shared<HybridConditional>(gaussianMixture),
|
std::make_shared<HybridConditional>(gaussianMixture),
|
||||||
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
|
std::make_shared<TableFactor>(discreteSeparator, dtf->probabilities())};
|
||||||
} else {
|
} else {
|
||||||
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
|
||||||
// taking care to correct for conditional constant.
|
// taking care to correct for conditional constant.
|
||||||
|
@ -433,7 +435,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
// Add the gaussian factor error to every leaf of the error tree.
|
// Add the gaussian factor error to every leaf of the error tree.
|
||||||
error_tree = error_tree.apply(
|
error_tree = error_tree.apply(
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
[error](double leaf_value) { return leaf_value + error; });
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -40,6 +40,7 @@ class HybridEliminationTree;
|
||||||
class HybridBayesTree;
|
class HybridBayesTree;
|
||||||
class HybridJunctionTree;
|
class HybridJunctionTree;
|
||||||
class DecisionTreeFactor;
|
class DecisionTreeFactor;
|
||||||
|
class TableFactor;
|
||||||
class JacobianFactor;
|
class JacobianFactor;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
for (auto& k : hf->discreteKeys()) {
|
for (auto& k : hf->discreteKeys()) {
|
||||||
data.discreteKeys.insert(k.first);
|
data.discreteKeys.insert(k.first);
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
|
||||||
Data rootData(0);
|
Data rootData(0);
|
||||||
rootData.junctionTreeNode =
|
rootData.junctionTreeNode =
|
||||||
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
|
||||||
// the junction tree roots
|
// the junction tree roots
|
||||||
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
treeTraversal::DepthFirstForest(eliminationTree, rootData,
|
||||||
Data::ConstructorTraversalVisitorPre,
|
Data::ConstructorTraversalVisitorPre,
|
||||||
Data::ConstructorTraversalVisitorPost);
|
Data::ConstructorTraversalVisitorPost);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/GaussianMixture.h>
|
#include <gtsam/hybrid/GaussianMixture.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
|
||||||
|
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||||
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
|
||||||
linearFG->push_back(gf);
|
linearFG->push_back(gf);
|
||||||
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
// If discrete-only: doesn't need linearization.
|
// If discrete-only: doesn't need linearization.
|
||||||
linearFG->push_back(f);
|
linearFG->push_back(f);
|
||||||
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
|
|
@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
|
||||||
|
|
||||||
DiscreteFactorGraph dfg;
|
DiscreteFactorGraph dfg;
|
||||||
for (auto&& f : *remainingFactorGraph) {
|
for (auto&& f : *remainingFactorGraph) {
|
||||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
|
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
|
||||||
assert(discreteFactor);
|
assert(discreteFactor);
|
||||||
dfg.push_back(discreteFactor);
|
dfg.push_back(discreteFactor);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue