Merge pull request #1947 from borglab/discrete-improvements

release/4.3a0
Varun Agrawal 2025-01-02 10:49:11 -05:00 committed by GitHub
commit 6c516cc404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 47 deletions

View File

@ -112,13 +112,12 @@ namespace gtsam {
// } // }
/** /**
* @brief Multiply all the `factors` and normalize the * @brief Multiply all the `factors`.
* product to prevent underflow.
* *
* @param factors The factors to multiply as a DiscreteFactorGraph. * @param factors The factors to multiply as a DiscreteFactorGraph.
* @return DecisionTreeFactor * @return DecisionTreeFactor
*/ */
static DecisionTreeFactor ProductAndNormalize( static DecisionTreeFactor DiscreteProduct(
const DiscreteFactorGraph& factors) { const DiscreteFactorGraph& factors) {
// PRODUCT: multiply all factors // PRODUCT: multiply all factors
gttic(product); gttic(product);
@ -126,10 +125,10 @@ namespace gtsam {
gttoc(product); gttoc(product);
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size()); auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*normalization); product = product / (*denominator);
return product; return product;
} }
@ -139,7 +138,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateForMPE(const DiscreteFactorGraph& factors, EliminateForMPE(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors); DecisionTreeFactor product = DiscreteProduct(factors);
// max out frontals, this is the factor on the separator // max out frontals, this is the factor on the separator
gttic(max); gttic(max);
@ -207,8 +206,7 @@ namespace gtsam {
return dag.argmax(); return dag.argmax();
} }
DiscreteValues DiscreteFactorGraph::optimize( DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const {
const Ordering& ordering) const {
gttic(DiscreteFactorGraph_optimize); gttic(DiscreteFactorGraph_optimize);
DiscreteLookupDAG dag = maxProduct(ordering); DiscreteLookupDAG dag = maxProduct(ordering);
return dag.argmax(); return dag.argmax();
@ -218,7 +216,7 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> // std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors, EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) { const Ordering& frontalKeys) {
DecisionTreeFactor product = ProductAndNormalize(factors); DecisionTreeFactor product = DiscreteProduct(factors);
// sum out frontals, this is the factor on the separator // sum out frontals, this is the factor on the separator
gttic(sum); gttic(sum);

View File

@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();
// Record key assignment and value pairs in pair_table. std::vector<double> table;
// The assignments are stored in descending order of keys so that the order of
// the values matches what is expected by a DecisionTree.
// This is why we reverse the keys and then
// query for the key value/assignment.
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
std::vector<std::pair<uint64_t, double>> pair_table;
for (auto i = 0; i < sparse_table_.size(); i++) { for (auto i = 0; i < sparse_table_.size(); i++) {
std::stringstream ss; table.push_back(sparse_table_.coeff(i));
for (auto&& [key, _] : rdkeys) {
ss << keyValueForIndex(key, i);
}
// k will be in reverse key order already
uint64_t k;
ss >> k;
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
} }
// Sort the pair_table (of assignment-value pairs) based on assignment so we AlgebraicDecisionTree<Key> tree(dkeys, table);
// get values in reverse key order.
std::sort(
pair_table.begin(), pair_table.end(),
[](const std::pair<uint64_t, double>& a,
const std::pair<uint64_t, double>& b) { return a.first < b.first; });
// Create the table vector by extracting the values from pair_table.
// The pair_table has already been sorted in the desired order,
// so the values will be in descending key order.
std::vector<double> table;
std::for_each(pair_table.begin(), pair_table.end(),
[&table](const std::pair<uint64_t, double>& pair) {
table.push_back(pair.second);
});
AlgebraicDecisionTree<Key> tree(rdkeys, table);
DecisionTreeFactor f(dkeys, tree); DecisionTreeFactor f(dkeys, tree);
return f; return f;
} }

View File

@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr); *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected // Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size()); auto normalizer = newFactor.max(newFactor.size());
newFactor = newFactor / *normalization; newFactor = newFactor / *normalizer;
// Check Conditional // Check Conditional
CHECK(conditional); CHECK(conditional);
@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor); CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max. // Normalize by max.
normalization = expectedFactor.max(expectedFactor.size()); normalizer = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct. // Ensure normalizer is correct.
expectedFactor = expectedFactor / *normalization; expectedFactor = expectedFactor / *normalizer;
EXPECT(assert_equal(expectedFactor, newFactor)); EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree // Test using elimination tree

View File

@ -282,7 +282,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete(); auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("discreteElimination", dc); if (!dc) throwRuntimeError("discreteElimination", dc);
dfg.push_back(hc->asDiscrete()); dfg.push_back(dc);
} else { } else {
throwRuntimeError("discreteElimination", f); throwRuntimeError("discreteElimination", f);
} }