diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 383346ab1..a8ec66f73 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -57,6 +57,9 @@ namespace gtsam { AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {} + /// Constructor which accepts root pointer + AlgebraicDecisionTree(const typename Base::NodePtr root) : Base(root) {} + // Explicitly non-explicit constructor AlgebraicDecisionTree(const Base& add) : Base(add) {} diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 19dcdc729..f5ad2b98a 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -77,6 +77,13 @@ DiscreteConditional::DiscreteConditional(const Signature& signature) /* ************************************************************************** */ DiscreteConditional DiscreteConditional::operator*( const DiscreteConditional& other) const { + // If the root is a nullptr, we have a TableDistribution + // TODO(Varun) Revisit this hack after RSS2025 submission + if (!other.root_) { + DiscreteConditional dc(other.nrFrontals(), other.toDecisionTreeFactor()); + return dc * (*this); + } + // Take union of frontal keys std::set newFrontals; for (auto&& key : this->frontals()) newFrontals.insert(key); @@ -479,6 +486,19 @@ double DiscreteConditional::evaluate(const HybridValues& x) const { return this->operator()(x.discrete()); } +/* ************************************************************************* */ +DiscreteFactor::shared_ptr DiscreteConditional::max( + const Ordering& keys) const { + return BaseFactor::max(keys); +} + +/* ************************************************************************* */ +void DiscreteConditional::prune(size_t maxNrAssignments) { + // Get as DiscreteConditional so the probabilities are normalized + DiscreteConditional pruned(nrFrontals(), BaseFactor::prune(maxNrAssignments)); + this->root_ = pruned.root_; +} + /* ************************************************************************* */ double DiscreteConditional::negLogConstant() const { return 0.0; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 67f8a0056..1bca0b09f 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const DiscreteValues& parentsValues) const; + virtual size_t sample(const DiscreteValues& parentsValues) const; /// Single parent version. size_t sample(size_t parent_value) const; @@ -214,6 +214,15 @@ class GTSAM_EXPORT DiscreteConditional */ size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + /** + * @brief Create new factor by maximizing over all + * values with the same separator. + * + * @param keys The keys to sum over. + * @return DiscreteFactor::shared_ptr + */ + virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const override; + /// @} /// @name Advanced Interface /// @{ @@ -267,6 +276,9 @@ class GTSAM_EXPORT DiscreteConditional */ double negLogConstant() const override; + /// Prune the conditional + virtual void prune(size_t maxNrAssignments); + /// @} protected: diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 9b1774f49..7e059c5e5 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -118,30 +118,18 @@ namespace gtsam { // } // } - /** - * @brief Multiply all the `factors`. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return DiscreteFactor::shared_ptr - */ - static DiscreteFactor::shared_ptr DiscreteProduct( - const DiscreteFactorGraph& factors) { + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const { // PRODUCT: multiply all factors gttic(product); - DiscreteFactor::shared_ptr product = factors.product(); + DiscreteFactor::shared_ptr product = this->product(); gttoc(product); -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif // Max over all the potentials by pretending all keys are frontal: auto denominator = product->max(product->size()); // Normalize the product factor to prevent underflow. product = product->operator/(denominator); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif return product; } @@ -151,7 +139,7 @@ namespace gtsam { std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = factors.scaledProduct(); // max out frontals, this is the factor on the separator gttic(max); @@ -229,7 +217,7 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - DiscreteFactor::shared_ptr product = DiscreteProduct(factors); + DiscreteFactor::shared_ptr product = factors.scaledProduct(); // sum out frontals, this is the factor on the separator gttic(sum); diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 3d9e86cd1..f4d1a1833 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -150,6 +150,15 @@ class GTSAM_EXPORT DiscreteFactorGraph /** return product of all factors as a single factor */ DiscreteFactor::shared_ptr product() const; + /** + * @brief Return product of all `factors` as a single factor, + * which is scaled by the max value to prevent underflow + * + * @param factors The factors to multiply as a DiscreteFactorGraph. + * @return DiscreteFactor::shared_ptr + */ + DiscreteFactor::shared_ptr scaledProduct() const; + /** * Evaluates the factor graph given values, returns the joint probability of * the factor graph given specific instantiation of values diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp new file mode 100644 index 000000000..e8696c5b1 --- /dev/null +++ b/gtsam/discrete/TableDistribution.cpp @@ -0,0 +1,174 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableDistribution.cpp + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using std::pair; +using std::stringstream; +using std::vector; +namespace gtsam { + +/// Normalize sparse_table +static Eigen::SparseVector normalizeSparseTable( + const Eigen::SparseVector& sparse_table) { + return sparse_table / sparse_table.sum(); +} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const TableFactor& f) + : BaseConditional(f.keys().size(), f.discreteKeys(), ADT(nullptr)), + table_(f / (*std::dynamic_pointer_cast( + f.sum(f.keys().size())))) {} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const DiscreteKeys& keys, + const std::vector& potentials) + : BaseConditional(keys.size(), keys, ADT(nullptr)), + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} + +/* ************************************************************************** */ +TableDistribution::TableDistribution(const DiscreteKeys& keys, + const std::string& potentials) + : BaseConditional(keys.size(), keys, ADT(nullptr)), + table_(TableFactor( + keys, normalizeSparseTable(TableFactor::Convert(keys, potentials)))) { +} + +/* ************************************************************************** */ +void TableDistribution::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << " P( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + cout << "):\n"; + table_.print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +bool TableDistribution::equals(const DiscreteFactor& other, double tol) const { + auto dtc = dynamic_cast(&other); + if (!dtc) { + return false; + } else { + const DiscreteConditional& f( + static_cast(other)); + return table_.equals(dtc->table_, tol) && + DiscreteConditional::BaseConditional::equals(f, tol); + } +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::sum(size_t nrFrontals) const { + return table_.sum(nrFrontals); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::sum(const Ordering& keys) const { + return table_.sum(keys); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::max(size_t nrFrontals) const { + return table_.max(nrFrontals); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const { + return table_.max(keys); +} + +/* ****************************************************************************/ +DiscreteFactor::shared_ptr TableDistribution::operator/( + const DiscreteFactor::shared_ptr& f) const { + return table_ / f; +} + +/* ************************************************************************ */ +DiscreteValues TableDistribution::argmax() const { + uint64_t maxIdx = 0; + double maxValue = 0.0; + + Eigen::SparseVector sparseTable = table_.sparseTable(); + + for (SparseIt it(sparseTable); it; ++it) { + if (it.value() > maxValue) { + maxIdx = it.index(); + maxValue = it.value(); + } + } + + return table_.findAssignments(maxIdx); +} + +/* ****************************************************************************/ +void TableDistribution::prune(size_t maxNrAssignments) { + table_ = table_.prune(maxNrAssignments); +} + +/* ****************************************************************************/ +size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { + static mt19937 rng(2); // random number generator + + DiscreteKeys parentsKeys; + for (auto&& [key, _] : parentsValues) { + parentsKeys.push_back({key, table_.cardinality(key)}); + } + + // Get the correct conditional distribution: P(F|S=parentsValues) + TableFactor pFS = table_.choose(parentsValues, parentsKeys); + + // TODO(Duy): only works for one key now, seems horribly slow this way + if (nrFrontals() != 1) { + throw std::invalid_argument( + "TableDistribution::sample can only be called on single variable " + "conditionals"); + } + Key key = firstFrontalKey(); + size_t nj = cardinality(key); + vector p(nj); + DiscreteValues frontals; + for (size_t value = 0; value < nj; value++) { + frontals[key] = value; + p[value] = pFS(frontals); // P(F=value|S=parentsValues) + if (p[value] == 1.0) { + return value; // shortcut exit + } + } + std::discrete_distribution distribution(p.begin(), p.end()); + return distribution(rng); +} + +} // namespace gtsam diff --git a/gtsam/discrete/TableDistribution.h b/gtsam/discrete/TableDistribution.h new file mode 100644 index 000000000..72786a515 --- /dev/null +++ b/gtsam/discrete/TableDistribution.h @@ -0,0 +1,177 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file TableDistribution.h + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * Distribution which uses a SparseVector as the internal + * representation, similar to the TableFactor. + * + * This is primarily used in the case when we have a clique in the BayesTree + * which consists of all the discrete variables, e.g. in hybrid elimination. + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableDistribution : public DiscreteConditional { + private: + TableFactor table_; + + typedef Eigen::SparseVector::InnerIterator SparseIt; + + public: + // typedefs needed to play nice with gtsam + typedef TableDistribution This; ///< Typedef to this class + typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DiscreteConditional + BaseConditional; ///< Typedef to our conditional base class + + using Values = DiscreteValues; ///< backwards compatibility + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + TableDistribution() {} + + /// Construct from TableFactor. + TableDistribution(const TableFactor& f); + + /** + * Construct from DiscreteKeys and std::vector. + */ + TableDistribution(const DiscreteKeys& keys, + const std::vector& potentials); + + /** + * Construct from single DiscreteKey and std::vector. + */ + TableDistribution(const DiscreteKey& key, + const std::vector& potentials) + : TableDistribution(DiscreteKeys(key), potentials) {} + + /** + * Construct from DiscreteKey and std::string. + */ + TableDistribution(const DiscreteKeys& keys, const std::string& potentials); + + /** + * Construct from single DiscreteKey and std::string. + */ + TableDistribution(const DiscreteKey& key, const std::string& potentials) + : TableDistribution(DiscreteKeys(key), potentials) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Table Distribution: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// GTSAM-style equals + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + /// @} + /// @name Standard Interface + /// @{ + + /// Return the underlying TableFactor + TableFactor table() const { return table_; } + + using BaseConditional::evaluate; // HybridValues version + + /// Evaluate the conditional given the values. + virtual double evaluate(const Assignment& values) const override { + return table_.evaluate(values); + } + + /// Create new factor by summing all values with the same separator values + DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override; + + /// Create new factor by summing all values with the same separator values + DiscreteFactor::shared_ptr sum(const Ordering& keys) const override; + + /// Create new factor by maximizing over all values with the same separator. + DiscreteFactor::shared_ptr max(size_t nrFrontals) const override; + + /// Create new factor by maximizing over all values with the same separator. + DiscreteFactor::shared_ptr max(const Ordering& keys) const override; + + /// divide by DiscreteFactor::shared_ptr f (safely) + DiscreteFactor::shared_ptr operator/( + const DiscreteFactor::shared_ptr& f) const override; + + /** + * @brief Return assignment that maximizes value. + * + * @return maximizing assignment for the variables. + */ + DiscreteValues argmax() const; + + /** + * sample + * @param parentsValues Known values of the parents + * @return sample from conditional + */ + virtual size_t sample(const DiscreteValues& parentsValues) const override; + + /// @} + /// @name Advanced Interface + /// @{ + + /// Prune the conditional + virtual void prune(size_t maxNrAssignments) override; + + /// Get a DecisionTreeFactor representation. + DecisionTreeFactor toDecisionTreeFactor() const override { + return table_.toDecisionTreeFactor(); + } + + /// Get the number of non-zero values. + uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); } + + /// @} + + private: +#if GTSAM_ENABLE_BOOST_SERIALIZATION + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + ar& BOOST_SERIALIZATION_NVP(table_); + } +#endif +}; +// TableDistribution + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 5a804b6a6..1cb9eda8b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -87,6 +87,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); } + public: /** * Convert probability table given as doubles to SparseVector. * Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} @@ -98,7 +99,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { static Eigen::SparseVector Convert(const DiscreteKeys& keys, const std::string& table); - public: // typedefs needed to play nice with gtsam typedef TableFactor This; typedef DiscreteFactor Base; ///< Typedef to base class @@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DecisionTreeFactor toDecisionTreeFactor() const override; /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, + TableFactor choose(const DiscreteValues parentAssignments, DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index b2e2524f8..40f1822cf 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -168,6 +168,43 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { std::vector pmf() const; }; +#include +virtual class TableFactor : gtsam::DiscreteFactor { + TableFactor(); + TableFactor(const gtsam::DiscreteKeys& keys, + const gtsam::TableFactor& potentials); + TableFactor(const gtsam::DiscreteKeys& keys, std::vector& table); + TableFactor(const gtsam::DiscreteKeys& keys, string spec); + TableFactor(const gtsam::DiscreteKeys& keys, + const gtsam::DecisionTreeFactor& dtf); + TableFactor(const gtsam::DecisionTreeFactor& dtf); + + void print(string s = "TableFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + double evaluate(const gtsam::DiscreteValues& values) const; + double error(const gtsam::DiscreteValues& values) const; +}; + +#include +virtual class TableDistribution : gtsam::DiscreteConditional { + TableDistribution(); + TableDistribution(const gtsam::TableFactor& f); + TableDistribution(const gtsam::DiscreteKey& key, std::vector spec); + TableDistribution(const gtsam::DiscreteKeys& keys, std::vector spec); + TableDistribution(const gtsam::DiscreteKeys& keys, string spec); + TableDistribution(const gtsam::DiscreteKey& key, string spec); + + void print(string s = "Table Distribution\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + + gtsam::TableFactor table() const; + double evaluate(const gtsam::DiscreteValues& values) const; + size_t nrValues() const; +}; + #include class DiscreteBayesNet { DiscreteBayesNet(); diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 623b82eea..8668bedd6 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -55,12 +56,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { joint = joint * (*conditional); } - // Prune the joint. NOTE: again, possibly quite expensive. - const DecisionTreeFactor pruned = joint.prune(maxNrLeaves); - - // Create a the result starting with the pruned joint. + // Create the result starting with the pruned joint. HybridBayesNet result; - result.emplace_shared(pruned.size(), pruned); + result.emplace_shared(joint); + // Prune the joint. NOTE: imperative and, again, possibly quite expensive. + result.back()->asDiscrete()->prune(maxNrLeaves); + + // Get pruned discrete probabilities so + // we can prune HybridGaussianConditionals. + DiscreteConditional pruned = *result.back()->asDiscrete(); /* To prune, we visitWith every leaf in the HybridGaussianConditional. * For each leaf, using the assignment we can check the discrete decision tree @@ -126,7 +130,14 @@ HybridValues HybridBayesNet::optimize() const { for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_fg.push_back(conditional->asDiscrete()); + if (auto dtc = conditional->asDiscrete()) { + // The number of keys should be small so should not + // be expensive to convert to DiscreteConditional. + discrete_fg.push_back(DiscreteConditional(dtc->nrFrontals(), + dtc->toDecisionTreeFactor())); + } else { + discrete_fg.push_back(conditional->asDiscrete()); + } } } diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 1b633e024..31d256d6f 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,22 @@ bool HybridBayesTree::equals(const This& other, double tol) const { return Base::equals(other, tol); } +/* ************************************************************************* */ +DiscreteValues HybridBayesTree::discreteMaxProduct( + const DiscreteFactorGraph& dfg) const { + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); + + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + DiscreteValues assignment = TableDistribution(p).argmax(); + return assignment; +} + /* ************************************************************************* */ HybridValues HybridBayesTree::optimize() const { DiscreteFactorGraph discrete_fg; @@ -52,8 +69,9 @@ HybridValues HybridBayesTree::optimize() const { // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - discrete_fg.push_back(root_conditional->asDiscrete()); - mpe = discrete_fg.optimize(); + auto discrete = root_conditional->asDiscrete(); + discrete_fg.push_back(discrete); + mpe = discreteMaxProduct(discrete_fg); } else { throw std::runtime_error( "HybridBayesTree root is not discrete-only. Please check elimination " @@ -179,16 +197,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); + auto prunedDiscreteProbs = + this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); - discreteProbs->root_ = prunedDiscreteProbs.root_; + // Imperative pruning + prunedDiscreteProbs->prune(maxNrLeaves); /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteProbs; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, + DiscreteConditional::shared_ptr prunedDiscreteProbs; + HybridPrunerData(const DiscreteConditional::shared_ptr& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) : prunedDiscreteProbs(prunedDiscreteProbs) {} @@ -213,7 +232,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (!hybridGaussianCond->pruned()) { // Imperative clique->conditional() = std::make_shared( - hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); + hybridGaussianCond->prune(*parentData.prunedDiscreteProbs)); } } return parentData; diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 06d880f02..ec29f7b1e 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -115,6 +115,10 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /// @} private: + /// Helper method to compute the max product assignment + /// given a DiscreteFactorGraph + DiscreteValues discreteMaxProduct(const DiscreteFactorGraph& dfg) const; + #if GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access; diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index a08b3a6ee..3cf5b80e5 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -166,12 +166,13 @@ class GTSAM_EXPORT HybridConditional } /** - * @brief Return conditional as a DiscreteConditional + * @brief Return conditional as a DiscreteConditional or specified type T. * @return nullptr if not a DiscreteConditional * @return DiscreteConditional::shared_ptr */ - DiscreteConditional::shared_ptr asDiscrete() const { - return std::dynamic_pointer_cast(inner_); + template + typename T::shared_ptr asDiscrete() const { + return std::dynamic_pointer_cast(inner_); } /// Get the type-erased pointer to the inner type diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 54346679e..78e1f5324 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -304,7 +304,7 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /* *******************************************************************************/ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( - const DecisionTreeFactor &discreteProbs) const { + const DiscreteConditional &discreteProbs) const { // Find keys in discreteProbs.keys() but not in this->keys(): std::set mine(this->keys().begin(), this->keys().end()); std::set theirs(discreteProbs.keys().begin(), diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index e769662ed..3b95e0277 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -235,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional * @return Shared pointer to possibly a pruned HybridGaussianConditional */ HybridGaussianConditional::shared_ptr prune( - const DecisionTreeFactor &discreteProbs) const; + const DiscreteConditional &discreteProbs) const; /// Return true if the conditional has already been pruned. bool pruned() const { return pruned_; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 8be5a8af4..cf56b52ed 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -255,46 +256,6 @@ static TableFactor::shared_ptr DiscreteFactorFromErrors( return std::make_shared(discreteKeys, potentials); } -/** - * @brief Multiply all the `factors` using the machinery of the TableFactor. - * - * @param factors The factors to multiply as a DiscreteFactorGraph. - * @return TableFactor - */ -static TableFactor TableProduct(const DiscreteFactorGraph &factors) { - // PRODUCT: multiply all factors -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteProduct); -#endif - TableFactor product; - for (auto &&factor : factors) { - if (factor) { - if (auto f = std::dynamic_pointer_cast(factor)) { - product = product * (*f); - } else if (auto dtf = - std::dynamic_pointer_cast(factor)) { - product = product * TableFactor(*dtf); - } - } - } -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteProduct); -#endif - -#if GTSAM_HYBRID_TIMING - gttic_(DiscreteNormalize); -#endif - // Max over all the potentials by pretending all keys are frontal: - auto denominator = product.max(product.size()); - // Normalize the product factor to prevent underflow. - product = product / (*denominator); -#if GTSAM_HYBRID_TIMING - gttoc_(DiscreteNormalize); -#endif - - return product; -} - /* ************************************************************************ */ static DiscreteFactorGraph CollectDiscreteFactors( const HybridGaussianFactorGraph &factors) { @@ -325,12 +286,17 @@ static DiscreteFactorGraph CollectDiscreteFactors( #if GTSAM_HYBRID_TIMING gttic_(ConvertConditionalToTableFactor); #endif - // Convert DiscreteConditional to TableFactor - auto tdc = std::make_shared(*dc); + if (auto dtc = std::dynamic_pointer_cast(dc)) { + /// Get the underlying TableFactor + dfg.push_back(dtc->table()); + } else { + // Convert DiscreteConditional to TableFactor + auto tdc = std::make_shared(*dc); + dfg.push_back(tdc); + } #if GTSAM_HYBRID_TIMING gttoc_(ConvertConditionalToTableFactor); #endif - dfg.push_back(tdc); } else { throwRuntimeError("discreteElimination", f); } @@ -355,21 +321,24 @@ discreteElimination(const HybridGaussianFactorGraph &factors, // so we can use the TableFactor for efficiency. if (frontalKeys.size() == dfg.keys().size()) { // Get product factor - TableFactor product = TableProduct(dfg); + DiscreteFactor::shared_ptr product = dfg.scaledProduct(); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif - auto conditional = std::make_shared( - frontalKeys.size(), product.toDecisionTreeFactor()); + // Check type of product, and get as TableFactor for efficiency. + TableFactor p; + if (auto tf = std::dynamic_pointer_cast(product)) { + p = *tf; + } else { + p = TableFactor(product->toDecisionTreeFactor()); + } + auto conditional = std::make_shared(p); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif - TableFactor::shared_ptr sum = product.sum(frontalKeys); -#if GTSAM_HYBRID_TIMING - gttoc_(EliminateDiscrete); -#endif + DiscreteFactor::shared_ptr sum = product->sum(frontalKeys); return {std::make_shared(conditional), sum}; @@ -378,6 +347,9 @@ discreteElimination(const HybridGaussianFactorGraph &factors, auto result = EliminateDiscrete(dfg, frontalKeys); return {std::make_shared(result.first), result.second}; } +#if GTSAM_HYBRID_TIMING + gttoc_(EliminateDiscrete); +#endif } /* ************************************************************************ */ diff --git a/gtsam/hybrid/HybridNonlinearISAM.cpp b/gtsam/hybrid/HybridNonlinearISAM.cpp index 29e467d86..3b4856dfb 100644 --- a/gtsam/hybrid/HybridNonlinearISAM.cpp +++ b/gtsam/hybrid/HybridNonlinearISAM.cpp @@ -15,6 +15,7 @@ * @author Varun Agrawal */ +#include #include #include #include @@ -65,7 +66,14 @@ void HybridNonlinearISAM::reorderRelinearize() { // Obtain the new linearization point const Values newLinPoint = estimate(); - auto discreteProbs = *(isam_.roots().at(0)->conditional()->asDiscrete()); + DiscreteConditional::shared_ptr discreteProbabilities; + + auto discreteRoot = isam_.roots().at(0)->conditional(); + if (discreteRoot->asDiscrete()) { + discreteProbabilities = discreteRoot->asDiscrete(); + } else { + discreteProbabilities = discreteRoot->asDiscrete(); + } isam_.clear(); @@ -73,7 +81,7 @@ void HybridNonlinearISAM::reorderRelinearize() { HybridNonlinearFactorGraph pruned_factors; for (auto&& factor : factors_) { if (auto nf = std::dynamic_pointer_cast(factor)) { - pruned_factors.push_back(nf->prune(discreteProbs)); + pruned_factors.push_back(nf->prune(*discreteProbabilities)); } else { pruned_factors.push_back(factor); } diff --git a/gtsam/hybrid/tests/testGaussianMixture.cpp b/gtsam/hybrid/tests/testGaussianMixture.cpp index 698c1bbf6..c98485fea 100644 --- a/gtsam/hybrid/tests/testGaussianMixture.cpp +++ b/gtsam/hybrid/tests/testGaussianMixture.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -79,8 +80,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { double midway = mu1 - mu0; auto eliminationResult = gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential(); - auto pMid = *eliminationResult->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid)); + auto pMid = eliminationResult->at(0)->asDiscrete(); + EXPECT(assert_equal(TableDistribution(m, "60 40"), *pMid)); // Everywhere else, the result should be a sigmoid. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -90,7 +91,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = + *eliminationResult1->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -99,7 +101,8 @@ TEST(GaussianMixture, GaussianMixtureModel) { m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)}); hfg1.push_back(mixing); auto eliminationResult2 = hfg1.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = + *eliminationResult2->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } @@ -133,13 +136,13 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Eliminate the graph! auto eliminationResultMax = gfg.eliminateSequential(); - // Equality of posteriors asserts that the elimination is correct (same ratios - // for all modes) + // Equality of posteriors asserts that the elimination is correct + // (same ratios for all modes) EXPECT(assert_equal(expectedDiscretePosterior, eliminationResultMax->discretePosterior(vv))); - auto pMax = *eliminationResultMax->at(0)->asDiscrete(); - EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4)); + auto pMax = *eliminationResultMax->at(0)->asDiscrete(); + EXPECT(assert_equal(TableDistribution(m, "42 58"), pMax, 1e-4)); // Everywhere else, the result should be a bell curve like function. for (const double shift : {-4, -2, 0, 2, 4}) { @@ -149,7 +152,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) { // Workflow 1: convert HBN to HFG and solve auto eliminationResult1 = gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential(); - auto posterior1 = *eliminationResult1->at(0)->asDiscrete(); + auto posterior1 = + *eliminationResult1->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); // Workflow 2: directly specify HFG and solve @@ -158,7 +162,8 @@ TEST(GaussianMixture, GaussianMixtureModel2) { m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)}); hfg.push_back(mixing); auto eliminationResult2 = hfg.eliminateSequential(); - auto posterior2 = *eliminationResult2->at(0)->asDiscrete(); + auto posterior2 = + *eliminationResult2->at(0)->asDiscrete(); EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); } } diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 135da5dc7..989694b26 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -454,7 +455,8 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { } size_t maxNrLeaves = 3; - auto prunedDecisionTree = joint.prune(maxNrLeaves); + DiscreteConditional prunedDecisionTree(joint); + prunedDecisionTree.prune(maxNrLeaves); #ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index 5b27e2b41..ef2ae9c41 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -16,6 +16,7 @@ */ #include +#include #include #include #include @@ -464,14 +465,14 @@ TEST(HybridEstimation, EliminateSequentialRegression) { // Create expected discrete conditional on m0. DiscreteKey m(M(0), 2); - DiscreteConditional expected(m % "0.51341712/1"); // regression + TableDistribution expected(m, "0.51341712 1"); // regression // Eliminate into BN using one ordering const Ordering ordering1{X(0), X(1), M(0)}; HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1); // Check that the discrete conditional matches the expected. - auto dc1 = bn1->back()->asDiscrete(); + auto dc1 = bn1->back()->asDiscrete(); EXPECT(assert_equal(expected, *dc1, 1e-9)); // Eliminate into BN using a different ordering @@ -479,7 +480,7 @@ TEST(HybridEstimation, EliminateSequentialRegression) { HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2); // Check that the discrete conditional matches the expected. - auto dc2 = bn2->back()->asDiscrete(); + auto dc2 = bn2->back()->asDiscrete(); EXPECT(assert_equal(expected, *dc2, 1e-9)); } diff --git a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp index 350bc9184..8bb83cac4 100644 --- a/gtsam/hybrid/tests/testHybridGaussianConditional.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianConditional.cpp @@ -261,7 +261,8 @@ TEST(HybridGaussianConditional, Prune) { potentials[i] = 1; const DecisionTreeFactor decisionTreeFactor(keys, potentials); // Prune the HybridGaussianConditional - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 1 conditional EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); } @@ -271,7 +272,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 2 conditionals EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); @@ -286,7 +288,8 @@ TEST(HybridGaussianConditional, Prune) { 0, 0, 0.5, 0}; const DecisionTreeFactor decisionTreeFactor(keys, potentials); - const auto pruned = hgc.prune(decisionTreeFactor); + const auto pruned = + hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); // Check that the pruned HybridGaussianConditional has 3 conditionals EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 1942e9234..c8735c40a 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -650,7 +651,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) { mode, std::vector{conditional0, conditional1}); // Add prior on mode. - expectedBayesNet.emplace_shared(mode, "74/26"); + expectedBayesNet.emplace_shared(mode, "74 26"); // Test elimination const auto posterior = fg.eliminateSequential(); @@ -700,11 +701,11 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) { m1, std::vector{conditional0, conditional1}); // Add prior on m1. - expectedBayesNet.emplace_shared(m1, "1/1"); + expectedBayesNet.emplace_shared(m1, "0.188638 0.811362"); // Test elimination const auto posterior = fg.eliminateSequential(); - // EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); + EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01)); EXPECT(ratioTest(bn, measurements, *posterior)); @@ -736,7 +737,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) { mode, std::vector{conditional0, conditional1}); // Add prior on mode. - expectedBayesNet.emplace_shared(mode, "23/77"); + // Since this is the only discrete conditional, it is added as a + // TableDistribution. + expectedBayesNet.emplace_shared(mode, "23 77"); // Test elimination const auto posterior = fg.eliminateSequential(); diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 04b44f904..54964f6f7 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -141,7 +142,7 @@ TEST(HybridGaussianISAM, IncrementalInference) { expectedRemainingGraph->eliminateMultifrontal(discreteOrdering); // Test the probability values with regression tests. - auto discrete = isam[M(1)]->conditional()->asDiscrete(); + auto discrete = isam[M(1)]->conditional()->asDiscrete(); EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5)); EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5)); @@ -221,16 +222,12 @@ TEST(HybridGaussianISAM, ApproxInference) { 1 1 1 Leaf 0.5 */ - auto discreteConditional_m0 = *dynamic_pointer_cast( + auto discreteConditional_m0 = *dynamic_pointer_cast( incrementalHybrid[M(0)]->conditional()->inner()); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); - // Get the number of elements which are greater than 0. - auto count = [](const double &value, int count) { - return value > 0 ? count + 1 : count; - }; // Check that the number of leaves after pruning is 5. - EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); + EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues()); // Check that the hybrid nodes of the bayes net match those of the pre-pruning // bayes net, at the same positions. @@ -477,7 +474,9 @@ TEST(HybridGaussianISAM, NonTrivial) { // Test if the optimal discrete mode assignment is (1, 1, 1). DiscreteFactorGraph discreteGraph; - discreteGraph.push_back(discreteTree); + // discreteTree is a TableDistribution, so we convert to + // DecisionTreeFactor for the DiscreteFactorGraph + discreteGraph.push_back(discreteTree->toDecisionTreeFactor()); DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues expected_assignment; diff --git a/gtsam/hybrid/tests/testHybridMotionModel.cpp b/gtsam/hybrid/tests/testHybridMotionModel.cpp index 747a1b688..a4de6a17b 100644 --- a/gtsam/hybrid/tests/testHybridMotionModel.cpp +++ b/gtsam/hybrid/tests/testHybridMotionModel.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -143,8 +144,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) { // Since no measurement on x1, we hedge our bets // Importance sampling run with 100k samples gives 50.051/49.949 // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "50/50"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()))); + TableDistribution expected(m1, "50 50"); + EXPECT( + assert_equal(expected, *(bn->at(2)->asDiscrete()))); } { @@ -160,8 +162,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel) { // Since we have a measurement on x1, we get a definite result // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "44.3854/55.6146"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "44.3854 55.6146"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } } @@ -248,8 +251,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "48.3158/51.6842"); - EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "48.3158 51.6842"); + EXPECT(assert_equal( + expected, *(eliminated->at(2)->asDiscrete()), 0.02)); } { @@ -263,8 +267,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel2) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "55.396/44.604"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "55.396 44.604"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } } @@ -340,8 +345,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "51.7762/48.2238"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "51.7762 48.2238"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.02)); } { @@ -355,8 +361,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel3) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "49.0762/50.9238"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.005)); + TableDistribution expected(m1, "49.0762 50.9238"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.05)); } } @@ -381,8 +388,9 @@ TEST(HybridGaussianFactorGraph, TwoStateModel4) { // Values taken from an importance sampling run with 100k samples: // approximateDiscreteMarginal(hbn, hybridMotionModel, given); - DiscreteConditional expected(m1, "8.91527/91.0847"); - EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002)); + TableDistribution expected(m1, "8.91527 91.0847"); + EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), + 0.01)); } /* ************************************************************************* */ @@ -537,8 +545,8 @@ TEST(HybridGaussianFactorGraph, DifferentCovariances) { DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - DiscreteConditional expected_m1(m1, "0.5/0.5"); - DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); + TableDistribution expected_m1(m1, "0.5 0.5"); + TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete()); EXPECT(assert_equal(expected_m1, actual_m1)); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index 6e844dbcb..3df03021b 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -512,9 +513,10 @@ TEST(HybridNonlinearFactorGraph, Full_Elimination) { // P(m1) EXPECT(hybridBayesNet->at(4)->frontals() == KeyVector{M(1)}); EXPECT_LONGS_EQUAL(0, hybridBayesNet->at(4)->nrParents()); - EXPECT( - dynamic_pointer_cast(hybridBayesNet->at(4)->inner()) - ->equals(*discreteBayesNet.at(1))); + TableDistribution dtc = + *hybridBayesNet->at(4)->asDiscrete(); + EXPECT(DiscreteConditional(dtc.nrFrontals(), dtc.toDecisionTreeFactor()) + .equals(*discreteBayesNet.at(1))); } /**************************************************************************** @@ -1061,8 +1063,8 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) { DiscreteValues dv0{{M(1), 0}}; DiscreteValues dv1{{M(1), 1}}; - DiscreteConditional expected_m1(m1, "0.5/0.5"); - DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete()); + TableDistribution expected_m1(m1, "0.5 0.5"); + TableDistribution actual_m1 = *(hbn->at(2)->asDiscrete()); EXPECT(assert_equal(expected_m1, actual_m1)); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 67cec8319..b32860cff 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -265,16 +266,12 @@ TEST(HybridNonlinearISAM, ApproxInference) { 1 1 1 Leaf 0.5 */ - auto discreteConditional_m0 = *dynamic_pointer_cast( + auto discreteConditional_m0 = *dynamic_pointer_cast( bayesTree[M(0)]->conditional()->inner()); EXPECT(discreteConditional_m0.keys() == KeyVector({M(0), M(1), M(2)})); - // Get the number of elements which are greater than 0. - auto count = [](const double &value, int count) { - return value > 0 ? count + 1 : count; - }; // Check that the number of leaves after pruning is 5. - EXPECT_LONGS_EQUAL(5, discreteConditional_m0.fold(count, 0)); + EXPECT_LONGS_EQUAL(5, discreteConditional_m0.nrValues()); // Check that the hybrid nodes of the bayes net match those of the pre-pruning // bayes net, at the same positions. @@ -520,12 +517,13 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete(); + auto discreteTree = + bayesTree[M(3)]->conditional()->asDiscrete(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1). DiscreteFactorGraph discreteGraph; - discreteGraph.push_back(discreteTree); + discreteGraph.push_back(discreteTree->toDecisionTreeFactor()); DiscreteValues optimal_assignment = discreteGraph.optimize(); DiscreteValues expected_assignment; diff --git a/gtsam/hybrid/tests/testSerializationHybrid.cpp b/gtsam/hybrid/tests/testSerializationHybrid.cpp index 3be96b751..9aabe309b 100644 --- a/gtsam/hybrid/tests/testSerializationHybrid.cpp +++ b/gtsam/hybrid/tests/testSerializationHybrid.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -44,6 +45,7 @@ BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional"); +BOOST_CLASS_EXPORT_GUID(TableDistribution, "gtsam_TableDistribution"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); using ADT = AlgebraicDecisionTree; diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py index 6d609deb0..6edab3449 100644 --- a/python/gtsam/tests/test_HybridFactorGraph.py +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -13,14 +13,14 @@ Author: Fan Jiang, Varun Agrawal, Frank Dellaert import unittest import numpy as np -from gtsam.symbol_shorthand import C, M, X, Z -from gtsam.utils.test_case import GtsamTestCase import gtsam -from gtsam import (DiscreteConditional, GaussianConditional, - HybridBayesNet, HybridGaussianConditional, - HybridGaussianFactor, HybridGaussianFactorGraph, - HybridValues, JacobianFactor, noiseModel) +from gtsam import (DiscreteConditional, GaussianConditional, HybridBayesNet, + HybridGaussianConditional, HybridGaussianFactor, + HybridGaussianFactorGraph, HybridValues, JacobianFactor, + TableDistribution, noiseModel) +from gtsam.symbol_shorthand import C, M, X, Z +from gtsam.utils.test_case import GtsamTestCase DEBUG_MARGINALS = False @@ -51,7 +51,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase): self.assertEqual(len(hybridCond.keys()), 2) discrete_conditional = hbn.at(hbn.size() - 1).inner() - self.assertIsInstance(discrete_conditional, DiscreteConditional) + self.assertIsInstance(discrete_conditional, TableDistribution) def test_optimize(self): """Test construction of hybrid factor graph."""