nicer HybridBayesNet::optimize with normalized errors
parent
39f7ac20a1
commit
c374a26b45
|
@ -223,22 +223,38 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
DiscreteFactorGraph discrete_fg;
|
||||
VectorValues continuousValues;
|
||||
|
||||
// Error values for each hybrid factor
|
||||
AlgebraicDecisionTree<Key> error(0.0);
|
||||
std::set<DiscreteKey> discreteKeySet;
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
} else {
|
||||
/*
|
||||
Perform the integration of L(X;M, Z)P(X|M) which is the model selection
|
||||
term.
|
||||
TODO(Varun) Write better comments detailing the whole process.
|
||||
Perform the integration of L(X;M,Z)P(X|M)
|
||||
which is the model selection term.
|
||||
|
||||
By Bayes' rule, P(X|M) ∝ L(X;M,Z)P(X|M),
|
||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
||||
the joint Gaussian distribution.
|
||||
|
||||
This can be computed by multiplying all the exponentiaed errors
|
||||
of each of the conditionals, which we do below in hybrid case.
|
||||
*/
|
||||
if (conditional->isContinuous()) {
|
||||
/*
|
||||
If we are here, it means there are no discrete variables in
|
||||
the Bayes net (due to strong elimination ordering).
|
||||
This is a continuous-only problem hence model selection doesn't matter.
|
||||
*/
|
||||
auto gc = conditional->asGaussian();
|
||||
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
|
||||
frontal != gc->endFrontals(); ++frontal) {
|
||||
continuousValues.insert_or_assign(*frontal,
|
||||
Vector::Zero(gc->getDim(frontal)));
|
||||
}
|
||||
|
||||
} else if (conditional->isHybrid()) {
|
||||
auto gm = conditional->asMixture();
|
||||
gm->conditionals().apply(
|
||||
|
@ -253,36 +269,47 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
return gc;
|
||||
});
|
||||
|
||||
DecisionTree<Key, double> error = gm->error(continuousValues);
|
||||
/*
|
||||
To perform model selection, we need:
|
||||
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
|
||||
|
||||
// Functional to take error and compute the probability
|
||||
auto integrate = [&gm](const double &error) {
|
||||
// q(mu; M, Z) = exp(-error)
|
||||
// k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
// thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
|
||||
// = exp(-error - log(k))
|
||||
double prob = std::exp(-error - gm->logNormalizationConstant());
|
||||
if (prob > 1e-12) {
|
||||
return prob;
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k))
|
||||
= exp(log(q) - log(k)) = exp(-error - log(k))
|
||||
= exp(-(error + log(k)))
|
||||
|
||||
So let's compute (error + log(k)) and exponentiate later
|
||||
*/
|
||||
error = error + gm->error(continuousValues);
|
||||
|
||||
// Add the logNormalization constant to the error
|
||||
// Also compute the mean for normalization (for numerical stability)
|
||||
double mean = 0.0;
|
||||
auto addConstant = [&gm, &mean](const double &error) {
|
||||
double e = error + gm->logNormalizationConstant();
|
||||
mean += e;
|
||||
return e;
|
||||
};
|
||||
AlgebraicDecisionTree<Key> model_selection =
|
||||
DecisionTree<Key, double>(error, integrate);
|
||||
error = error.apply(addConstant);
|
||||
// Normalize by the mean
|
||||
error = error.apply([&mean](double x) { return x / mean; });
|
||||
|
||||
std::cout << "\n\nmodel selection";
|
||||
model_selection.print("", DefaultKeyFormatter);
|
||||
discrete_fg.push_back(
|
||||
DecisionTreeFactor(gm->discreteKeys(), model_selection));
|
||||
// Include the discrete keys
|
||||
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
|
||||
std::inserter(discreteKeySet, discreteKeySet.end()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
|
||||
error, [](const double &error) { return std::exp(-error); });
|
||||
|
||||
discrete_fg.push_back(DecisionTreeFactor(
|
||||
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
|
||||
model_selection));
|
||||
|
||||
// Solve for the MPE
|
||||
discrete_fg.print();
|
||||
DiscreteValues mpe = discrete_fg.optimize();
|
||||
mpe.print("mpe");
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
return HybridValues(optimize(mpe), mpe);
|
||||
|
|
Loading…
Reference in New Issue