Merge pull request #1592 from borglab/tablefactor-improvements

release/4.3a0
Varun Agrawal 2023-07-27 11:11:12 -04:00 committed by GitHub
commit e649fc6967
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 2 deletions

View File

@ -56,9 +56,45 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTree<Key, double>& dtree)
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
/**
* @brief Compute the correct ordering of the leaves in the decision tree.
*
* This is done by first taking all the values which have modulo 0 value with
* the cardinality of the innermost key `n`, and we go up to modulo n.
*
* @param dt The DecisionTree
* @return std::vector<double>
*/
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dt) {
std::vector<double> probs = dt.probabilities();
std::vector<double> ordered;
size_t n = dkeys[0].second;
for (size_t k = 0; k < n; ++k) {
for (size_t idx = 0; idx < probs.size(); ++idx) {
if (idx % n == k) {
ordered.push_back(probs[idx]);
}
}
}
return ordered;
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteKeys& dkeys,
const DecisionTreeFactor& dtf)
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
: TableFactor(c.discreteKeys(), c) {}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(

View File

@ -144,6 +144,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/// Constructor from DecisionTreeFactor
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
@ -180,7 +186,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul);
};
/// multiple with DecisionTreeFactor
/// multiply with DecisionTreeFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
static double safe_div(const double& a, const double& b);

View File

@ -19,6 +19,7 @@
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/discrete/TableFactor.h>
@ -131,6 +132,16 @@ TEST(TableFactor, constructors) {
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
// Test for 9=3x3 values.
DiscreteKey V(0, 3), W(1, 3);
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
TableFactor f5(conditional5);
// GTSAM_PRINT(f5);
TableFactor expected_f5(
X & Y,
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
EXPECT(assert_equal(expected_f5, f5, 1e-6));
}
/* ************************************************************************* */