Merge pull request #1590 from borglab/hybrid-tablefactor-3

release/4.3a0
Varun Agrawal 2023-07-27 12:25:55 -04:00 committed by GitHub
commit c5740b2221
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 119 additions and 30 deletions

View File

@ -82,6 +82,22 @@ namespace gtsam {
ADT::print("", formatter); ADT::print("", formatter);
} }
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { ADT::Binary op) const {
@ -101,14 +117,6 @@ namespace gtsam {
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {
@ -188,10 +196,45 @@ namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
std::vector<double> DecisionTreeFactor::probabilities() const { std::vector<double> DecisionTreeFactor::probabilities() const {
// Set of all keys
std::set<Key> allKeys(keys().begin(), keys().end());
std::vector<double> probs; std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value); /* An operation that takes each leaf probability, and computes the
* nrAssignments by checking the difference between the keys in the factor
* and the keys in the assignment.
* The nrAssignments is then used to append
* the correct number of leaf probability values to the `probs` vector
* defined above.
*/
auto op = [&](const Assignment<Key>& a, double p) {
// Get all the keys in the current assignment
std::set<Key> assignment_keys;
for (auto&& [k, _] : a) {
assignment_keys.insert(k);
} }
// Find the keys missing in the assignment
std::vector<Key> diff;
std::set_difference(allKeys.begin(), allKeys.end(),
assignment_keys.begin(), assignment_keys.end(),
std::back_inserter(diff));
// Compute the total number of assignments in the (pruned) subtree
size_t nrAssignments = 1;
for (auto&& k : diff) {
nrAssignments *= cardinalities_.at(k);
}
// Add p `nrAssignments` times to the probs vector.
probs.insert(probs.end(), nrAssignments, p);
return p;
};
// Go through the tree
this->apply(op);
return probs; return probs;
} }
@ -305,11 +348,7 @@ namespace gtsam {
const size_t N = maxNrAssignments; const size_t N = maxNrAssignments;
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities = this->probabilities();
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -186,6 +186,13 @@ namespace gtsam {
* Apply unary operator (*this) "op" f * Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree * @param op a unary operator that operates on AlgebraicDecisionTree
*/ */
DecisionTreeFactor apply(ADT::Unary op) const;
/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree. Takes
* both the assignment and the value.
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const; DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
/** /**

View File

@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTree<Key, double>& dtree)
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
/**
* @brief Compute the correct ordering of the leaves in the decision tree.
*
* This is done by first taking all the values which have modulo 0 value with
* the cardinality of the innermost key `n`, and we go up to modulo n.
*
* @param dt The DecisionTree
* @return std::vector<double>
*/
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
std::vector<double> probs = dt.probabilities();
std::vector<double> ordered;
size_t n = dkeys[0].second;
for (size_t k = 0; k < n; ++k) {
for (size_t idx = 0; idx < probs.size(); ++idx) {
if (idx % n == k) {
ordered.push_back(probs[idx]);
}
}
}
return ordered;
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dtf)
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c) TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {} : TableFactor(c.discreteKeys(), c) {}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(

View File

@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row) TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {} : TableFactor(DiscreteKeys{key}, row) {}
/// Constructor from DecisionTreeFactor
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
/** Construct from a DiscreteConditional type */ /** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c); explicit TableFactor(const DiscreteConditional& c);
@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul); return apply(f, Ring::mul);
}; };
/// multiple with DecisionTreeFactor /// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);

View File

@ -19,6 +19,7 @@
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
@ -131,6 +132,16 @@ TEST(TableFactor, constructors) {
// Manually constructed via inspection and comparison to DecisionTreeFactor // Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4)); EXPECT(assert_equal(expected, f4));
// Test for 9=3x3 values.
DiscreteKey V(0, 3), W(1, 3);
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
TableFactor f5(conditional5);
// GTSAM_PRINT(f5);
TableFactor expected_f5(
X & Y,
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
EXPECT(assert_equal(expected_f5, f5, 1e-6));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -286,8 +286,6 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(discreteProbs); auto pruner = prunerFunc(discreteProbs);

View File

@ -129,7 +129,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) { size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys // Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability. // The joint discrete probability.
DiscreteConditional discreteProbs; DiscreteConditional discreteProbs;
@ -147,12 +146,11 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
discrete_factor_idxs.push_back(i); discrete_factor_idxs.push_back(i);
} }
} }
const DecisionTreeFactor prunedDiscreteProbs = const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves); discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals // Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
DiscreteFactorGraph dfg{prunedDiscreteProbs}; DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
@ -161,7 +159,6 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t idx = discrete_factor_idxs.at(i); size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i)); this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
} }
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
return prunedDiscreteProbs; return prunedDiscreteProbs;
} }
@ -180,7 +177,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
gttic_(HybridBayesNet_PruneMixtures);
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per prunedDiscreteProbs. // Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
@ -197,7 +193,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
prunedBayesNetFragment.push_back(conditional); prunedBayesNetFragment.push_back(conditional);
} }
} }
gttoc_(HybridBayesNet_PruneMixtures);
return prunedBayesNetFragment; return prunedBayesNetFragment;
} }

View File

@ -96,7 +96,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete // TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector. // keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic_(assembleGraphTree);
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
@ -129,8 +128,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
} }
gttoc_(assembleGraphTree);
return result; return result;
} }

View File

@ -420,7 +420,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
DiscreteFactorGraph discrete_fg; DiscreteFactorGraph discrete_fg;
// TODO(Varun) Make this a function of HybridGaussianFactorGraph? // TODO(Varun) Make this a function of HybridGaussianFactorGraph?
for (auto& factor : (*remainingFactorGraph_partial)) { for (auto& factor : (*remainingFactorGraph_partial)) {
auto df = dynamic_pointer_cast<DecisionTreeFactor>(factor); auto df = dynamic_pointer_cast<DiscreteFactor>(factor);
assert(df); assert(df);
discrete_fg.push_back(df); discrete_fg.push_back(df);
} }