Merge pull request #1947 from borglab/discrete-improvements
commit
6c516cc404
|
@ -112,13 +112,12 @@ namespace gtsam {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Multiply all the `factors` and normalize the
|
* @brief Multiply all the `factors`.
|
||||||
* product to prevent underflow.
|
|
||||||
*
|
*
|
||||||
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
* @param factors The factors to multiply as a DiscreteFactorGraph.
|
||||||
* @return DecisionTreeFactor
|
* @return DecisionTreeFactor
|
||||||
*/
|
*/
|
||||||
static DecisionTreeFactor ProductAndNormalize(
|
static DecisionTreeFactor DiscreteProduct(
|
||||||
const DiscreteFactorGraph& factors) {
|
const DiscreteFactorGraph& factors) {
|
||||||
// PRODUCT: multiply all factors
|
// PRODUCT: multiply all factors
|
||||||
gttic(product);
|
gttic(product);
|
||||||
|
@ -126,10 +125,10 @@ namespace gtsam {
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
auto normalization = product.max(product.size());
|
auto denominator = product.max(product.size());
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*normalization);
|
product = product / (*denominator);
|
||||||
|
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
@ -139,7 +138,7 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateForMPE(const DiscreteFactorGraph& factors,
|
EliminateForMPE(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// max out frontals, this is the factor on the separator
|
// max out frontals, this is the factor on the separator
|
||||||
gttic(max);
|
gttic(max);
|
||||||
|
@ -207,8 +206,7 @@ namespace gtsam {
|
||||||
return dag.argmax();
|
return dag.argmax();
|
||||||
}
|
}
|
||||||
|
|
||||||
DiscreteValues DiscreteFactorGraph::optimize(
|
DiscreteValues DiscreteFactorGraph::optimize(const Ordering& ordering) const {
|
||||||
const Ordering& ordering) const {
|
|
||||||
gttic(DiscreteFactorGraph_optimize);
|
gttic(DiscreteFactorGraph_optimize);
|
||||||
DiscreteLookupDAG dag = maxProduct(ordering);
|
DiscreteLookupDAG dag = maxProduct(ordering);
|
||||||
return dag.argmax();
|
return dag.argmax();
|
||||||
|
@ -218,7 +216,7 @@ namespace gtsam {
|
||||||
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
std::pair<DiscreteConditional::shared_ptr, DiscreteFactor::shared_ptr> //
|
||||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
DecisionTreeFactor product = ProductAndNormalize(factors);
|
DecisionTreeFactor product = DiscreteProduct(factors);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
|
|
|
@ -252,41 +252,12 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
|
||||||
// Record key assignment and value pairs in pair_table.
|
std::vector<double> table;
|
||||||
// The assignments are stored in descending order of keys so that the order of
|
|
||||||
// the values matches what is expected by a DecisionTree.
|
|
||||||
// This is why we reverse the keys and then
|
|
||||||
// query for the key value/assignment.
|
|
||||||
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
|
||||||
std::vector<std::pair<uint64_t, double>> pair_table;
|
|
||||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||||
std::stringstream ss;
|
table.push_back(sparse_table_.coeff(i));
|
||||||
for (auto&& [key, _] : rdkeys) {
|
|
||||||
ss << keyValueForIndex(key, i);
|
|
||||||
}
|
|
||||||
// k will be in reverse key order already
|
|
||||||
uint64_t k;
|
|
||||||
ss >> k;
|
|
||||||
pair_table.push_back(std::make_pair(k, sparse_table_.coeff(i)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort the pair_table (of assignment-value pairs) based on assignment so we
|
AlgebraicDecisionTree<Key> tree(dkeys, table);
|
||||||
// get values in reverse key order.
|
|
||||||
std::sort(
|
|
||||||
pair_table.begin(), pair_table.end(),
|
|
||||||
[](const std::pair<uint64_t, double>& a,
|
|
||||||
const std::pair<uint64_t, double>& b) { return a.first < b.first; });
|
|
||||||
|
|
||||||
// Create the table vector by extracting the values from pair_table.
|
|
||||||
// The pair_table has already been sorted in the desired order,
|
|
||||||
// so the values will be in descending key order.
|
|
||||||
std::vector<double> table;
|
|
||||||
std::for_each(pair_table.begin(), pair_table.end(),
|
|
||||||
[&table](const std::pair<uint64_t, double>& pair) {
|
|
||||||
table.push_back(pair.second);
|
|
||||||
});
|
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> tree(rdkeys, table);
|
|
||||||
DecisionTreeFactor f(dkeys, tree);
|
DecisionTreeFactor f(dkeys, tree);
|
||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto normalization = newFactor.max(newFactor.size());
|
auto normalizer = newFactor.max(newFactor.size());
|
||||||
|
|
||||||
newFactor = newFactor / *normalization;
|
newFactor = newFactor / *normalizer;
|
||||||
|
|
||||||
// Check Conditional
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
|
@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
CHECK(&newFactor);
|
CHECK(&newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
// Normalize by max.
|
// Normalize by max.
|
||||||
normalization = expectedFactor.max(expectedFactor.size());
|
normalizer = expectedFactor.max(expectedFactor.size());
|
||||||
// Ensure normalization is correct.
|
// Ensure normalizer is correct.
|
||||||
expectedFactor = expectedFactor / *normalization;
|
expectedFactor = expectedFactor / *normalizer;
|
||||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
||||||
// Test using elimination tree
|
// Test using elimination tree
|
||||||
|
|
|
@ -282,7 +282,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
auto dc = hc->asDiscrete();
|
auto dc = hc->asDiscrete();
|
||||||
if (!dc) throwRuntimeError("discreteElimination", dc);
|
if (!dc) throwRuntimeError("discreteElimination", dc);
|
||||||
dfg.push_back(hc->asDiscrete());
|
dfg.push_back(dc);
|
||||||
} else {
|
} else {
|
||||||
throwRuntimeError("discreteElimination", f);
|
throwRuntimeError("discreteElimination", f);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue