From 7695fd6de3fc00cf7d348dc2f30cd6e6f4548f1a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 20 Nov 2023 13:24:39 -0500 Subject: [PATCH] Improved HybridBayesNet::optimize with proper model selection --- gtsam/hybrid/HybridBayesNet.cpp | 60 +++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 31177ddb7..ba869a6f5 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -220,15 +220,69 @@ GaussianBayesNet HybridBayesNet::choose( /* ************************************************************************* */ HybridValues HybridBayesNet::optimize() const { // Collect all the discrete factors to compute MPE - DiscreteBayesNet discrete_bn; + DiscreteFactorGraph discrete_fg; + VectorValues continuousValues; + for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_bn.push_back(conditional->asDiscrete()); + discrete_fg.push_back(conditional->asDiscrete()); + } else { + /* + Perform the integration of L(X;M, Z)P(X|M) which is the model selection + term. + TODO(Varun) Write better comments detailing the whole process. + */ + if (conditional->isContinuous()) { + auto gc = conditional->asGaussian(); + for (GaussianConditional::const_iterator frontal = gc->beginFrontals(); + frontal != gc->endFrontals(); ++frontal) { + continuousValues.insert_or_assign(*frontal, + Vector::Zero(gc->getDim(frontal))); + } + } else if (conditional->isHybrid()) { + auto gm = conditional->asMixture(); + gm->conditionals().apply( + [&continuousValues](const GaussianConditional::shared_ptr &gc) { + if (gc) { + for (GaussianConditional::const_iterator frontal = gc->begin(); + frontal != gc->end(); ++frontal) { + continuousValues.insert_or_assign( + *frontal, Vector::Zero(gc->getDim(frontal))); + } + } + return gc; + }); + + DecisionTree error = gm->error(continuousValues); + + // Functional to take error and compute the probability + auto integrate = [&gm](const double &error) { + // q(mu; M, Z) = exp(-error) + // k = 1.0 / sqrt((2*pi)^n*det(Sigma)) + // thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k)) + // = exp(-error - log(k)) + double prob = std::exp(-error - gm->logNormalizationConstant()); + if (prob > 1e-12) { + return prob; + } else { + return 1.0; + } + }; + AlgebraicDecisionTree model_selection = + DecisionTree(error, integrate); + + std::cout << "\n\nmodel selection"; + model_selection.print("", DefaultKeyFormatter); + discrete_fg.push_back( + DecisionTreeFactor(gm->discreteKeys(), model_selection)); + } } } // Solve for the MPE - DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); + discrete_fg.print(); + DiscreteValues mpe = discrete_fg.optimize(); + mpe.print("mpe"); // Given the MPE, compute the optimal continuous values. return HybridValues(optimize(mpe), mpe);