minor clean up and get tests to pass

release/4.3a0
Varun Agrawal 2022-11-08 14:00:44 -05:00
parent 3987b036a7
commit 1a3b343537
3 changed files with 87 additions and 26 deletions

View File

@ -512,9 +512,17 @@ HybridGaussianFactorGraph::continuousDelta(
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
const DiscreteKeys &orig_discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet) const {
// Generate all possible assignments.
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(orig_discrete_keys);
// Save a copy of the original discrete key ordering
DiscreteKeys discrete_keys(orig_discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);
@ -532,7 +540,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
}
double error = 0.0;
// Compute the error given the delta and the assignment.
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);
@ -563,15 +571,21 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional<Ordering> continuous, const boost::optional<Ordering> discrete) const {
Ordering continuous_ordering(this->continuousKeys()),
discrete_ordering(this->discreteKeys());
HybridGaussianFactorGraph::eliminateHybridSequential(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
Ordering continuous_ordering =
continuous ? *continuous : Ordering(this->continuousKeys());
Ordering discrete_ordering =
discrete ? *discrete : Ordering(this->discreteKeys());
// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(continuous_ordering);
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
function, variableIndex);
// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
@ -582,26 +596,54 @@ HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional<Order
return bayesNet;
}
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);
// Save a copy of the original discrete key ordering
DiscreteKeys orig_discrete_keys(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
AlgebraicDecisionTree<Key> probPrimeTree =
continuousProbPrimes(discrete_keys, bayesNet, assignments);
this->continuousProbPrimes(discrete_keys, bayesNet);
discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree));
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->eliminateSequential(discrete_ordering);
discreteGraph->BaseEliminateable::eliminateSequential(
discrete_ordering, function, variableIndex);
bayesNet->add(*discreteBayesNet);
return bayesNet;
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
OptionalOrderingType orderingType, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
return BaseEliminateable::eliminateSequential(orderingType, function,
variableIndex);
}
/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateSequential(
const Ordering &ordering, const Eliminate &function,
OptionalVariableIndex variableIndex) const {
KeySet all_continuous_keys = this->continuousKeys();
KeySet all_discrete_keys = this->discreteKeys();
Ordering continuous_ordering, discrete_ordering;
// Segregate the continuous and the discrete keys
for (auto &&key : ordering) {
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
key) != all_continuous_keys.end()) {
continuous_ordering.push_back(key);
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
key) != all_discrete_keys.end()) {
discrete_ordering.push_back(key);
} else {
throw std::runtime_error("Key in ordering not present in factors.");
}
}
return this->eliminateHybridSequential(continuous_ordering,
discrete_ordering);
}
} // namespace gtsam

View File

@ -214,14 +214,11 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous
* eliminated variables.
* @param assignments List of all discrete assignments to create the final
* decision tree.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> continuousProbPrimes(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const;
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const;
/**
* @brief Custom elimination function which computes the correct
@ -232,8 +229,20 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @return boost::shared_ptr<BayesNetType>
*/
boost::shared_ptr<BayesNetType> eliminateHybridSequential(
const boost::optional<Ordering> continuous,
const boost::optional<Ordering> discrete) const;
const boost::optional<Ordering> continuous = boost::none,
const boost::optional<Ordering> discrete = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
boost::shared_ptr<BayesNetType> eliminateSequential(
OptionalOrderingType orderingType = boost::none,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
boost::shared_ptr<BayesNetType> eliminateSequential(
const Ordering& ordering,
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = boost::none) const;
/**
* @brief Return a Colamd constrained ordering where the discrete keys are

View File

@ -372,7 +372,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
// All leaves should be probability 1 since this is not P*(X|M,Z)
EXPECT(discreteFactor->root_->isLeaf());
// TODO(Varun) Test emplace_discrete
}
@ -439,6 +440,15 @@ TEST(HybridFactorGraph, Full_Elimination) {
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
discrete_fg.push_back(df->inner());
}
// Get the probabilit P*(X | M, Z)
DiscreteKeys discrete_keys =
remainingFactorGraph_partial->at(2)->discreteKeys();
AlgebraicDecisionTree<Key> probPrimeTree =
linearizedFactorGraph.continuousProbPrimes(discrete_keys,
hybridBayesNet_partial);
discrete_fg.add(DecisionTreeFactor(discrete_keys, probPrimeTree));
ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
discreteBayesNet =