kill TableProduct in favor of DiscreteFactorGraph::scaledProduct

release/4.3a0
Varun Agrawal 2025-01-06 22:18:14 -05:00
parent 82dba6322f
commit 9960f2d8dc
3 changed files with 19 additions and 51 deletions

View File

@ -45,9 +45,16 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
/* ************************************************************************* */
DiscreteValues HybridBayesTree::discreteMaxProduct(
const DiscreteFactorGraph& dfg) const {
TableFactor product = TableProduct(dfg);
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
DiscreteValues assignment = TableDistribution(product).argmax();
// Check type of product, and get as TableFactor for efficiency.
TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
DiscreteValues assignment = TableDistribution(p).argmax();
return assignment;
}

View File

@ -255,43 +255,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors(
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/* ************************************************************************ */
TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// PRODUCT: multiply all factors
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteProduct);
#endif
TableFactor product;
for (auto &&factor : factors) {
if (factor) {
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(factor)) {
product = product * dtc->table();
} else if (auto f = std::dynamic_pointer_cast<TableFactor>(factor)) {
product = product * (*f);
} else if (auto dtf =
std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
product = product * TableFactor(*dtf);
}
}
}
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteProduct);
#endif
#if GTSAM_HYBRID_TIMING
gttic_(DiscreteNormalize);
#endif
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / *std::dynamic_pointer_cast<TableFactor>(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
return product;
}
/* ************************************************************************ */
static DiscreteFactorGraph CollectDiscreteFactors(
const HybridGaussianFactorGraph &factors) {
@ -357,17 +320,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// so we can use the TableFactor for efficiency.
if (frontalKeys.size() == dfg.keys().size()) {
// Get product factor
TableFactor product = TableProduct(dfg);
DiscreteFactor::shared_ptr product = dfg.scaledProduct();
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif
auto conditional = std::make_shared<TableDistribution>(product);
// Check type of product, and get as TableFactor for efficiency.
TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf;
} else {
p = TableFactor(product->toDecisionTreeFactor());
}
auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
DiscreteFactor::shared_ptr sum = product.sum(frontalKeys);
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum};

View File

@ -271,13 +271,4 @@ template <>
struct traits<HybridGaussianFactorGraph>
: public Testable<HybridGaussianFactorGraph> {};
/**
* @brief Multiply all the `factors` and normalize the
* product to prevent underflow.
*
* @param factors The factors to multiply as a DiscreteFactorGraph.
* @return TableFactor
*/
TableFactor TableProduct(const DiscreteFactorGraph& factors);
} // namespace gtsam