updates to get things working

release/4.3a0
Varun Agrawal 2025-01-06 21:06:22 -05:00
parent 8658f25edd
commit 5913fd120d
4 changed files with 6 additions and 5 deletions

View File

@ -221,7 +221,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param keys The keys to sum over. * @param keys The keys to sum over.
* @return DiscreteFactor::shared_ptr * @return DiscreteFactor::shared_ptr
*/ */
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const; virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface

View File

@ -47,7 +47,8 @@ static Eigen::SparseVector<double> normalizeSparseTable(
TableDistribution::TableDistribution(const TableFactor& f) TableDistribution::TableDistribution(const TableFactor& f)
: BaseConditional(f.keys().size(), : BaseConditional(f.keys().size(),
DecisionTreeFactor(f.discreteKeys(), ADT())), DecisionTreeFactor(f.discreteKeys(), ADT())),
table_(f / (*f.sum(f.keys().size()))) {} table_(f / (*std::dynamic_pointer_cast<TableFactor>(
f.sum(f.keys().size())))) {}
/* ************************************************************************** */ /* ************************************************************************** */
TableDistribution::TableDistribution(const DiscreteKeys& keys, TableDistribution::TableDistribution(const DiscreteKeys& keys,

View File

@ -140,7 +140,7 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
} }
/// Get the number of non-zero values. /// Get the number of non-zero values.
size_t nrValues() const { return table_.sparseTable().nonZeros(); } uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
/// @} /// @}

View File

@ -284,7 +284,7 @@ TableFactor TableProduct(const DiscreteFactorGraph &factors) {
// Max over all the potentials by pretending all keys are frontal: // Max over all the potentials by pretending all keys are frontal:
auto denominator = product.max(product.size()); auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow. // Normalize the product factor to prevent underflow.
product = product / (*denominator); product = product / *std::dynamic_pointer_cast<TableFactor>(denominator);
#if GTSAM_HYBRID_TIMING #if GTSAM_HYBRID_TIMING
gttoc_(DiscreteNormalize); gttoc_(DiscreteNormalize);
#endif #endif
@ -367,7 +367,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
gttoc_(EliminateDiscreteFormDiscreteConditional); gttoc_(EliminateDiscreteFormDiscreteConditional);
#endif #endif
TableFactor::shared_ptr sum = product.sum(frontalKeys); DiscreteFactor::shared_ptr sum = product.sum(frontalKeys);
return {std::make_shared<HybridConditional>(conditional), sum}; return {std::make_shared<HybridConditional>(conditional), sum};