diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 7f42e6986..8278549d0 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -71,29 +71,27 @@ GaussianMixture::GaussianMixture( Conditionals(discreteParents, conditionals)) {} /* *******************************************************************************/ -// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from -// GaussianMixtureFactor, no? -GaussianFactorGraphTree GaussianMixture::add( - const GaussianFactorGraphTree &sum) const { - using Y = GaussianFactorGraph; - auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; - }; - const auto tree = asGaussianFactorGraphTree(); - return sum.empty() ? tree : sum.apply(tree, add); -} - -/* *******************************************************************************/ -GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { +GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const { auto wrap = [](const GaussianConditional::shared_ptr &gc) { - return GaussianFactorGraph{gc}; + if (gc) { + return GaussianBayesNet{gc}; + } else { + return GaussianBayesNet(); + } }; return {conditionals_, wrap}; } /* *******************************************************************************/ +GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { + auto wrap = [](const GaussianBayesNet &gbn) { + return GaussianFactorGraph(gbn); + }; + return {this->asGaussianBayesNetTree(), wrap}; +} + +/* +*******************************************************************************/ GaussianBayesNetTree GaussianMixture::add( const GaussianBayesNetTree &sum) const { using Y = GaussianBayesNet; @@ -110,15 +108,18 @@ GaussianBayesNetTree GaussianMixture::add( } /* *******************************************************************************/ -GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const { - auto wrap = [](const GaussianConditional::shared_ptr &gc) { - if (gc) { - return GaussianBayesNet{gc}; - } else { - return GaussianBayesNet(); - } +// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from +// GaussianMixtureFactor, no? +GaussianFactorGraphTree GaussianMixture::add( + const GaussianFactorGraphTree &sum) const { + using Y = GaussianFactorGraph; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; }; - return {conditionals_, wrap}; + const auto tree = asGaussianFactorGraphTree(); + return sum.empty() ? tree : sum.apply(tree, add); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 29a850da8..1d01baed2 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -26,16 +26,6 @@ static std::mt19937_64 kRandomNumberGenerator(42); namespace gtsam { -/* ************************************************************************ */ -// Throw a runtime exception for method specified in string s, -// and conditional f: -static void throwRuntimeError(const std::string &s, - const std::shared_ptr &f) { - auto &fr = *f; - throw std::runtime_error(s + " not implemented for conditional type " + - demangle(typeid(fr).name()) + "."); -} - /* ************************************************************************* */ void HybridBayesNet::print(const std::string &s, const KeyFormatter &formatter) const { @@ -227,141 +217,17 @@ GaussianBayesNet HybridBayesNet::choose( return gbn; } -/* ************************************************************************ */ -static GaussianBayesNetTree addGaussian( - const GaussianBayesNetTree &gfgTree, - const GaussianConditional::shared_ptr &factor) { - // If the decision tree is not initialized, then initialize it. - if (gfgTree.empty()) { - GaussianBayesNet result{factor}; - return GaussianBayesNetTree(result); - } else { - auto add = [&factor](const GaussianBayesNet &graph) { - auto result = graph; - result.push_back(factor); - return result; - }; - return gfgTree.apply(add); - } -} - -/* ************************************************************************ */ -GaussianBayesNetValTree HybridBayesNet::assembleTree() const { - GaussianBayesNetTree result; - - for (auto &f : factors_) { - // TODO(dellaert): just use a virtual method defined in HybridFactor. - if (auto gm = std::dynamic_pointer_cast(f)) { - result = gm->add(result); - } else if (auto hc = std::dynamic_pointer_cast(f)) { - if (auto gm = hc->asMixture()) { - result = gm->add(result); - } else if (auto g = hc->asGaussian()) { - result = addGaussian(result, g); - } else { - // Has to be discrete. - // TODO(dellaert): in C++20, we can use std::visit. - continue; - } - } else if (std::dynamic_pointer_cast(f)) { - // Don't do anything for discrete-only factors - // since we want to evaluate continuous values only. - continue; - } else { - // We need to handle the case where the object is actually an - // BayesTreeOrphanWrapper! - throwRuntimeError("HybridBayesNet::assembleTree", f); - } - } - - GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) { - return std::make_pair(gbn, 0.0); - }); - return resultTree; -} - -/* ************************************************************************* */ -AlgebraicDecisionTree HybridBayesNet::modelSelection() const { - /* - To perform model selection, we need: - q(mu; M, Z) = exp(-error) - where error is computed at the corresponding MAP point, gbn.error(mu). - - So we compute `error` and exponentiate after. - */ - - GaussianBayesNetValTree bnTree = assembleTree(); - - GaussianBayesNetValTree bn_error = bnTree.apply( - [this](const Assignment &assignment, - const std::pair &gbnAndValue) { - // Compute the X* of each assignment - VectorValues mu = gbnAndValue.first.optimize(); - - // mu is empty if gbn had nullptrs - if (mu.size() == 0) { - return std::make_pair(gbnAndValue.first, - std::numeric_limits::max()); - } - - // Compute the error for X* and the assignment - double error = - this->error(HybridValues(mu, DiscreteValues(assignment))); - - return std::make_pair(gbnAndValue.first, error); - }); - - // Compute model selection term (with help from ADT methods) - auto trees = unzip(bn_error); - AlgebraicDecisionTree errorTree = trees.second; - AlgebraicDecisionTree modelSelectionTerm = errorTree * -1; - - double max_log = modelSelectionTerm.max(); - modelSelectionTerm = DecisionTree( - modelSelectionTerm, - [&max_log](const double &x) { return std::exp(x - max_log); }); - modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum()); - - return modelSelectionTerm; -} - /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE DiscreteFactorGraph discrete_fg; - // Compute model selection term - AlgebraicDecisionTree modelSelectionTerm = modelSelection(); - - // Get the set of all discrete keys involved in model selection - std::set discreteKeySet; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { discrete_fg.push_back(conditional->asDiscrete()); - } else { - if (conditional->isContinuous()) { - /* - If we are here, it means there are no discrete variables in - the Bayes net (due to strong elimination ordering). - This is a continuous-only problem hence model selection doesn't matter. - */ - - } else if (conditional->isHybrid()) { - auto gm = conditional->asMixture(); - // Include the discrete keys - std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), - std::inserter(discreteKeySet, discreteKeySet.end())); - } } } - // Only add model_selection if we have discrete keys - if (discreteKeySet.size() > 0) { - discrete_fg.push_back(DecisionTreeFactor( - DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), - modelSelectionTerm)); - } - // Solve for the MPE DiscreteValues mpe = discrete_fg.optimize(); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index fe9010b1f..032cd55b9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -118,34 +118,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { return evaluate(values); } - /** - * @brief Assemble a DecisionTree of (GaussianBayesNet, double) leaves for - * each discrete assignment. - * The included double value is used to make - * constructing the model selection term cleaner and more efficient. - * - * @return GaussianBayesNetValTree - */ - GaussianBayesNetValTree assembleTree() const; - - /* - Compute L(M;Z), the likelihood of the discrete model M - given the measurements Z. - This is called the model selection term. - - To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M). - - By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M), - hence L(X;M,Z)P(X|M) is the unnormalized probabilty of - the joint Gaussian distribution. - - This can be computed by multiplying all the exponentiated errors - of each of the conditionals. - - Return a tree where each leaf value is L(M_i;Z). - */ - AlgebraicDecisionTree modelSelection() const; - /** * @brief Solve the HybridBayesNet by first computing the MPE of all the * discrete variables and then optimizing the continuous variables based on diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 8828a9172..418489d66 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -13,6 +13,7 @@ * @file HybridFactor.h * @date Mar 11, 2022 * @author Fan Jiang + * @author Varun Agrawal */ #pragma once @@ -35,12 +36,6 @@ class HybridValues; using GaussianFactorGraphTree = DecisionTree; /// Alias for DecisionTree of GaussianBayesNets using GaussianBayesNetTree = DecisionTree; -/** - * Alias for DecisionTree of (GaussianBayesNet, double) pairs. - * Used for model selection in BayesNet::optimize - */ -using GaussianBayesNetValTree = - DecisionTree>; KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);