small improvements to HybridGaussianFactorGraph

release/4.3a0
Varun Agrawal 2025-01-24 13:42:04 -05:00
parent 59539ffe6c
commit 26642f1ba0
2 changed files with 5 additions and 1 deletions

View File

@ -186,6 +186,7 @@ DiscreteValues HybridBayesNet::mpe() const {
} }
} }
} }
return discrete_fg.optimize(); return discrete_fg.optimize();
} }

View File

@ -327,6 +327,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
gttic_(EliminateDiscreteFormDiscreteConditional); gttic_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
// Check type of product, and get as TableFactor for efficiency. // Check type of product, and get as TableFactor for efficiency.
// Use object instead of pointer since we need it
// for the TableDistribution constructor.
TableFactor p; TableFactor p;
if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) { if (auto tf = std::dynamic_pointer_cast<TableFactor>(product)) {
p = *tf; p = *tf;
@ -334,11 +336,12 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
p = TableFactor(product->toDecisionTreeFactor()); p = TableFactor(product->toDecisionTreeFactor());
} }
auto conditional = std::make_shared<TableDistribution>(p); auto conditional = std::make_shared<TableDistribution>(p);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteFormDiscreteConditional); gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); DiscreteFactor::shared_ptr sum = p.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum}; return {std::make_shared<HybridConditional>(conditional), sum};