Improved HybridBayesNet::optimize with proper model selection
parent
cd5c13065b
commit
7695fd6de3
|
@ -220,15 +220,69 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Collect all the discrete factors to compute MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
VectorValues continuousValues;
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscrete());
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
} else {
|
||||
/*
|
||||
Perform the integration of L(X;M, Z)P(X|M) which is the model selection
|
||||
term.
|
||||
TODO(Varun) Write better comments detailing the whole process.
|
||||
*/
|
||||
if (conditional->isContinuous()) {
|
||||
auto gc = conditional->asGaussian();
|
||||
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
|
||||
frontal != gc->endFrontals(); ++frontal) {
|
||||
continuousValues.insert_or_assign(*frontal,
|
||||
Vector::Zero(gc->getDim(frontal)));
|
||||
}
|
||||
} else if (conditional->isHybrid()) {
|
||||
auto gm = conditional->asMixture();
|
||||
gm->conditionals().apply(
|
||||
[&continuousValues](const GaussianConditional::shared_ptr &gc) {
|
||||
if (gc) {
|
||||
for (GaussianConditional::const_iterator frontal = gc->begin();
|
||||
frontal != gc->end(); ++frontal) {
|
||||
continuousValues.insert_or_assign(
|
||||
*frontal, Vector::Zero(gc->getDim(frontal)));
|
||||
}
|
||||
}
|
||||
return gc;
|
||||
});
|
||||
|
||||
DecisionTree<Key, double> error = gm->error(continuousValues);
|
||||
|
||||
// Functional to take error and compute the probability
|
||||
auto integrate = [&gm](const double &error) {
|
||||
// q(mu; M, Z) = exp(-error)
|
||||
// k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||
// thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
|
||||
// = exp(-error - log(k))
|
||||
double prob = std::exp(-error - gm->logNormalizationConstant());
|
||||
if (prob > 1e-12) {
|
||||
return prob;
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
};
|
||||
AlgebraicDecisionTree<Key> model_selection =
|
||||
DecisionTree<Key, double>(error, integrate);
|
||||
|
||||
std::cout << "\n\nmodel selection";
|
||||
model_selection.print("", DefaultKeyFormatter);
|
||||
discrete_fg.push_back(
|
||||
DecisionTreeFactor(gm->discreteKeys(), model_selection));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Solve for the MPE
|
||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
discrete_fg.print();
|
||||
DiscreteValues mpe = discrete_fg.optimize();
|
||||
mpe.print("mpe");
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
return HybridValues(optimize(mpe), mpe);
|
||||
|
|
Loading…
Reference in New Issue