diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index cfa5fa90c..307d1486c 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -536,5 +536,11 @@ namespace gtsam { return DecisionTreeFactor(this->discreteKeys(), thresholded); } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr DecisionTreeFactor::restrict( + const DiscreteValues& assignment) const { + throw std::runtime_error("DecisionTreeFactor::restrict not implemented"); +} + /* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 716c43b63..63f0384aa 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -220,6 +220,10 @@ namespace gtsam { return combine(keys, Ring::max); } + /// Restrict the factor to the given assignment. + DiscreteFactor::shared_ptr restrict( + const DiscreteValues& assignment) const override; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 6fa074379..a30383942 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -167,8 +167,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /** * @brief Scale the factor values by the maximum * to prevent underflow/overflow. - * - * @return DiscreteFactor::shared_ptr + * + * @return DiscreteFactor::shared_ptr */ DiscreteFactor::shared_ptr scale() const; @@ -178,6 +178,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { */ virtual uint64_t nrValues() const = 0; + /// Restrict the factor to the given assignment. + virtual DiscreteFactor::shared_ptr restrict( + const DiscreteValues& assignment) const = 0; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index 25acae06e..b5d3193e4 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -391,12 +391,12 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const { /* ************************************************************************ */ DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const { -return combine(nrFrontals, Ring::add); + return combine(nrFrontals, Ring::add); } /* ************************************************************************ */ DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const { -return combine(keys, Ring::add); + return combine(keys, Ring::add); } /* ************************************************************************ */ @@ -418,7 +418,6 @@ DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const { return combine(keys, Ring::max); } - /* ************************************************************************ */ TableFactor TableFactor::apply(Unary op) const { // Initialize new factor. @@ -781,5 +780,11 @@ TableFactor TableFactor::prune(size_t maxNrAssignments) const { return TableFactor(this->discreteKeys(), pruned_vec); } +/* ************************************************************************ */ +DiscreteFactor::shared_ptr TableFactor::restrict( + const DiscreteValues& assignment) const { + throw std::runtime_error("TableFactor::restrict not implemented"); +} + /* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index ce58d14bc..72c2861a2 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -342,6 +342,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { */ uint64_t nrValues() const override { return sparse_table_.nonZeros(); } + /// Restrict the factor to the given assignment. + DiscreteFactor::shared_ptr restrict( + const DiscreteValues& assignment) const override; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index dbf94ea59..fb8b215cf 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -69,11 +69,13 @@ HybridBayesNet HybridBayesNet::prune( // Go through all the Gaussian conditionals, restrict them according to // fixed values, and then prune further. - for (std::shared_ptr conditional : *this) { + for (std::shared_ptr conditional : *this) { if (conditional->isDiscrete()) continue; // No-op if not a HybridGaussianConditional. - if (marginalThreshold) conditional = conditional->restrict(fixed); + if (marginalThreshold) + conditional = std::static_pointer_cast( + conditional->restrict(fixed)); // Now decide on type what to do: if (auto hgc = conditional->asHybrid()) { diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 257eca314..7fffb06d3 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -170,8 +170,8 @@ double HybridConditional::evaluate(const HybridValues &values) const { } /* ************************************************************************ */ -HybridConditional::shared_ptr HybridConditional::restrict( - const DiscreteValues &discreteValues) const { +std::shared_ptr HybridConditional::restrict( + const DiscreteValues &assignment) const { if (auto gc = asGaussian()) { return std::make_shared(gc); } else if (auto dc = asDiscrete()) { @@ -184,21 +184,20 @@ HybridConditional::shared_ptr HybridConditional::restrict( "HybridConditional::restrict: conditional type not handled"); // Case 1: Fully determined, return corresponding Gaussian conditional - auto parentValues = discreteValues.filter(discreteKeys_); + auto parentValues = assignment.filter(discreteKeys_); if (parentValues.size() == discreteKeys_.size()) { return std::make_shared(hgc->choose(parentValues)); } // Case 2: Some live parents remain, build a new tree - auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); - if (!unspecifiedParentKeys.empty()) { + auto remainingKeys = assignment.missingKeys(discreteKeys_); + if (!remainingKeys.empty()) { auto newTree = hgc->factors(); for (const auto &[key, value] : parentValues) { newTree = newTree.choose(key, value); } return std::make_shared( - std::make_shared(unspecifiedParentKeys, - newTree)); + std::make_shared(remainingKeys, newTree)); } // Case 3: No changes needed, return original diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 075fbe411..45b00969b 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -153,7 +153,8 @@ class GTSAM_EXPORT HybridConditional * @return HybridGaussianConditional::shared_ptr otherwise */ HybridGaussianConditional::shared_ptr asHybrid() const { - return std::dynamic_pointer_cast(inner_); + if (!isHybrid()) return nullptr; + return std::static_pointer_cast(inner_); } /** @@ -162,7 +163,8 @@ class GTSAM_EXPORT HybridConditional * @return GaussianConditional::shared_ptr otherwise */ GaussianConditional::shared_ptr asGaussian() const { - return std::dynamic_pointer_cast(inner_); + if (!isContinuous()) return nullptr; + return std::static_pointer_cast(inner_); } /** @@ -172,7 +174,8 @@ class GTSAM_EXPORT HybridConditional */ template typename T::shared_ptr asDiscrete() const { - return std::dynamic_pointer_cast(inner_); + if (!isDiscrete()) return nullptr; + return std::static_pointer_cast(inner_); } /// Get the type-erased pointer to the inner type @@ -221,7 +224,8 @@ class GTSAM_EXPORT HybridConditional * which is just a GaussianConditional. If this conditional is *not* a hybrid * conditional, just return that. */ - shared_ptr restrict(const DiscreteValues& discreteValues) const; + std::shared_ptr restrict( + const DiscreteValues& assignment) const override; /// @} diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 9fc280322..4147420bd 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -133,10 +133,14 @@ class GTSAM_EXPORT HybridFactor : public Factor { /// Return only the continuous keys for this factor. const KeyVector &continuousKeys() const { return continuousKeys_; } - /// Virtual class to compute tree of linear errors. + /// Compute tree of linear errors. virtual AlgebraicDecisionTree errorTree( const VectorValues &values) const = 0; + /// Restrict the factor to the given discrete values. + virtual std::shared_ptr restrict( + const DiscreteValues &discreteValues) const = 0; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianConditional.cpp b/gtsam/hybrid/HybridGaussianConditional.cpp index 0546ff16b..cb569efac 100644 --- a/gtsam/hybrid/HybridGaussianConditional.cpp +++ b/gtsam/hybrid/HybridGaussianConditional.cpp @@ -363,4 +363,12 @@ double HybridGaussianConditional::evaluate(const HybridValues &values) const { return conditional->evaluate(values.continuous()); } +/* ************************************************************************ */ +std::shared_ptr HybridGaussianConditional::restrict( + const DiscreteValues &assignment) const { + throw std::runtime_error( + "HybridGaussianConditional::restrict not implemented"); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianConditional.h b/gtsam/hybrid/HybridGaussianConditional.h index 3b95e0277..6fc60a482 100644 --- a/gtsam/hybrid/HybridGaussianConditional.h +++ b/gtsam/hybrid/HybridGaussianConditional.h @@ -241,6 +241,10 @@ class GTSAM_EXPORT HybridGaussianConditional /// Return true if the conditional has already been pruned. bool pruned() const { return pruned_; } + /// Restrict to the given discrete values. + std::shared_ptr restrict( + const DiscreteValues &discreteValues) const override; + /// @} private: diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index fd9bd2fd4..616e015f6 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -199,4 +199,12 @@ double HybridGaussianFactor::error(const HybridValues& values) const { return PotentiallyPrunedComponentError(pair, values.continuous()); } +/* ************************************************************************ */ +std::shared_ptr HybridGaussianFactor::restrict( + const DiscreteValues& assignment) const { + throw std::runtime_error("HybridGaussianFactor::restrict not implemented"); +} + +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index efbba9e51..0bf38effa 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -157,6 +157,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { */ virtual HybridGaussianProductFactor asProductFactor() const; + /// Restrict the factor to the given discrete values. + std::shared_ptr restrict( + const DiscreteValues &discreteValues) const override; + /// @} private: diff --git a/gtsam/hybrid/HybridNonlinearFactor.cpp b/gtsam/hybrid/HybridNonlinearFactor.cpp index fa22051e5..900102e5d 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.cpp +++ b/gtsam/hybrid/HybridNonlinearFactor.cpp @@ -239,4 +239,21 @@ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune( return std::make_shared(discreteKeys(), prunedFactors); } +/* ************************************************************************ */ +std::shared_ptr HybridNonlinearFactor::restrict( + const DiscreteValues& assignment) const { + auto restrictedFactors = factors_.restrict(assignment); + auto filtered = assignment.filter(discreteKeys_); + if (filtered.size() == discreteKeys_.size()) { + auto [nonlinearFactor, val] = factors_(filtered); + return nonlinearFactor; + } else { + auto remainingKeys = assignment.missingKeys(discreteKeys()); + return std::make_shared(remainingKeys, + factors_.restrict(filtered)); + } +} + +/* ************************************************************************ */ + } // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearFactor.h b/gtsam/hybrid/HybridNonlinearFactor.h index e264b1d10..9fe08a364 100644 --- a/gtsam/hybrid/HybridNonlinearFactor.h +++ b/gtsam/hybrid/HybridNonlinearFactor.h @@ -80,6 +80,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { } public: + /// @name Constructors + /// @{ + /// Default constructor, mainly for serialization. HybridNonlinearFactor() = default; @@ -137,7 +140,7 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { * @return double The error of this factor. */ double error(const Values& continuousValues, - const DiscreteValues& discreteValues) const; + const DiscreteValues& assignment) const; /** * @brief Compute error of factor given hybrid values. @@ -154,7 +157,8 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { */ size_t dim() const; - /// Testable + /// @} + /// @name Testable /// @{ /// print to stdout @@ -165,15 +169,16 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { bool equals(const HybridFactor& other, double tol = 1e-9) const override; /// @} + /// @name Standard API + /// @{ /// Getter for NonlinearFactor decision tree const FactorValuePairs& factors() const { return factors_; } /// Linearize specific nonlinear factors based on the assignment in /// discreteValues. - GaussianFactor::shared_ptr linearize( - const Values& continuousValues, - const DiscreteValues& discreteValues) const; + GaussianFactor::shared_ptr linearize(const Values& continuousValues, + const DiscreteValues& assignment) const; /// Linearize all the continuous factors to get a HybridGaussianFactor. std::shared_ptr linearize( @@ -183,6 +188,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor { HybridNonlinearFactor::shared_ptr prune( const DecisionTreeFactor& discreteProbs) const; + /// Restrict the factor to the given discrete values. + std::shared_ptr restrict( + const DiscreteValues& assignment) const override; + + /// @} + private: /// Helper struct to assist private constructor below. struct ConstructorHelper; diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 2f5031cf2..b42676aac 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -221,5 +221,30 @@ AlgebraicDecisionTree HybridNonlinearFactorGraph::discretePosterior( return p / p.sum(); } +/* ************************************************************************ */ +HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict( + const DiscreteValues& discreteValues) const { + using std::dynamic_pointer_cast; + + HybridNonlinearFactorGraph result; + result.reserve(size()); + for (auto& f : factors_) { + // First check if it is a valid factor + if (!f) { + continue; + } + // Check if it is a hybrid factor + if (auto hf = dynamic_pointer_cast(f)) { + result.push_back(hf->restrict(discreteValues)); + } else if (auto df = dynamic_pointer_cast(f)) { + result.push_back(df->restrict(discreteValues)); + } else { + result.push_back(f); // Everything else is just added as is + } + } + + return result; +} + /* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index f79f7b452..9f91a74b9 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -116,6 +116,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { AlgebraicDecisionTree discretePosterior( const Values& continuousValues) const; + /// Restrict all factors in the graph to the given discrete values. + HybridNonlinearFactorGraph restrict( + const DiscreteValues& assignment) const; + /// @} }; diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactor.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactor.cpp index 2b441ab13..e7ef9d7d9 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactor.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactor.cpp @@ -131,6 +131,18 @@ TEST(HybridNonlinearFactor, Dim) { EXPECT_LONGS_EQUAL(1, hybridFactor.dim()); } +/* ************************************************************************* */ +// Test restrict method +TEST(HybridNonlinearFactor, Restrict) { + using namespace test_constructor; + HybridNonlinearFactor factor(m1, {f0, f1}); + DiscreteValues assignment = {{m1.first, 0}}; + auto restricted = factor.restrict(assignment); + auto betweenFactor = dynamic_pointer_cast>(restricted); + CHECK(betweenFactor); + EXPECT(assert_equal(*f0, *betweenFactor)); +} + /* ************************************************************************* */ int main() { TestResult tr;