Merge pull request #1960 from borglab/table-factor-fixes
commit
05d8030af4
|
|
@ -87,7 +87,15 @@ static Eigen::SparseVector<double> ComputeSparseTable(
|
||||||
});
|
});
|
||||||
sparseTable.reserve(nrValues);
|
sparseTable.reserve(nrValues);
|
||||||
|
|
||||||
std::set<Key> allKeys(dt.keys().begin(), dt.keys().end());
|
KeySet allKeys(dt.keys().begin(), dt.keys().end());
|
||||||
|
|
||||||
|
// Compute denominators to be used in computing sparse table indices
|
||||||
|
std::map<Key, size_t> denominators;
|
||||||
|
double denom = sparseTable.size();
|
||||||
|
for (const DiscreteKey& dkey : dkeys) {
|
||||||
|
denom /= dkey.second;
|
||||||
|
denominators.insert(std::pair<Key, double>(dkey.first, denom));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Functor which is called by the DecisionTree for each leaf.
|
* @brief Functor which is called by the DecisionTree for each leaf.
|
||||||
|
|
@ -102,13 +110,13 @@ static Eigen::SparseVector<double> ComputeSparseTable(
|
||||||
auto op = [&](const Assignment<Key>& assignment, double p) {
|
auto op = [&](const Assignment<Key>& assignment, double p) {
|
||||||
if (p > 0) {
|
if (p > 0) {
|
||||||
// Get all the keys involved in this assignment
|
// Get all the keys involved in this assignment
|
||||||
std::set<Key> assignmentKeys;
|
KeySet assignmentKeys;
|
||||||
for (auto&& [k, _] : assignment) {
|
for (auto&& [k, _] : assignment) {
|
||||||
assignmentKeys.insert(k);
|
assignmentKeys.insert(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the keys missing in the assignment
|
// Find the keys missing in the assignment
|
||||||
std::vector<Key> diff;
|
KeyVector diff;
|
||||||
std::set_difference(allKeys.begin(), allKeys.end(),
|
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||||
assignmentKeys.begin(), assignmentKeys.end(),
|
assignmentKeys.begin(), assignmentKeys.end(),
|
||||||
std::back_inserter(diff));
|
std::back_inserter(diff));
|
||||||
|
|
@ -127,12 +135,10 @@ static Eigen::SparseVector<double> ComputeSparseTable(
|
||||||
|
|
||||||
// Generate index and add to the sparse vector.
|
// Generate index and add to the sparse vector.
|
||||||
Eigen::Index idx = 0;
|
Eigen::Index idx = 0;
|
||||||
size_t previousCardinality = 1;
|
|
||||||
// We go in reverse since a DecisionTree has the highest label first
|
// We go in reverse since a DecisionTree has the highest label first
|
||||||
for (auto&& it = updatedAssignment.rbegin();
|
for (auto&& it = updatedAssignment.rbegin();
|
||||||
it != updatedAssignment.rend(); it++) {
|
it != updatedAssignment.rend(); it++) {
|
||||||
idx += previousCardinality * it->second;
|
idx += it->second * denominators.at(it->first);
|
||||||
previousCardinality *= dt.cardinality(it->first);
|
|
||||||
}
|
}
|
||||||
sparseTable.coeffRef(idx) = p;
|
sparseTable.coeffRef(idx) = p;
|
||||||
}
|
}
|
||||||
|
|
@ -252,9 +258,19 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
|
||||||
std::vector<double> table;
|
// If no keys, then return empty DecisionTreeFactor
|
||||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
if (dkeys.size() == 0) {
|
||||||
table.push_back(sparse_table_.coeff(i));
|
AlgebraicDecisionTree<Key> tree;
|
||||||
|
// We can have an empty sparse_table_ or one with a single value.
|
||||||
|
if (sparse_table_.size() != 0) {
|
||||||
|
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
|
||||||
|
}
|
||||||
|
return DecisionTreeFactor(dkeys, tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<double> table(sparse_table_.size(), 0.0);
|
||||||
|
for (SparseIt it(sparse_table_); it; ++it) {
|
||||||
|
table[it.index()] = it.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> tree(dkeys, table);
|
AlgebraicDecisionTree<Key> tree(dkeys, table);
|
||||||
|
|
|
||||||
|
|
@ -173,6 +173,36 @@ TEST(TableFactor, Conversion) {
|
||||||
TableFactor tf(dtf.discreteKeys(), dtf);
|
TableFactor tf(dtf.discreteKeys(), dtf);
|
||||||
|
|
||||||
EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor()));
|
EXPECT(assert_equal(dtf, tf.toDecisionTreeFactor()));
|
||||||
|
|
||||||
|
// Test for correct construction when keys are not in reverse order.
|
||||||
|
// This is possible in conditionals e.g. P(x1 | x0)
|
||||||
|
DiscreteKey X(1, 2), Y(0, 2);
|
||||||
|
DiscreteConditional dtf2(
|
||||||
|
X, {Y}, std::vector<double>{0.33333333, 0.6, 0.66666667, 0.4});
|
||||||
|
|
||||||
|
TableFactor tf2(dtf2);
|
||||||
|
// GTSAM_PRINT(dtf2);
|
||||||
|
// GTSAM_PRINT(tf2);
|
||||||
|
// GTSAM_PRINT(tf2.toDecisionTreeFactor());
|
||||||
|
|
||||||
|
// Check for ADT equality since the order of keys is irrelevant
|
||||||
|
EXPECT(assert_equal<AlgebraicDecisionTree<Key>>(dtf2,
|
||||||
|
tf2.toDecisionTreeFactor()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(TableFactor, Empty) {
|
||||||
|
DiscreteKey X(1, 2);
|
||||||
|
|
||||||
|
TableFactor single = *TableFactor({X}, "1 1").sum(1);
|
||||||
|
// Should not throw a segfault
|
||||||
|
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
|
||||||
|
single.toDecisionTreeFactor()));
|
||||||
|
|
||||||
|
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
|
||||||
|
// Should not throw a segfault
|
||||||
|
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
|
||||||
|
empty.toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue