minor clean up and get tests to pass
parent
3987b036a7
commit
1a3b343537
|
|
@ -512,9 +512,17 @@ HybridGaussianFactorGraph::continuousDelta(
|
|||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
||||
const DiscreteKeys &discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
|
||||
const std::vector<DiscreteValues> &assignments) const {
|
||||
const DiscreteKeys &orig_discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType> &continuousBayesNet) const {
|
||||
// Generate all possible assignments.
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(orig_discrete_keys);
|
||||
|
||||
// Save a copy of the original discrete key ordering
|
||||
DiscreteKeys discrete_keys(orig_discrete_keys);
|
||||
// Reverse discrete keys order for correct tree construction
|
||||
std::reverse(discrete_keys.begin(), discrete_keys.end());
|
||||
|
||||
// Create a decision tree of all the different VectorValues
|
||||
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
|
||||
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);
|
||||
|
|
@ -532,7 +540,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
|||
}
|
||||
|
||||
double error = 0.0;
|
||||
|
||||
// Compute the error given the delta and the assignment.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
auto factor = factors_.at(idx);
|
||||
|
||||
|
|
@ -563,15 +571,21 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
|||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional<Ordering> continuous, const boost::optional<Ordering> discrete) const {
|
||||
Ordering continuous_ordering(this->continuousKeys()),
|
||||
discrete_ordering(this->discreteKeys());
|
||||
HybridGaussianFactorGraph::eliminateHybridSequential(
|
||||
const boost::optional<Ordering> continuous,
|
||||
const boost::optional<Ordering> discrete, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
Ordering continuous_ordering =
|
||||
continuous ? *continuous : Ordering(this->continuousKeys());
|
||||
Ordering discrete_ordering =
|
||||
discrete ? *discrete : Ordering(this->discreteKeys());
|
||||
|
||||
// Eliminate continuous
|
||||
HybridBayesNet::shared_ptr bayesNet;
|
||||
HybridGaussianFactorGraph::shared_ptr discreteGraph;
|
||||
std::tie(bayesNet, discreteGraph) =
|
||||
BaseEliminateable::eliminatePartialSequential(continuous_ordering);
|
||||
BaseEliminateable::eliminatePartialSequential(continuous_ordering,
|
||||
function, variableIndex);
|
||||
|
||||
// Get the last continuous conditional which will have all the discrete keys
|
||||
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
|
||||
|
|
@ -582,26 +596,54 @@ HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional<Order
|
|||
return bayesNet;
|
||||
}
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(discrete_keys);
|
||||
|
||||
// Save a copy of the original discrete key ordering
|
||||
DiscreteKeys orig_discrete_keys(discrete_keys);
|
||||
// Reverse discrete keys order for correct tree construction
|
||||
std::reverse(discrete_keys.begin(), discrete_keys.end());
|
||||
|
||||
AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
continuousProbPrimes(discrete_keys, bayesNet, assignments);
|
||||
this->continuousProbPrimes(discrete_keys, bayesNet);
|
||||
|
||||
discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree));
|
||||
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
// Perform discrete elimination
|
||||
HybridBayesNet::shared_ptr discreteBayesNet =
|
||||
discreteGraph->eliminateSequential(discrete_ordering);
|
||||
discreteGraph->BaseEliminateable::eliminateSequential(
|
||||
discrete_ordering, function, variableIndex);
|
||||
|
||||
bayesNet->add(*discreteBayesNet);
|
||||
|
||||
return bayesNet;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateSequential(
|
||||
OptionalOrderingType orderingType, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
return BaseEliminateable::eliminateSequential(orderingType, function,
|
||||
variableIndex);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
|
||||
HybridGaussianFactorGraph::eliminateSequential(
|
||||
const Ordering &ordering, const Eliminate &function,
|
||||
OptionalVariableIndex variableIndex) const {
|
||||
KeySet all_continuous_keys = this->continuousKeys();
|
||||
KeySet all_discrete_keys = this->discreteKeys();
|
||||
Ordering continuous_ordering, discrete_ordering;
|
||||
|
||||
// Segregate the continuous and the discrete keys
|
||||
for (auto &&key : ordering) {
|
||||
if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(),
|
||||
key) != all_continuous_keys.end()) {
|
||||
continuous_ordering.push_back(key);
|
||||
} else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(),
|
||||
key) != all_discrete_keys.end()) {
|
||||
discrete_ordering.push_back(key);
|
||||
} else {
|
||||
throw std::runtime_error("Key in ordering not present in factors.");
|
||||
}
|
||||
}
|
||||
|
||||
return this->eliminateHybridSequential(continuous_ordering,
|
||||
discrete_ordering);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -214,14 +214,11 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @param discrete_keys The discrete keys which form all the modes.
|
||||
* @param continuousBayesNet The Bayes Net representing the continuous
|
||||
* eliminated variables.
|
||||
* @param assignments List of all discrete assignments to create the final
|
||||
* decision tree.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> continuousProbPrimes(
|
||||
const DiscreteKeys& discrete_keys,
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
|
||||
const std::vector<DiscreteValues>& assignments) const;
|
||||
const boost::shared_ptr<BayesNetType>& continuousBayesNet) const;
|
||||
|
||||
/**
|
||||
* @brief Custom elimination function which computes the correct
|
||||
|
|
@ -232,8 +229,20 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @return boost::shared_ptr<BayesNetType>
|
||||
*/
|
||||
boost::shared_ptr<BayesNetType> eliminateHybridSequential(
|
||||
const boost::optional<Ordering> continuous,
|
||||
const boost::optional<Ordering> discrete) const;
|
||||
const boost::optional<Ordering> continuous = boost::none,
|
||||
const boost::optional<Ordering> discrete = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
boost::shared_ptr<BayesNetType> eliminateSequential(
|
||||
OptionalOrderingType orderingType = boost::none,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
boost::shared_ptr<BayesNetType> eliminateSequential(
|
||||
const Ordering& ordering,
|
||||
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
|
||||
OptionalVariableIndex variableIndex = boost::none) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
|
|
|
|||
|
|
@ -372,7 +372,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
|||
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
|
||||
CHECK(discreteFactor);
|
||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
||||
// All leaves should be probability 1 since this is not P*(X|M,Z)
|
||||
EXPECT(discreteFactor->root_->isLeaf());
|
||||
|
||||
// TODO(Varun) Test emplace_discrete
|
||||
}
|
||||
|
|
@ -439,6 +440,15 @@ TEST(HybridFactorGraph, Full_Elimination) {
|
|||
auto df = dynamic_pointer_cast<HybridDiscreteFactor>(factor);
|
||||
discrete_fg.push_back(df->inner());
|
||||
}
|
||||
|
||||
// Get the probabilit P*(X | M, Z)
|
||||
DiscreteKeys discrete_keys =
|
||||
remainingFactorGraph_partial->at(2)->discreteKeys();
|
||||
AlgebraicDecisionTree<Key> probPrimeTree =
|
||||
linearizedFactorGraph.continuousProbPrimes(discrete_keys,
|
||||
hybridBayesNet_partial);
|
||||
discrete_fg.add(DecisionTreeFactor(discrete_keys, probPrimeTree));
|
||||
|
||||
ordering.clear();
|
||||
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
|
||||
discreteBayesNet =
|
||||
|
|
|
|||
Loading…
Reference in New Issue