diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f93d21651..f898178c2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -125,8 +125,14 @@ GaussianBayesNet HybridBayesNet::choose( const DiscreteValues &assignment) const { GaussianBayesNet gbn; for (size_t idx = 0; idx < size(); idx++) { - GaussianMixture gm = *this->atGaussian(idx); - gbn.push_back(gm(assignment)); + try { + GaussianMixture gm = *this->atGaussian(idx); + gbn.push_back(gm(assignment)); + + } catch (std::exception &exc) { + // if factor at `idx` is discrete-only, just continue. + continue; + } } return gbn; } diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c024c1255..94e33890d 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors, for (auto &fp : factors) { if (auto ptr = boost::dynamic_pointer_cast(fp)) { gfg.push_back(ptr->inner()); - } else if (auto p = - boost::static_pointer_cast(fp)->inner()) { - gfg.push_back(boost::static_pointer_cast(p)); + } else if (auto ptr = boost::static_pointer_cast(fp)) { + gfg.push_back( + boost::static_pointer_cast(ptr->inner())); } else { // It is an orphan wrapped conditional } @@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } +/* ************************************************************************ */ +const Ordering HybridGaussianFactorGraph::getHybridOrdering( + OptionalOrderingType orderingType) const { + KeySet discrete_keys; + for (auto &factor : factors_) { + for (const DiscreteKey &k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + + const VariableIndex index(factors_); + Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); + return ordering; +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 936710330..56f9b7e07 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph Base::push_back(sharedFactor); } } + + /** + * @brief + * + * @param orderingType + * @return const Ordering + */ + const Ordering getHybridOrdering( + OptionalOrderingType orderingType = boost::none) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index d447bcce2..4602e8bac 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -19,6 +19,7 @@ */ #include +#include #include "Switching.h" @@ -87,7 +88,7 @@ TEST(HybridBayesNet, Choose) { /* ****************************************************************************/ // Test bayes net optimize -TEST(HybridBayesNet, Optimize) { +TEST(HybridBayesNet, OptimizeAssignment) { Switching s(4); Ordering ordering; @@ -119,6 +120,42 @@ TEST(HybridBayesNet, Optimize) { EXPECT(assert_equal(expected_delta, delta)); } +/* ****************************************************************************/ +// Test bayes net optimize +TEST(HybridBayesNet, Optimize) { + Switching s(4); + + Ordering ordering; + for (auto&& kvp : s.linearizationPoint) { + ordering += kvp.key; + } + + Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet::shared_ptr hybridBayesNet = + s.linearizedFactorGraph.eliminateSequential(hybridOrdering); + + HybridValues delta = hybridBayesNet->optimize(); + + delta.print(); + VectorValues correct; + correct.insert(X(1), 0 * Vector1::Ones()); + correct.insert(X(2), 1 * Vector1::Ones()); + correct.insert(X(3), 2 * Vector1::Ones()); + correct.insert(X(4), 3 * Vector1::Ones()); + + DiscreteValues assignment111; + assignment111[M(1)] = 1; + assignment111[M(2)] = 1; + assignment111[M(3)] = 1; + std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl; + + DiscreteValues assignment101; + assignment101[M(1)] = 1; + assignment101[M(2)] = 0; + assignment101[M(3)] = 1; + std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl; +} + /* ************************************************************************* */ int main() { TestResult tr;