Merge pull request #1590 from borglab/hybrid-tablefactor-3
commit
c5740b2221
|
@ -82,6 +82,22 @@ namespace gtsam {
|
||||||
ADT::print("", formatter);
|
ADT::print("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
|
||||||
|
// apply operand
|
||||||
|
ADT result = ADT::apply(op);
|
||||||
|
// Make a new factor
|
||||||
|
return DecisionTreeFactor(discreteKeys(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
|
||||||
|
// apply operand
|
||||||
|
ADT result = ADT::apply(op);
|
||||||
|
// Make a new factor
|
||||||
|
return DecisionTreeFactor(discreteKeys(), result);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||||
ADT::Binary op) const {
|
ADT::Binary op) const {
|
||||||
|
@ -101,14 +117,6 @@ namespace gtsam {
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
|
|
||||||
// apply operand
|
|
||||||
ADT result = ADT::apply(op);
|
|
||||||
// Make a new factor
|
|
||||||
return DecisionTreeFactor(discreteKeys(), result);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
size_t nrFrontals, ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
@ -188,10 +196,45 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::vector<double> DecisionTreeFactor::probabilities() const {
|
std::vector<double> DecisionTreeFactor::probabilities() const {
|
||||||
|
// Set of all keys
|
||||||
|
std::set<Key> allKeys(keys().begin(), keys().end());
|
||||||
|
|
||||||
std::vector<double> probs;
|
std::vector<double> probs;
|
||||||
for (auto&& [key, value] : enumerate()) {
|
|
||||||
probs.push_back(value);
|
/* An operation that takes each leaf probability, and computes the
|
||||||
|
* nrAssignments by checking the difference between the keys in the factor
|
||||||
|
* and the keys in the assignment.
|
||||||
|
* The nrAssignments is then used to append
|
||||||
|
* the correct number of leaf probability values to the `probs` vector
|
||||||
|
* defined above.
|
||||||
|
*/
|
||||||
|
auto op = [&](const Assignment<Key>& a, double p) {
|
||||||
|
// Get all the keys in the current assignment
|
||||||
|
std::set<Key> assignment_keys;
|
||||||
|
for (auto&& [k, _] : a) {
|
||||||
|
assignment_keys.insert(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Find the keys missing in the assignment
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||||
|
assignment_keys.begin(), assignment_keys.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
|
// Compute the total number of assignments in the (pruned) subtree
|
||||||
|
size_t nrAssignments = 1;
|
||||||
|
for (auto&& k : diff) {
|
||||||
|
nrAssignments *= cardinalities_.at(k);
|
||||||
|
}
|
||||||
|
// Add p `nrAssignments` times to the probs vector.
|
||||||
|
probs.insert(probs.end(), nrAssignments, p);
|
||||||
|
|
||||||
|
return p;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Go through the tree
|
||||||
|
this->apply(op);
|
||||||
|
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,11 +348,7 @@ namespace gtsam {
|
||||||
const size_t N = maxNrAssignments;
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Get the probabilities in the decision tree so we can threshold.
|
||||||
std::vector<double> probabilities;
|
std::vector<double> probabilities = this->probabilities();
|
||||||
// NOTE(Varun) this is potentially slow due to the cartesian product
|
|
||||||
for (auto&& [assignment, prob] : this->enumerate()) {
|
|
||||||
probabilities.push_back(prob);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves
|
// The number of probabilities can be lower than max_leaves
|
||||||
if (probabilities.size() <= N) {
|
if (probabilities.size() <= N) {
|
||||||
|
|
|
@ -186,6 +186,13 @@ namespace gtsam {
|
||||||
* Apply unary operator (*this) "op" f
|
* Apply unary operator (*this) "op" f
|
||||||
* @param op a unary operator that operates on AlgebraicDecisionTree
|
* @param op a unary operator that operates on AlgebraicDecisionTree
|
||||||
*/
|
*/
|
||||||
|
DecisionTreeFactor apply(ADT::Unary op) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Apply unary operator (*this) "op" f
|
||||||
|
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
|
||||||
|
* both the assignment and the value.
|
||||||
|
*/
|
||||||
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
|
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
|
const DecisionTree<Key, double>& dtree)
|
||||||
|
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the correct ordering of the leaves in the decision tree.
|
||||||
|
*
|
||||||
|
* This is done by first taking all the values which have modulo 0 value with
|
||||||
|
* the cardinality of the innermost key `n`, and we go up to modulo n.
|
||||||
|
*
|
||||||
|
* @param dt The DecisionTree
|
||||||
|
* @return std::vector<double>
|
||||||
|
*/
|
||||||
|
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
|
||||||
|
const DecisionTreeFactor& dt) {
|
||||||
|
std::vector<double> probs = dt.probabilities();
|
||||||
|
std::vector<double> ordered;
|
||||||
|
|
||||||
|
size_t n = dkeys[0].second;
|
||||||
|
|
||||||
|
for (size_t k = 0; k < n; ++k) {
|
||||||
|
for (size_t idx = 0; idx < probs.size(); ++idx) {
|
||||||
|
if (idx % n == k) {
|
||||||
|
ordered.push_back(probs[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ordered;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
|
const DecisionTreeFactor& dtf)
|
||||||
|
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteConditional& c)
|
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
: TableFactor(c.discreteKeys(), c.probabilities()) {}
|
: TableFactor(c.discreteKeys(), c) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
|
|
|
@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
: TableFactor(DiscreteKeys{key}, row) {}
|
: TableFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
|
/// Constructor from DecisionTreeFactor
|
||||||
|
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
|
||||||
|
|
||||||
|
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
|
||||||
|
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
|
||||||
|
|
||||||
/** Construct from a DiscreteConditional type */
|
/** Construct from a DiscreteConditional type */
|
||||||
explicit TableFactor(const DiscreteConditional& c);
|
explicit TableFactor(const DiscreteConditional& c);
|
||||||
|
|
||||||
|
@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return apply(f, Ring::mul);
|
return apply(f, Ring::mul);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// multiple with DecisionTreeFactor
|
/// multiply with DecisionTreeFactor
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
static double safe_div(const double& a, const double& b);
|
static double safe_div(const double& a, const double& b);
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
@ -131,6 +132,16 @@ TEST(TableFactor, constructors) {
|
||||||
// Manually constructed via inspection and comparison to DecisionTreeFactor
|
// Manually constructed via inspection and comparison to DecisionTreeFactor
|
||||||
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
EXPECT(assert_equal(expected, f4));
|
EXPECT(assert_equal(expected, f4));
|
||||||
|
|
||||||
|
// Test for 9=3x3 values.
|
||||||
|
DiscreteKey V(0, 3), W(1, 3);
|
||||||
|
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
||||||
|
TableFactor f5(conditional5);
|
||||||
|
// GTSAM_PRINT(f5);
|
||||||
|
TableFactor expected_f5(
|
||||||
|
X & Y,
|
||||||
|
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
|
||||||
|
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
@ -286,8 +286,6 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
|
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
|
||||||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
|
||||||
// Functional which loops over all assignments and create a set of
|
// Functional which loops over all assignments and create a set of
|
||||||
// GaussianConditionals
|
// GaussianConditionals
|
||||||
auto pruner = prunerFunc(discreteProbs);
|
auto pruner = prunerFunc(discreteProbs);
|
||||||
|
|
|
@ -129,7 +129,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
size_t maxNrLeaves) {
|
size_t maxNrLeaves) {
|
||||||
// Get the joint distribution of only the discrete keys
|
// Get the joint distribution of only the discrete keys
|
||||||
gttic_(HybridBayesNet_PruneDiscreteConditionals);
|
|
||||||
// The joint discrete probability.
|
// The joint discrete probability.
|
||||||
DiscreteConditional discreteProbs;
|
DiscreteConditional discreteProbs;
|
||||||
|
|
||||||
|
@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
discrete_factor_idxs.push_back(i);
|
discrete_factor_idxs.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const DecisionTreeFactor prunedDiscreteProbs =
|
const DecisionTreeFactor prunedDiscreteProbs =
|
||||||
discreteProbs.prune(maxNrLeaves);
|
discreteProbs.prune(maxNrLeaves);
|
||||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
|
||||||
|
|
||||||
// Eliminate joint probability back into conditionals
|
// Eliminate joint probability back into conditionals
|
||||||
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
|
|
||||||
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
||||||
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
||||||
|
|
||||||
|
@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
||||||
size_t idx = discrete_factor_idxs.at(i);
|
size_t idx = discrete_factor_idxs.at(i);
|
||||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||||
}
|
}
|
||||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
|
||||||
|
|
||||||
return prunedDiscreteProbs;
|
return prunedDiscreteProbs;
|
||||||
}
|
}
|
||||||
|
@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
|
|
||||||
HybridBayesNet prunedBayesNetFragment;
|
HybridBayesNet prunedBayesNetFragment;
|
||||||
|
|
||||||
gttic_(HybridBayesNet_PruneMixtures);
|
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
|
@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
prunedBayesNetFragment.push_back(conditional);
|
prunedBayesNetFragment.push_back(conditional);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gttoc_(HybridBayesNet_PruneMixtures);
|
|
||||||
|
|
||||||
return prunedBayesNetFragment;
|
return prunedBayesNetFragment;
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
// keys, and then loop over all assignments to populate a vector.
|
||||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
gttic_(assembleGraphTree);
|
|
||||||
|
|
||||||
GaussianFactorGraphTree result;
|
GaussianFactorGraphTree result;
|
||||||
|
|
||||||
|
@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gttoc_(assembleGraphTree);
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
|
// TODO(Varun) Make this a function of HybridGaussianFactorGraph?
|
||||||
for (auto& factor : (*remainingFactorGraph_partial)) {
|
for (auto& factor : (*remainingFactorGraph_partial)) {
|
||||||
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor);
|
auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
|
||||||
assert(df);
|
assert(df);
|
||||||
discrete_fg.push_back(df);
|
discrete_fg.push_back(df);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue