From c374a26b45ed1f25c597ad08d12594ee4414282e Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 20 Nov 2023 13:31:07 -0500 Subject: [PATCH] nicer HybridBayesNet::optimize with normalized errors --- gtsam/hybrid/HybridBayesNet.cpp | 75 ++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 24 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ba869a6f5..cd1157576 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -223,22 +223,38 @@ HybridValues HybridBayesNet::optimize() const { DiscreteFactorGraph discrete_fg; VectorValues continuousValues; + // Error values for each hybrid factor + AlgebraicDecisionTree error(0.0); + std::set discreteKeySet; + for (auto &&conditional : *this) { if (conditional->isDiscrete()) { discrete_fg.push_back(conditional->asDiscrete()); } else { /* - Perform the integration of L(X;M, Z)P(X|M) which is the model selection - term. - TODO(Varun) Write better comments detailing the whole process. + Perform the integration of L(X;M,Z)P(X|M) + which is the model selection term. + + By Bayes' rule, P(X|M) ∝ L(X;M,Z)P(X|M), + hence L(X;M,Z)P(X|M) is the unnormalized probabilty of + the joint Gaussian distribution. + + This can be computed by multiplying all the exponentiaed errors + of each of the conditionals, which we do below in hybrid case. */ if (conditional->isContinuous()) { + /* + If we are here, it means there are no discrete variables in + the Bayes net (due to strong elimination ordering). + This is a continuous-only problem hence model selection doesn't matter. + */ auto gc = conditional->asGaussian(); for (GaussianConditional::const_iterator frontal = gc->beginFrontals(); frontal != gc->endFrontals(); ++frontal) { continuousValues.insert_or_assign(*frontal, Vector::Zero(gc->getDim(frontal))); } + } else if (conditional->isHybrid()) { auto gm = conditional->asMixture(); gm->conditionals().apply( @@ -253,36 +269,47 @@ HybridValues HybridBayesNet::optimize() const { return gc; }); - DecisionTree error = gm->error(continuousValues); + /* + To perform model selection, we need: + q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) - // Functional to take error and compute the probability - auto integrate = [&gm](const double &error) { - // q(mu; M, Z) = exp(-error) - // k = 1.0 / sqrt((2*pi)^n*det(Sigma)) - // thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k)) - // = exp(-error - log(k)) - double prob = std::exp(-error - gm->logNormalizationConstant()); - if (prob > 1e-12) { - return prob; - } else { - return 1.0; - } + If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma)) + thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k)) + = exp(log(q) - log(k)) = exp(-error - log(k)) + = exp(-(error + log(k))) + + So let's compute (error + log(k)) and exponentiate later + */ + error = error + gm->error(continuousValues); + + // Add the logNormalization constant to the error + // Also compute the mean for normalization (for numerical stability) + double mean = 0.0; + auto addConstant = [&gm, &mean](const double &error) { + double e = error + gm->logNormalizationConstant(); + mean += e; + return e; }; - AlgebraicDecisionTree model_selection = - DecisionTree(error, integrate); + error = error.apply(addConstant); + // Normalize by the mean + error = error.apply([&mean](double x) { return x / mean; }); - std::cout << "\n\nmodel selection"; - model_selection.print("", DefaultKeyFormatter); - discrete_fg.push_back( - DecisionTreeFactor(gm->discreteKeys(), model_selection)); + // Include the discrete keys + std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), + std::inserter(discreteKeySet, discreteKeySet.end())); } } } + AlgebraicDecisionTree model_selection = DecisionTree( + error, [](const double &error) { return std::exp(-error); }); + + discrete_fg.push_back(DecisionTreeFactor( + DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), + model_selection)); + // Solve for the MPE - discrete_fg.print(); DiscreteValues mpe = discrete_fg.optimize(); - mpe.print("mpe"); // Given the MPE, compute the optimal continuous values. return HybridValues(optimize(mpe), mpe);