initial changes

release/4.3a0
Varun Agrawal 2023-07-08 13:09:35 -04:00
parent 4e13fb717b
commit f6b1872b13
6 changed files with 16 additions and 13 deletions

View File

@ -17,7 +17,6 @@
* @date January, 2023 * @date January, 2023
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
namespace gtsam { namespace gtsam {
@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const { std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys; std::set<DiscreteKey> keys;
for (auto& factor : factors_) { for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) { if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) { for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key); keys.insert(key);
} }

View File

@ -48,8 +48,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
// #define HYBRID_TIMING
namespace gtsam { namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// TODO(dellaert): in C++20, we can use std::visit. // TODO(dellaert): in C++20, we can use std::visit.
continue; continue;
} }
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors // Don't do anything for discrete-only factors
// since we want to eliminate continuous values only. // since we want to eliminate continuous values only.
continue; continue;
@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto &f : factors) { for (auto &f : factors) {
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) { if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
dfg.push_back(dtf); dfg.push_back(df);
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) { } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique. // Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here. // TODO(dellaert): is this correct? If so explain here.
@ -262,9 +260,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}; };
DecisionTree<Key, double> probabilities(eliminationResults, probability); DecisionTree<Key, double> probabilities(eliminationResults, probability);
auto dtf =
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
return { return {
std::make_shared<HybridConditional>(gaussianMixture), std::make_shared<HybridConditional>(gaussianMixture),
std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)}; std::make_shared<TableFactor>(discreteSeparator, dtf->probabilities())};
} else { } else {
// Otherwise, we create a resulting GaussianMixtureFactor on the separator, // Otherwise, we create a resulting GaussianMixtureFactor on the separator,
// taking care to correct for conditional constant. // taking care to correct for conditional constant.
@ -433,7 +435,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Add the gaussian factor error to every leaf of the error tree. // Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else { } else {

View File

@ -40,6 +40,7 @@ class HybridEliminationTree;
class HybridBayesTree; class HybridBayesTree;
class HybridJunctionTree; class HybridJunctionTree;
class DecisionTreeFactor; class DecisionTreeFactor;
class TableFactor;
class JacobianFactor; class JacobianFactor;
class HybridValues; class HybridValues;

View File

@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
} else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
for (auto& k : hf->discreteKeys()) { for (auto& k : hf->discreteKeys()) {
data.discreteKeys.insert(k.first); data.discreteKeys.insert(k.first);
} }
@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
Data rootData(0); Data rootData(0);
rootData.junctionTreeNode = rootData.junctionTreeNode =
std::make_shared<typename Base::Node>(); // Make a dummy node to gather std::make_shared<typename Base::Node>(); // Make a dummy node to gather
// the junction tree roots // the junction tree roots
treeTraversal::DepthFirstForest(eliminationTree, rootData, treeTraversal::DepthFirstForest(eliminationTree, rootData,
Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPre,
Data::ConstructorTraversalVisitorPost); Data::ConstructorTraversalVisitorPost);

View File

@ -17,6 +17,7 @@
*/ */
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
} else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) { } else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
linearFG->push_back(gf); linearFG->push_back(gf);
} else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If discrete-only: doesn't need linearization. // If discrete-only: doesn't need linearization.
linearFG->push_back(f); linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {

View File

@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
for (auto&& f : *remainingFactorGraph) { for (auto&& f : *remainingFactorGraph) {
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f); auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
assert(discreteFactor); assert(discreteFactor);
dfg.push_back(discreteFactor); dfg.push_back(discreteFactor);
} }