Merge pull request #1929 from borglab/table-factor-fix
commit
3af5360ad3
|
|
@ -62,40 +62,99 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
|
: TableFactor(dkeys, DecisionTreeFactor(dkeys, dtree)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the correct ordering of the leaves in the decision tree.
|
* @brief Compute the indexing of the leaves in the decision tree based on the
|
||||||
|
* assignment and add the (index, leaf) pair to a SparseVector.
|
||||||
*
|
*
|
||||||
* This is done by first taking all the values which have modulo 0 value with
|
* We visit each leaf in the tree, and using the cardinalities of the keys,
|
||||||
* the cardinality of the innermost key `n`, and we go up to modulo n.
|
* compute the correct index to add the leaf to a SparseVector which
|
||||||
|
* is then used to create the TableFactor.
|
||||||
*
|
*
|
||||||
* @param dt The DecisionTree
|
* @param dt The DecisionTree
|
||||||
* @return std::vector<double>
|
* @return Eigen::SparseVector<double>
|
||||||
*/
|
*/
|
||||||
std::vector<double> ComputeLeafOrdering(const DiscreteKeys& dkeys,
|
static Eigen::SparseVector<double> ComputeSparseTable(
|
||||||
const DecisionTreeFactor& dt) {
|
const DiscreteKeys& dkeys, const DecisionTreeFactor& dt) {
|
||||||
std::vector<double> probs = dt.probabilities();
|
// SparseVector needs to know the maximum possible index,
|
||||||
std::vector<double> ordered;
|
// so we compute the product of cardinalities.
|
||||||
|
size_t cardinalityProduct = 1;
|
||||||
|
for (auto&& [_, c] : dt.cardinalities()) {
|
||||||
|
cardinalityProduct *= c;
|
||||||
|
}
|
||||||
|
Eigen::SparseVector<double> sparseTable(cardinalityProduct);
|
||||||
|
size_t nrValues = 0;
|
||||||
|
dt.visit([&nrValues](double x) {
|
||||||
|
if (x > 0) nrValues += 1;
|
||||||
|
});
|
||||||
|
sparseTable.reserve(nrValues);
|
||||||
|
|
||||||
size_t n = dkeys[0].second;
|
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
|
||||||
|
|
||||||
for (size_t k = 0; k < n; ++k) {
|
/**
|
||||||
for (size_t idx = 0; idx < probs.size(); ++idx) {
|
* @brief Functor which is called by the DecisionTree for each leaf.
|
||||||
if (idx % n == k) {
|
* For each leaf value, we use the corresponding assignment to compute a
|
||||||
ordered.push_back(probs[idx]);
|
* corresponding index into a SparseVector. We then populate sparseTable with
|
||||||
|
* the value at the computed index.
|
||||||
|
*
|
||||||
|
* Takes advantage of the sparsity of the DecisionTree to be efficient. When
|
||||||
|
* merged branches are encountered, we enumerate over the missing keys.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
auto op = [&](const Assignment<Key>& assignment, double p) {
|
||||||
|
if (p > 0) {
|
||||||
|
// Get all the keys involved in this assignment
|
||||||
|
std::set<Key> assignmentKeys;
|
||||||
|
for (auto&& [k, _] : assignment) {
|
||||||
|
assignmentKeys.insert(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the keys missing in the assignment
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||||
|
assignmentKeys.begin(), assignmentKeys.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
|
// Generate all assignments using the missing keys
|
||||||
|
DiscreteKeys extras;
|
||||||
|
for (auto&& key : diff) {
|
||||||
|
extras.push_back({key, dt.cardinality(key)});
|
||||||
|
}
|
||||||
|
auto&& extraAssignments = DiscreteValues::CartesianProduct(extras);
|
||||||
|
|
||||||
|
for (auto&& extra : extraAssignments) {
|
||||||
|
// Create new assignment using the extra assignment
|
||||||
|
DiscreteValues updatedAssignment(assignment);
|
||||||
|
updatedAssignment.insert(extra);
|
||||||
|
|
||||||
|
// Generate index and add to the sparse vector.
|
||||||
|
Eigen::Index idx = 0;
|
||||||
|
size_t previousCardinality = 1;
|
||||||
|
// We go in reverse since a DecisionTree has the highest label first
|
||||||
|
for (auto&& it = updatedAssignment.rbegin();
|
||||||
|
it != updatedAssignment.rend(); it++) {
|
||||||
|
idx += previousCardinality * it->second;
|
||||||
|
previousCardinality *= dt.cardinality(it->first);
|
||||||
|
}
|
||||||
|
sparseTable.coeffRef(idx) = p;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
return ordered;
|
|
||||||
|
// Visit each leaf in `dt` to get the Assignment and leaf value
|
||||||
|
// to populate the sparseTable.
|
||||||
|
dt.visitWith(op);
|
||||||
|
|
||||||
|
return sparseTable;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const DecisionTreeFactor& dtf)
|
const DecisionTreeFactor& dtf)
|
||||||
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
|
: TableFactor(dkeys, ComputeSparseTable(dkeys, dtf)) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
|
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
|
||||||
: TableFactor(dtf.discreteKeys(),
|
: TableFactor(dtf.discreteKeys(),
|
||||||
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {}
|
ComputeSparseTable(dtf.discreteKeys(), dtf)) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteConditional& c)
|
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
|
|
|
||||||
|
|
@ -147,6 +147,34 @@ TEST(TableFactor, constructors) {
|
||||||
EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9));
|
EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check conversion from DecisionTreeFactor.
|
||||||
|
TEST(TableFactor, Conversion) {
|
||||||
|
/* This is the DecisionTree we are using
|
||||||
|
Choice(m2)
|
||||||
|
0 Choice(m1)
|
||||||
|
0 0 Leaf 0
|
||||||
|
0 1 Choice(m0)
|
||||||
|
0 1 0 Leaf 0
|
||||||
|
0 1 1 Leaf 0.14649446 // 3
|
||||||
|
1 Choice(m1)
|
||||||
|
1 0 Choice(m0)
|
||||||
|
1 0 0 Leaf 0
|
||||||
|
1 0 1 Leaf 0.14648756 // 5
|
||||||
|
1 1 Choice(m0)
|
||||||
|
1 1 0 Leaf 0.14649446 // 6
|
||||||
|
1 1 1 Leaf 0.23918345 // 7
|
||||||
|
*/
|
||||||
|
DiscreteKeys dkeys = {{0, 2}, {1, 2}, {2, 2}};
|
||||||
|
DecisionTreeFactor dtf(
|
||||||
|
dkeys, std::vector<double>{0, 0, 0, 0.14649446, 0, 0.14648756, 0.14649446,
|
||||||
|
0.23918345});
|
||||||
|
|
||||||
|
TableFactor tf(dtf.discreteKeys(), dtf);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check multiplication between two TableFactors.
|
// Check multiplication between two TableFactors.
|
||||||
TEST(TableFactor, multiplication) {
|
TEST(TableFactor, multiplication) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue