From 6cd20fba4d866092c3da220bc1323ebae5a8b901 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 27 May 2022 18:34:13 -0400 Subject: [PATCH] break up EliminateHybrid into smaller functions --- gtsam/hybrid/HybridFactorGraph.cpp | 310 +++++++++++++++-------------- 1 file changed, 161 insertions(+), 149 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index bffb0327c..426c86fa3 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -55,14 +55,6 @@ namespace gtsam { template class EliminateableFactorGraph; -static std::string BLACK_BOLD = "\033[1;30m"; -static std::string RED_BOLD = "\033[1;31m"; -static std::string GREEN = "\033[0;32m"; -static std::string GREEN_BOLD = "\033[1;32m"; -static std::string RESET = "\033[0m"; - -constexpr bool DEBUG = false; - /* ************************************************************************ */ static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { @@ -84,6 +76,163 @@ static GaussianMixtureFactor::Sum &addGaussian( return sum; } +/* ************************************************************************ */ +std::pair +continuousElimination(const HybridFactorGraph &factors, + const Ordering &frontalKeys) { + GaussianFactorGraph gfg; + for (auto &fp : factors) { + auto ptr = boost::dynamic_pointer_cast(fp); + if (ptr) { + gfg.push_back(ptr->inner()); + } else { + auto p = boost::static_pointer_cast(fp)->inner(); + if (p) { + gfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapped conditional + } + } + } + + auto result = EliminatePreferCholesky(gfg, frontalKeys); + return std::make_pair( + boost::make_shared(result.first), + boost::make_shared(result.second)); +} + +/* ************************************************************************ */ +std::pair +discreteElimination(const HybridFactorGraph &factors, + const Ordering &frontalKeys) { + DiscreteFactorGraph dfg; + for (auto &fp : factors) { + auto ptr = boost::dynamic_pointer_cast(fp); + if (ptr) { + dfg.push_back(ptr->inner()); + } else { + auto p = boost::static_pointer_cast(fp)->inner(); + if (p) { + dfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapper + } + } + } + + auto result = EliminateDiscrete(dfg, frontalKeys); + + return std::make_pair( + boost::make_shared(result.first), + boost::make_shared(result.second)); +} + +/* ************************************************************************ */ +std::pair +hybridElimination(const HybridFactorGraph &factors, const Ordering &frontalKeys, + const KeySet &continuousSeparator, + const std::set &discreteSeparatorSet) { + // NOTE: since we use the special JunctionTree, + // only possiblity is continuous conditioned on discrete. + DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), + discreteSeparatorSet.end()); + + // sum out frontals, this is the factor on the separator + gttic(sum); + + GaussianMixtureFactor::Sum sum; + std::vector deferredFactors; + + for (auto &f : factors) { + if (f->isHybrid()) { + auto cgmf = boost::dynamic_pointer_cast(f); + if (cgmf) { + sum = cgmf->add(sum); + } + + auto gm = boost::dynamic_pointer_cast(f); + if (gm) { + sum = gm->asMixture()->add(sum); + } + + } else if (f->isContinuous()) { + deferredFactors.push_back( + boost::dynamic_pointer_cast(f)->inner()); + } else { + // We need to handle the case where the object is actually an + // BayesTreeOrphanWrapper! + auto orphan = boost::dynamic_pointer_cast< + BayesTreeOrphanWrapper>(f); + if (!orphan) { + auto &fr = *f; + throw std::invalid_argument( + std::string("factor is discrete in continuous elimination") + + typeid(fr).name()); + } + } + } + + for (auto &f : deferredFactors) { + sum = addGaussian(sum, f); + } + + gttoc(sum); + + using EliminationPair = GaussianFactorGraph::EliminationResult; + + KeyVector keysOfEliminated; // Not the ordering + KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? + + auto eliminate = [&](const GaussianFactorGraph &graph) + -> GaussianFactorGraph::EliminationResult { + if (graph.empty()) { + return {nullptr, nullptr}; + } + auto result = EliminatePreferCholesky(graph, frontalKeys); + if (keysOfEliminated.empty()) { + keysOfEliminated = + result.first->keys(); // Initialize the keysOfEliminated to be the + } + // keysOfEliminated of the GaussianConditional + if (keysOfSeparator.empty()) { + keysOfSeparator = result.second->keys(); + } + return result; + }; + + DecisionTree eliminationResults(sum, eliminate); + + auto pair = unzip(eliminationResults); + + const GaussianMixtureFactor::Factors &separatorFactors = pair.second; + + // Create the GaussianMixtureConditional from the conditionals + auto conditional = boost::make_shared( + frontalKeys, keysOfSeparator, discreteSeparator, pair.first); + + // If there are no more continuous parents, then we should create here a + // DiscreteFactor, with the error for each discrete choice. + if (keysOfSeparator.empty()) { + VectorValues empty_values; + auto factorError = [&](const GaussianFactor::shared_ptr &factor) { + if (!factor) return 0.0; // TODO(fan): does this make sense? + return exp(-factor->error(empty_values)); + }; + DecisionTree fdt(separatorFactors, factorError); + auto discreteFactor = + boost::make_shared(discreteSeparator, fdt); + + return {boost::make_shared(conditional), + boost::make_shared(discreteFactor)}; + + } else { + // Create a resulting DCGaussianMixture on the separator. + auto factor = boost::make_shared( + KeyVector(continuousSeparator.begin(), continuousSeparator.end()), + discreteSeparator, separatorFactors); + return {boost::make_shared(conditional), factor}; + } +} /* ************************************************************************ */ std::pair // EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { @@ -177,154 +326,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // Case 1: we are only dealing with continuous if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) { - GaussianFactorGraph gfg; - for (auto &fp : factors) { - auto ptr = boost::dynamic_pointer_cast(fp); - if (ptr) { - gfg.push_back(ptr->inner()); - } else { - auto p = boost::static_pointer_cast(fp)->inner(); - if (p) { - gfg.push_back(boost::static_pointer_cast(p)); - } else { - // It is an orphan wrapped conditional - } - } - } - - auto result = EliminatePreferCholesky(gfg, frontalKeys); - return std::make_pair( - boost::make_shared(result.first), - boost::make_shared(result.second)); + return continuousElimination(factors, frontalKeys); } // Case 2: we are only dealing with discrete if (allContinuousKeys.empty()) { - DiscreteFactorGraph dfg; - for (auto &fp : factors) { - auto ptr = boost::dynamic_pointer_cast(fp); - if (ptr) { - dfg.push_back(ptr->inner()); - } else { - auto p = boost::static_pointer_cast(fp)->inner(); - if (p) { - dfg.push_back(boost::static_pointer_cast(p)); - } else { - // It is an orphan wrapper - } - } - } - - auto result = EliminateDiscrete(dfg, frontalKeys); - - return std::make_pair( - boost::make_shared(result.first), - boost::make_shared(result.second)); + return discreteElimination(factors, frontalKeys); } // Case 3: We are now in the hybrid land! - // NOTE: since we use the special JunctionTree, only possiblity is cont. - // conditioned on disc. - DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), - discreteSeparatorSet.end()); - - // sum out frontals, this is the factor on the separator - gttic(sum); - - GaussianMixtureFactor::Sum sum; - - std::vector deferredFactors; - - for (auto &f : factors) { - if (f->isHybrid()) { - auto cgmf = boost::dynamic_pointer_cast(f); - if (cgmf) { - sum = cgmf->add(sum); - } - - auto gm = boost::dynamic_pointer_cast(f); - if (gm) { - sum = gm->asMixture()->add(sum); - } - - } else if (f->isContinuous()) { - deferredFactors.push_back( - boost::dynamic_pointer_cast(f)->inner()); - } else { - // We need to handle the case where the object is actually an - // BayesTreeOrphanWrapper! - auto orphan = boost::dynamic_pointer_cast< - BayesTreeOrphanWrapper>(f); - if (!orphan) { - auto &fr = *f; - throw std::invalid_argument( - std::string("factor is discrete in continuous elimination") + - typeid(fr).name()); - } - } - } - - for (auto &f : deferredFactors) { - sum = addGaussian(sum, f); - } - - gttoc(sum); - - using EliminationPair = GaussianFactorGraph::EliminationResult; - - KeyVector keysOfEliminated; // Not the ordering - KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? - - auto eliminate = [&](const GaussianFactorGraph &graph) - -> GaussianFactorGraph::EliminationResult { - if (graph.empty()) { - return {nullptr, nullptr}; - } - auto result = EliminatePreferCholesky(graph, frontalKeys); - if (keysOfEliminated.empty()) { - keysOfEliminated = - result.first->keys(); // Initialize the keysOfEliminated to be the - } - // keysOfEliminated of the GaussianConditional - if (keysOfSeparator.empty()) { - keysOfSeparator = result.second->keys(); - } - return result; - }; - - DecisionTree eliminationResults(sum, eliminate); - - auto pair = unzip(eliminationResults); - - const GaussianMixtureConditional::Conditionals &conditionals = pair.first; - const GaussianMixtureFactor::Factors &separatorFactors = pair.second; - - // Create the GaussianMixtureConditional from the conditionals - auto conditional = boost::make_shared( - frontalKeys, keysOfSeparator, discreteSeparator, conditionals); - - // If there are no more continuous parents, then we should create here a - // DiscreteFactor, with the error for each discrete choice. - if (keysOfSeparator.empty()) { - VectorValues empty_values; - auto factorError = [&](const GaussianFactor::shared_ptr &factor) { - if (!factor) return 0.0; // TODO(fan): does this make sense? - return exp(-factor->error(empty_values)); - }; - DecisionTree fdt(separatorFactors, factorError); - auto discreteFactor = - boost::make_shared(discreteSeparator, fdt); - - return {boost::make_shared(conditional), - boost::make_shared(discreteFactor)}; - - } else { - // Create a resulting DCGaussianMixture on the separator. - auto factor = boost::make_shared( - KeyVector(continuousSeparator.begin(), continuousSeparator.end()), - discreteSeparator, separatorFactors); - return {boost::make_shared(conditional), factor}; - } + return hybridElimination(factors, frontalKeys, continuousSeparator, + discreteSeparatorSet); } /* ************************************************************************ */