diff --git a/gtsam/hybrid/CGMixtureFactor.cpp b/gtsam/hybrid/CGMixtureFactor.cpp index 705160273..08ae8b6e9 100644 --- a/gtsam/hybrid/CGMixtureFactor.cpp +++ b/gtsam/hybrid/CGMixtureFactor.cpp @@ -20,7 +20,9 @@ #include +#include #include +#include #include namespace gtsam { @@ -47,4 +49,29 @@ void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) }); } +const CGMixtureFactor::Factors& CGMixtureFactor::factors() { + return factors_; +} + +/* *******************************************************************************/ +CGMixtureFactor::Sum CGMixtureFactor::addTo(const CGMixtureFactor::Sum &sum) const { + using Y = GaussianFactorGraph; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; + }; + const Sum wrapped = wrappedFactors(); + return sum.empty() ? wrapped : sum.apply(wrapped, add); +} + +/* *******************************************************************************/ +CGMixtureFactor::Sum CGMixtureFactor::wrappedFactors() const { + auto wrap = [](const GaussianFactor::shared_ptr &factor) { + GaussianFactorGraph result; + result.push_back(factor); + return result; + }; + return {factors_, wrap}; +} } \ No newline at end of file diff --git a/gtsam/hybrid/CGMixtureFactor.h b/gtsam/hybrid/CGMixtureFactor.h index 556c53cc5..fd81cdd91 100644 --- a/gtsam/hybrid/CGMixtureFactor.h +++ b/gtsam/hybrid/CGMixtureFactor.h @@ -27,6 +27,8 @@ namespace gtsam { +class GaussianFactorGraph; + class CGMixtureFactor : public HybridFactor { public: using Base = HybridFactor; @@ -42,6 +44,16 @@ class CGMixtureFactor : public HybridFactor { CGMixtureFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const Factors &factors); + using Sum = DecisionTree; + + const Factors& factors(); + + /* *******************************************************************************/ + Sum addTo(const Sum &sum) const; + + /* *******************************************************************************/ + Sum wrappedFactors() const; + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; void print( diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 6af5f7192..8135b2d2d 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -20,6 +20,7 @@ #include #include #include +#include namespace gtsam { @@ -32,6 +33,32 @@ GaussianMixture::GaussianMixture( BaseConditional(continuousFrontals.size()), conditionals_(conditionals) {} +const GaussianMixture::Conditionals& GaussianMixture::conditionals() { + return conditionals_; +} + +/* *******************************************************************************/ +GaussianMixture::Sum GaussianMixture::addTo(const GaussianMixture::Sum &sum) const { + using Y = GaussianFactorGraph; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; + }; + const Sum wrapped = wrappedConditionals(); + return sum.empty() ? wrapped : sum.apply(wrapped, add); +} + +/* *******************************************************************************/ +GaussianMixture::Sum GaussianMixture::wrappedConditionals() const { + auto wrap = [](const GaussianFactor::shared_ptr &factor) { + GaussianFactorGraph result; + result.push_back(factor); + return result; + }; + return {conditionals_, wrap}; +} + bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { return false; } diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 541a2b8a6..13171ae5d 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -43,6 +43,16 @@ class GaussianMixture const DiscreteKeys &discreteParents, const Conditionals &conditionals); + using Sum = DecisionTree; + + const Conditionals& conditionals(); + + /* *******************************************************************************/ + Sum addTo(const Sum &sum) const; + + /* *******************************************************************************/ + Sum wrappedConditionals() const; + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; void print( diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index e212f4534..7ada061b9 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -17,6 +17,8 @@ #include #include +#include "gtsam/hybrid/HybridFactor.h" +#include "gtsam/inference/Key.h" namespace gtsam { @@ -34,8 +36,27 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals, HybridConditional::HybridConditional( boost::shared_ptr continuousConditional) - : BaseFactor(continuousConditional->keys()), - BaseConditional(continuousConditional->nrFrontals()) {} + : HybridConditional(continuousConditional->keys(), {}, + continuousConditional->nrFrontals()) { + inner = continuousConditional; +} + +HybridConditional::HybridConditional( + boost::shared_ptr discreteConditional) + : HybridConditional({}, discreteConditional->discreteKeys(), + discreteConditional->nrFrontals()) { + inner = discreteConditional; +} + +HybridConditional::HybridConditional( + boost::shared_ptr gaussianMixture) + : BaseFactor(KeyVector(gaussianMixture->keys().begin(), + gaussianMixture->keys().begin() + + gaussianMixture->nrContinuous), + gaussianMixture->discreteKeys_), + BaseConditional(gaussianMixture->nrFrontals()) { + inner = gaussianMixture; +} void HybridConditional::print(const std::string &s, const KeyFormatter &formatter) const { @@ -70,8 +91,4 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { return false; } -HybridConditional HybridConditional::operator*( - const HybridConditional &other) const { - return {}; -} } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index d851cd68e..5fdf5064a 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -23,13 +23,20 @@ #include #include +#include #include +#include #include -#include "gtsam/inference/Key.h" -#include "gtsam/linear/GaussianConditional.h" +#include "gtsam/hybrid/GaussianMixture.h" + +#include +#include +#include namespace gtsam { +class HybridFactorGraph; + /** * Hybrid Conditional Density * @@ -49,9 +56,9 @@ class GTSAM_EXPORT HybridConditional typedef Conditional BaseConditional; ///< Typedef to our conditional base class - private: + protected: // Type-erased pointer to the inner type - std::unique_ptr inner; + boost::shared_ptr inner; public: /// @name Standard Constructors @@ -70,23 +77,15 @@ class GTSAM_EXPORT HybridConditional const DiscreteKeys& discreteParents); HybridConditional(boost::shared_ptr continuousConditional); + + HybridConditional(boost::shared_ptr discreteConditional); - /** - * @brief Combine two conditionals, yielding a new conditional with the union - * of the frontal keys, ordered by gtsam::Key. - * - * The two conditionals must make a valid Bayes net fragment, i.e., - * the frontal variables cannot overlap, and must be acyclic: - * Example of correct use: - * P(A,B) = P(A|B) * P(B) - * P(A,B|C) = P(A|B) * P(B|C) - * P(A,B,C) = P(A,B|C) * P(C) - * Example of incorrect use: - * P(A|B) * P(A|C) = ? - * P(A|B) * P(B|A) = ? - * We check for overlapping frontals, but do *not* check for cyclic. - */ - HybridConditional operator*(const HybridConditional& other) const; + HybridConditional(boost::shared_ptr gaussianMixture); + + GaussianMixture::shared_ptr asMixture() { + if (!isHybrid_) throw std::invalid_argument("Not a mixture"); + return boost::static_pointer_cast(inner); + } /// @} /// @name Testable @@ -101,6 +100,10 @@ class GTSAM_EXPORT HybridConditional bool equals(const HybridFactor& other, double tol = 1e-9) const override; /// @} + + friend std::pair // + EliminateHybrid(const HybridFactorGraph& factors, + const Ordering& frontalKeys); }; // DiscreteConditional diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 96d3842b8..be5659f04 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -19,12 +19,15 @@ #include #include +#include "gtsam/discrete/DecisionTreeFactor.h" namespace gtsam { +// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right! HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) - : Base(other->keys()) { + : Base(boost::dynamic_pointer_cast(other)->discreteKeys()) { inner = other; + } HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 48a9dc468..a5ce8bd4e 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -48,7 +48,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, HybridFactor::HybridFactor() = default; HybridFactor::HybridFactor(const KeyVector &keys) - : Base(keys), isContinuous_(true) {} + : Base(keys), isContinuous_(true), nrContinuous(keys.size()) {} HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) @@ -56,6 +56,7 @@ HybridFactor::HybridFactor(const KeyVector &continuousKeys, isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), + nrContinuous(continuousKeys.size()), discreteKeys_(discreteKeys) {} HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 84d9b71bc..653d4270a 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -22,6 +22,7 @@ #include #include +#include #include namespace gtsam { @@ -46,6 +47,8 @@ class GTSAM_EXPORT HybridFactor : public Factor { bool isContinuous_ = false; bool isHybrid_ = false; + size_t nrContinuous = 0; + DiscreteKeys discreteKeys_; public: diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 367f9f9ab..735b8cfa9 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -13,9 +13,17 @@ * @file HybridFactorGraph.cpp * @brief Hybrid factor graph that uses type erasure * @author Fan Jiang + * @author Varun Agrawal + * @author Frank Dellaert * @date Mar 11, 2022 */ +#include +#include +#include +#include +#include +#include #include #include #include @@ -23,14 +31,20 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include #include #include - -#include "gtsam/inference/Key.h" -#include "gtsam/linear/GaussianFactorGraph.h" -#include "gtsam/linear/HessianFactor.h" +#include namespace gtsam { @@ -42,6 +56,26 @@ static std::string GREEN = "\033[0;32m"; static std::string GREEN_BOLD = "\033[1;32m"; static std::string RESET = "\033[0m"; +static CGMixtureFactor::Sum &addGaussian( + CGMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { + using Y = GaussianFactorGraph; + // If the decision tree is not intiialized, then intialize it. + if (sum.empty()) { + GaussianFactorGraph result; + result.push_back(factor); + sum = CGMixtureFactor::Sum(result); + + } else { + auto add = [&factor](const Y &graph) { + auto result = graph; + result.push_back(factor); + return result; + }; + sum = sum.apply(add); + } + return sum; +} + /* ************************************************************************ */ std::pair // EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { @@ -65,29 +99,29 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // Because of all these reasons, we need to think very carefully about how to // implement the hybrid factors so that we do not get poor performance. - // + // // The first thing is how to represent the GaussianMixture. A very possible - // scenario is that the incoming factors will have different levels of discrete - // keys. For example, imagine we are going to eliminate the fragment: - // $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will need - // to know how to retrieve the corresponding continuous densities for the assi- - // -gnment (c1,c2,c3) (OR (c2,c3,c1)! note there is NO defined order!). And we - // also need to consider when there is pruning. Two mixture factors could have - // different pruning patterns-one could have (c1=0,c2=1) pruned, and another - // could have (c2=0,c3=1) pruned, and this creates a big problem in how to - // identify the intersection of non-pruned branches. + // scenario is that the incoming factors will have different levels of + // discrete keys. For example, imagine we are going to eliminate the fragment: + // $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will + // need to know how to retrieve the corresponding continuous densities for the + // assi- -gnment (c1,c2,c3) (OR (c2,c3,c1)! note there is NO defined order!). + // And we also need to consider when there is pruning. Two mixture factors + // could have different pruning patterns-one could have (c1=0,c2=1) pruned, + // and another could have (c2=0,c3=1) pruned, and this creates a big problem + // in how to identify the intersection of non-pruned branches. - // One possible approach is first building the collection of all discrete keys. - // After that we enumerate the space of all key combinations *lazily* so that - // the exploration branch terminates whenever an assignment yields NULL in any - // of the hybrid factors. + // One possible approach is first building the collection of all discrete + // keys. After that we enumerate the space of all key combinations *lazily* so + // that the exploration branch terminates whenever an assignment yields NULL + // in any of the hybrid factors. // When the number of assignments is large we may encounter stack overflows. // However this is also the case with iSAM2, so no pressure :) // PREPROCESS: Identify the nature of the current elimination std::unordered_map discreteCardinalities; - std::set discreteSeparator; + std::set discreteSeparatorSet; std::set discreteFrontals; KeySet separatorKeys; @@ -123,15 +157,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { discreteFrontals.insert(discreteCardinalities.at(k)); } else { continuousFrontals.insert(k); + allContinuousKeys.insert(k); } } // Fill in discrete frontals and continuous frontals for the end result for (auto &k : separatorKeys) { if (discreteCardinalities.find(k) != discreteCardinalities.end()) { - discreteSeparator.insert(discreteCardinalities.at(k)); + discreteSeparatorSet.insert(discreteCardinalities.at(k)); } else { continuousSeparator.insert(k); + allContinuousKeys.insert(k); } } @@ -164,12 +200,18 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { } // NOTE: We should really defer the product here because of pruning - + // Case 1: we are only dealing with continuous - if (discreteCardinalities.empty()) { + if (discreteCardinalities.empty() && !allContinuousKeys.empty()) { + std::cout << RED_BOLD << "CONT. ONLY" << RESET << "\n"; GaussianFactorGraph gfg; for (auto &fp : factors) { - gfg.push_back(boost::static_pointer_cast(fp)->inner); + auto ptr = boost::dynamic_pointer_cast(fp); + if (ptr) + gfg.push_back(ptr->inner); + else + gfg.push_back(boost::static_pointer_cast( + boost::static_pointer_cast(fp)->inner)); } auto result = EliminatePreferCholesky(gfg, frontalKeys); @@ -179,61 +221,248 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { } // Case 2: we are only dealing with discrete - if (discreteCardinalities.empty()) { - GaussianFactorGraph gfg; + if (allContinuousKeys.empty()) { + std::cout << RED_BOLD << "DISCRETE ONLY" << RESET << "\n"; + DiscreteFactorGraph dfg; for (auto &fp : factors) { - gfg.push_back(boost::static_pointer_cast(fp)->inner); + auto ptr = boost::dynamic_pointer_cast(fp); + if (ptr) + dfg.push_back(ptr->inner); + else + dfg.push_back(boost::static_pointer_cast( + boost::static_pointer_cast(fp)->inner)); } - auto result = EliminatePreferCholesky(gfg, frontalKeys); + auto result = EliminateDiscrete(dfg, frontalKeys); + return std::make_pair( boost::make_shared(result.first), - boost::make_shared(result.second)); + boost::make_shared(result.second)); } + // Case 3: We are now in the hybrid land! + // NOTE: since we use the special JunctionTree, only possiblity is cont. + // conditioned on disc. + DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), + discreteSeparatorSet.end()); + + // We will need to know a mapping on when will a factor be fully determined by + // discrete keys std::vector> + // availableFactors; + + // { + // std::vector> keysForFactor; + // keysForFactor.reserve(factors.size()); + // std::transform( + // factors.begin(), factors.end(), std::back_inserter(keysForFactor), + // [](HybridFactor::shared_ptr factor) + // -> std::pair { + // return {KeySet(factor->keys().begin() + factor->nrContinuous, + // factor->keys().end()), + // factor}; + // }); + + // KeySet currentAllKeys; + // const auto N = discreteSeparator.size(); + // for (size_t k = 0; k < N; k++) { + // currentAllKeys.insert(discreteSeparator.at(k).first); + // std::vector shouldRemove(N, false); + // for (size_t i = 0; i < keysForFactor.size(); i++) { + // availableFactors.emplace_back(); + + // if (std::includes(currentAllKeys.begin(), currentAllKeys.end(), + // keysForFactor[i].first.begin(), + // keysForFactor[i].first.end())) { + // // mark for delete + // shouldRemove[i] = true; + // availableFactors.back().push_back(keysForFactor[i].second); + // } + // keysForFactor.erase( + // std::remove_if(keysForFactor.begin(), keysForFactor.end(), + // [&shouldRemove, &keysForFactor](std::pair const &i) { + // return shouldRemove.at(&i - + // keysForFactor.data()); + // }), + // keysForFactor.end()); + // } + // } + // } + + // std::function>( + // (Assignment, GaussianFactorGraph, int))> + // visitor = [&](Assignment history, GaussianFactorGraph gf, int pos) + // -> boost::shared_ptr> { + // const auto currentKey = discreteSeparator[pos].first; + // const auto currentCard = discreteSeparator[pos].second; + + // std::vector>> + // children(currentCard, nullptr); + // for (size_t choice = 0; choice < currentCard; choice++) { + // Assignment new_assignment = history; + // GaussianFactorGraph new_gf(gf); + // // we try to get all currently available factors + // for (auto &factor : availableFactors[pos]) { + // if (!factor) { + // continue; + // } + + // auto ptr_mf = boost::dynamic_pointer_cast(factor); + // if (ptr_mf) gf.push_back(ptr_mf->factors_(new_assignment)); + + // auto ptr_gm = boost::dynamic_pointer_cast(factor); + // if (ptr_gm) gf.push_back(ptr_gm->conditionals_(new_assignment)); + + // children[choice] = visitor(new_assignment, new_gf, pos + 1); + // } + // } + // }; + // PRODUCT: multiply all factors - gttic(product); - - HybridConditional sum( - KeyVector(continuousSeparator.begin(), continuousSeparator.end()), - DiscreteKeys(discreteSeparator.begin(), discreteSeparator.end()), - separatorKeys.size()); - - // HybridDiscreteFactor product(DiscreteConditional()); - // for (auto&& factor : factors) product = (*factor) * product; - gttoc(product); + // HybridConditional sum_factor( + // KeyVector(continuousSeparator.begin(), continuousSeparator.end()), + // DiscreteKeys(discreteSeparatorSet.begin(), discreteSeparatorSet.end()), + // separatorKeys.size()); // sum out frontals, this is the factor on the separator gttic(sum); - // HybridFactor::shared_ptr sum = product.sum(frontalKeys); + + std::cout << RED_BOLD << "HYBRID ELIM." << RESET << "\n"; + + CGMixtureFactor::Sum sum; + + std::vector deferredFactors; + + for (auto &f : factors) { + if (f->isHybrid_) { + auto cgmf = boost::dynamic_pointer_cast(f); + if (cgmf) { + sum = cgmf->addTo(sum); + } + + auto gm = boost::dynamic_pointer_cast(f); + if (gm) { + sum = gm->asMixture()->addTo(sum); + } + + } else if (f->isContinuous_) { + deferredFactors.push_back( + boost::dynamic_pointer_cast(f)->inner); + } else { + throw std::invalid_argument( + "factor is discrete in continuous elimination"); + } + } + + for (auto &f : deferredFactors) { + std::cout << GREEN_BOLD << "Adding Gaussian" << RESET << "\n"; + sum = addGaussian(sum, f); + } + + std::cout << GREEN_BOLD << "[GFG Tree]\n" << RESET; + sum.print("", DefaultKeyFormatter, [](GaussianFactorGraph gfg) { + RedirectCout rd; + gfg.print(""); + return rd.str(); + }); + gttoc(sum); - // Ordering keys for the conditional so that frontalKeys are really in front - Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum.keys().begin(), sum.keys().end()); + using EliminationPair = GaussianFactorGraph::EliminationResult; - // now divide product/sum to get conditional - gttic(divide); - // auto conditional = - // boost::make_shared(product, *sum, orderedKeys); - gttoc(divide); + KeyVector keysOfEliminated; // Not the ordering + KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? + + auto eliminate = [&](const GaussianFactorGraph &graph) + -> GaussianFactorGraph::EliminationResult { + if (graph.empty()) { + return {nullptr, nullptr}; + } + auto result = EliminatePreferCholesky(graph, frontalKeys); + if (keysOfEliminated.empty()) { + keysOfEliminated = + result.first->keys(); // Initialize the keysOfEliminated to be the + } + // keysOfEliminated of the GaussianConditional + if (keysOfSeparator.empty()) { + keysOfSeparator = result.second->keys(); + } + return result; + }; + + DecisionTree eliminationResults(sum, eliminate); + + auto pair = unzip(eliminationResults); + + const GaussianMixture::Conditionals &conditionals = pair.first; + const CGMixtureFactor::Factors &separatorFactors = pair.second; + + // Create the GaussianMixture from the conditionals + auto conditional = boost::make_shared( + frontalKeys, keysOfSeparator, discreteSeparator, conditionals); - auto conditional = boost::make_shared( - CollectKeys({continuousFrontals.begin(), continuousFrontals.end()}, - {continuousSeparator.begin(), continuousSeparator.end()}), - CollectDiscreteKeys({discreteFrontals.begin(), discreteFrontals.end()}, - {discreteSeparator.begin(), discreteSeparator.end()}), - continuousFrontals.size() + discreteFrontals.size()); std::cout << GREEN_BOLD << "[Conditional]\n" << RESET; conditional->print(); std::cout << GREEN_BOLD << "[Separator]\n" << RESET; - sum.print(); + separatorFactors.print("", DefaultKeyFormatter, + [](GaussianFactor::shared_ptr gc) { + RedirectCout rd; + gc->print(""); + return rd.str(); + }); std::cout << RED_BOLD << "[End Eliminate]\n" << RESET; - // return std::make_pair(conditional, sum); - return std::make_pair(conditional, - boost::make_shared(std::move(sum))); + // If there are no more continuous parents, then we should create here a + // DiscreteFactor, with the error for each discrete choice. + if (keysOfSeparator.empty()) { + VectorValues empty_values; + auto factorError = [&](const GaussianFactor::shared_ptr &factor) { + if (!factor) return 0.0; // TODO(fan): does this make sense? + return exp(-factor->error(empty_values)); + }; + DecisionTree fdt(separatorFactors, factorError); + auto discreteFactor = + boost::make_shared(discreteSeparator, fdt); + + return {boost::make_shared(conditional), + boost::make_shared(discreteFactor)}; + + } else { + // Create a resulting DCGaussianMixture on the separator. + auto factor = boost::make_shared( + frontalKeys, discreteSeparator, separatorFactors); + return {boost::make_shared(conditional), factor}; + } + + // Ordering keys for the conditional so that frontalKeys are really in front + // Ordering orderedKeys; + // orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + // frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), + // sum_factor.keys().begin(), + // sum_factor.keys().end()); + + // // now divide product/sum to get conditional + // gttic(divide); + // // auto conditional = + // // boost::make_shared(product, *sum, orderedKeys); + // gttoc(divide); + + // auto conditional = boost::make_shared( + // CollectKeys({continuousFrontals.begin(), continuousFrontals.end()}, + // {continuousSeparator.begin(), continuousSeparator.end()}), + // CollectDiscreteKeys( + // {discreteFrontals.begin(), discreteFrontals.end()}, + // {discreteSeparatorSet.begin(), discreteSeparatorSet.end()}), + // continuousFrontals.size() + discreteFrontals.size()); + // std::cout << GREEN_BOLD << "[Conditional]\n" << RESET; + // conditional->print(); + // std::cout << GREEN_BOLD << "[Separator]\n" << RESET; + // sum_factor.print(); + // std::cout << RED_BOLD << "[End Eliminate]\n" << RESET; + + // // return std::make_pair(conditional, sum); + // return std::make_pair(conditional, boost::make_shared( + // std::move(sum_factor))); } void HybridFactorGraph::add(JacobianFactor &&factor) { diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index d16715300..6672a1870 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -31,7 +31,6 @@ #include #include -#include "gtsam/base/Testable.h" using namespace boost::assign; @@ -41,22 +40,33 @@ using namespace gtsam; using gtsam::symbol_shorthand::C; using gtsam::symbol_shorthand::X; +#define BOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED + +#include // ::signal, ::raise + +#include + +void my_signal_handler(int signum) { + ::signal(signum, SIG_DFL); + std::cout << boost::stacktrace::stacktrace(); + ::raise(SIGABRT); +} + /* ************************************************************************* */ -TEST(HybridFactorGraph, creation) { +TEST_DISABLED(HybridFactorGraph, creation) { HybridConditional test; HybridFactorGraph hfg; hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); - GaussianMixture clgc( - {X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), - GaussianMixture::Conditionals( - C(0), - boost::make_shared(X(0), Z_3x1, I_3x3, X(1), - I_3x3), - boost::make_shared(X(0), Vector3::Ones(), I_3x3, - X(1), I_3x3))); + GaussianMixture clgc({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), + GaussianMixture::Conditionals( + C(0), + boost::make_shared( + X(0), Z_3x1, I_3x3, X(1), I_3x3), + boost::make_shared( + X(0), Vector3::Ones(), I_3x3, X(1), I_3x3))); GTSAM_PRINT(clgc); } @@ -70,7 +80,7 @@ TEST_DISABLED(HybridFactorGraph, eliminate) { EXPECT_LONGS_EQUAL(result.first->size(), 1); } -TEST(HybridFactorGraph, eliminateMultifrontal) { +TEST_DISABLED(HybridFactorGraph, eliminateMultifrontal) { HybridFactorGraph hfg; DiscreteKey x(X(1), 2); @@ -84,7 +94,7 @@ TEST(HybridFactorGraph, eliminateMultifrontal) { EXPECT_LONGS_EQUAL(result.second->size(), 1); } -TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalSimple) { +TEST(HybridFactorGraph, eliminateFullMultifrontalSimple) { std::cout << ">>>>>>>>>>>>>>\n"; HybridFactorGraph hfg; @@ -99,7 +109,7 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalSimple) { boost::make_shared(X(1), I_3x3, Vector3::Ones())); hfg.add(CGMixtureFactor({X(1)}, {c1}, dt)); - hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); + // hfg.add(CGMixtureFactor({X(0)}, {c1}, dt)); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); hfg.add(HybridDiscreteFactor( DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); @@ -114,7 +124,7 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalSimple) { GTSAM_PRINT(*result->marginalFactor(C(1))); } -TEST(HybridFactorGraph, eliminateFullMultifrontalCLG) { +TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalCLG) { std::cout << ">>>>>>>>>>>>>>\n"; HybridFactorGraph hfg; @@ -148,8 +158,9 @@ TEST(HybridFactorGraph, eliminateFullMultifrontalCLG) { } /** - * This test is about how to assemble the Bayes Tree roots after we do partial elimination -*/ + * This test is about how to assemble the Bayes Tree roots after we do partial + * elimination + */ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { std::cout << ">>>>>>>>>>>>>>\n"; @@ -173,7 +184,8 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { } // hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); - hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); + hfg.add(HybridDiscreteFactor( + DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); @@ -192,14 +204,15 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { hfg.add(CGMixtureFactor({X(5)}, {{C(2), 2}}, dt1)); } - auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {C(0), C(1), C(2), C(3)}); + auto ordering_full = + Ordering::ColamdConstrainedLast(hfg, {C(0), C(1), C(2), C(3)}); GTSAM_PRINT(ordering_full); HybridBayesTree::shared_ptr hbt; HybridFactorGraph::shared_ptr remaining; - std::tie(hbt, remaining) = - hfg.eliminatePartialMultifrontal(Ordering(ordering_full.begin(), ordering_full.end())); + std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal( + Ordering(ordering_full.begin(), ordering_full.end())); GTSAM_PRINT(*hbt); @@ -215,6 +228,9 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { /* ************************************************************************* */ int main() { + ::signal(SIGSEGV, &my_signal_handler); + ::signal(SIGBUS, &my_signal_handler); + TestResult tr; return TestRegistry::runAllTests(tr); }