initial changes

release/4.3a0
Varun Agrawal 2023-07-08 13:09:35 -04:00
parent 4e13fb717b
commit f6b1872b13
6 changed files with 16 additions and 13 deletions

View File

@ -17,7 +17,6 @@
* @date January, 2023
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
namespace gtsam {
@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys;
for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}

View File

@ -48,8 +48,6 @@
#include <utility>
#include <vector>
// #define HYBRID_TIMING
namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): in C++20, we can use std::visit.
continue;
}
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to eliminate continuous values only.
continue;
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
DiscreteFactorGraph dfg;
for (auto &f : factors) {
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dfg.push_back(dtf);
if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(df);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here.
@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
};
DecisionTree<Key, double> probabilities(eliminationResults, probability);
auto dtf =
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
return {
std::make_shared<HybridConditional>(gaussianMixture),
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
std::make_shared<TableFactor>(discreteSeparator, dtf->probabilities())};
} else {
// Otherwise, we create a resulting GaussianMixtureFactor on the separator,
// taking care to correct for conditional constant.
@ -433,7 +435,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
continue;
} else {

View File

@ -40,6 +40,7 @@ class HybridEliminationTree;
class HybridBayesTree;
class HybridJunctionTree;
class DecisionTreeFactor;
class TableFactor;
class JacobianFactor;
class HybridValues;

View File

@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
} else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first);
}
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
Data rootData(0);
rootData.junctionTreeNode =
std::make_shared<typename Base::Node>(); // Make a dummy node to gather
// the junction tree roots
// the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPost);

View File

@ -17,6 +17,7 @@
*/
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
linearFG->push_back(gf);
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization.
linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {

View File

@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor);
dfg.push_back(discreteFactor);
}