updates to get things working

release/4.3a0
Varun Agrawal 2025-01-06 21:06:22 -05:00
parent 8658f25edd
commit 5913fd120d
4 changed files with 6 additions and 5 deletions

View File

@ -221,7 +221,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr
*/
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const;
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// @}
/// @name Advanced Interface

View File

@ -47,7 +47,8 @@ static Eigen::SparseVector<double> normalizeSparseTable(
TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(),
DecisionTreeFactor(f.discreteKeys(), ADT())),
table_(f / (*f.sum(f.keys().size()))) {}
table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}
/* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys,

View File

@ -140,7 +140,7 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
}
/// Get the number of non-zero values.
size_t nrValues() const { return table_.sparseTable().nonZeros(); }
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
/// @}

View File

@ -284,7 +284,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
product = product / *std::dynamic_pointer_cast<TableFactor>(denominator);
#if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize);
#endif
@ -367,7 +367,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif
TableFactor::shared_ptr sum = product.sum(frontalKeys);
DiscreteFactor::shared_ptr sum = product.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum};