Improved HybridBayesNet::optimize with proper model selection

release/4.3a0
Varun Agrawal 2023-11-20 13:24:39 -05:00
parent cd5c13065b
commit 7695fd6de3
1 changed files with 57 additions and 3 deletions

View File

@ -220,15 +220,69 @@ GaussianBayesNet HybridBayesNet::choose(
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn;
DiscreteFactorGraph discrete_fg;
VectorValues continuousValues;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete());
discrete_fg.push_back(conditional->asDiscrete());
} else {
/*
Perform the integration of L(X;M, Z)P(X|M) which is the model selection
term.
TODO(Varun) Write better comments detailing the whole process.
*/
if (conditional->isContinuous()) {
auto gc = conditional->asGaussian();
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
frontal != gc->endFrontals(); ++frontal) {
continuousValues.insert_or_assign(*frontal,
Vector::Zero(gc->getDim(frontal)));
}
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
gm->conditionals().apply(
[&continuousValues](const GaussianConditional::shared_ptr &gc) {
if (gc) {
for (GaussianConditional::const_iterator frontal = gc->begin();
frontal != gc->end(); ++frontal) {
continuousValues.insert_or_assign(
*frontal, Vector::Zero(gc->getDim(frontal)));
}
}
return gc;
});
DecisionTree<Key, double> error = gm->error(continuousValues);
// Functional to take error and compute the probability
auto integrate = [&gm](const double &error) {
// q(mu; M, Z) = exp(-error)
// k = 1.0 / sqrt((2*pi)^n*det(Sigma))
// thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
// = exp(-error - log(k))
double prob = std::exp(-error - gm->logNormalizationConstant());
if (prob > 1e-12) {
return prob;
} else {
return 1.0;
}
};
AlgebraicDecisionTree<Key> model_selection =
DecisionTree<Key, double>(error, integrate);
std::cout << "\n\nmodel selection";
model_selection.print("", DefaultKeyFormatter);
discrete_fg.push_back(
DecisionTreeFactor(gm->discreteKeys(), model_selection));
}
}
}
// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
discrete_fg.print();
DiscreteValues mpe = discrete_fg.optimize();
mpe.print("mpe");
// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);