nicer HybridBayesNet::optimize with normalized errors

release/4.3a0
Varun Agrawal 2023-11-20 13:31:07 -05:00
parent 39f7ac20a1
commit c374a26b45
1 changed files with 51 additions and 24 deletions

View File

@ -223,22 +223,38 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteFactorGraph discrete_fg;
VectorValues continuousValues;
// Error values for each hybrid factor
AlgebraicDecisionTree<Key> error(0.0);
std::set<DiscreteKey> discreteKeySet;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
} else {
/*
Perform the integration of L(X;M, Z)P(X|M) which is the model selection
term.
TODO(Varun) Write better comments detailing the whole process.
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.
By Bayes' rule, P(X|M) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiaed errors
of each of the conditionals, which we do below in hybrid case.
*/
if (conditional->isContinuous()) {
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/
auto gc = conditional->asGaussian();
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
frontal != gc->endFrontals(); ++frontal) {
continuousValues.insert_or_assign(*frontal,
Vector::Zero(gc->getDim(frontal)));
}
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
gm->conditionals().apply(
@ -253,36 +269,47 @@ HybridValues HybridBayesNet::optimize() const {
return gc;
});
DecisionTree<Key, double> error = gm->error(continuousValues);
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
// Functional to take error and compute the probability
auto integrate = [&gm](const double &error) {
// q(mu; M, Z) = exp(-error)
// k = 1.0 / sqrt((2*pi)^n*det(Sigma))
// thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
// = exp(-error - log(k))
double prob = std::exp(-error - gm->logNormalizationConstant());
if (prob > 1e-12) {
return prob;
} else {
return 1.0;
}
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k)))
So let's compute (error + log(k)) and exponentiate later
*/
error = error + gm->error(continuousValues);
// Add the logNormalization constant to the error
// Also compute the mean for normalization (for numerical stability)
double mean = 0.0;
auto addConstant = [&gm, &mean](const double &error) {
double e = error + gm->logNormalizationConstant();
mean += e;
return e;
};
AlgebraicDecisionTree<Key> model_selection =
DecisionTree<Key, double>(error, integrate);
error = error.apply(addConstant);
// Normalize by the mean
error = error.apply([&mean](double x) { return x / mean; });
std::cout << "\n\nmodel selection";
model_selection.print("", DefaultKeyFormatter);
discrete_fg.push_back(
DecisionTreeFactor(gm->discreteKeys(), model_selection));
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
std::inserter(discreteKeySet, discreteKeySet.end()));
}
}
}
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
error, [](const double &error) { return std::exp(-error); });
discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection));
// Solve for the MPE
discrete_fg.print();
DiscreteValues mpe = discrete_fg.optimize();
mpe.print("mpe");
// Given the MPE, compute the optimal continuous values.
return HybridValues(optimize(mpe), mpe);