Improved HybridBayesNet::optimize with proper model selection
parent
cd5c13065b
commit
7695fd6de3
|
@ -220,15 +220,69 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Collect all the discrete factors to compute MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
DiscreteBayesNet discrete_bn;
|
DiscreteFactorGraph discrete_fg;
|
||||||
|
VectorValues continuousValues;
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_bn.push_back(conditional->asDiscrete());
|
discrete_fg.push_back(conditional->asDiscrete());
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
Perform the integration of L(X;M, Z)P(X|M) which is the model selection
|
||||||
|
term.
|
||||||
|
TODO(Varun) Write better comments detailing the whole process.
|
||||||
|
*/
|
||||||
|
if (conditional->isContinuous()) {
|
||||||
|
auto gc = conditional->asGaussian();
|
||||||
|
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
|
||||||
|
frontal != gc->endFrontals(); ++frontal) {
|
||||||
|
continuousValues.insert_or_assign(*frontal,
|
||||||
|
Vector::Zero(gc->getDim(frontal)));
|
||||||
|
}
|
||||||
|
} else if (conditional->isHybrid()) {
|
||||||
|
auto gm = conditional->asMixture();
|
||||||
|
gm->conditionals().apply(
|
||||||
|
[&continuousValues](const GaussianConditional::shared_ptr &gc) {
|
||||||
|
if (gc) {
|
||||||
|
for (GaussianConditional::const_iterator frontal = gc->begin();
|
||||||
|
frontal != gc->end(); ++frontal) {
|
||||||
|
continuousValues.insert_or_assign(
|
||||||
|
*frontal, Vector::Zero(gc->getDim(frontal)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gc;
|
||||||
|
});
|
||||||
|
|
||||||
|
DecisionTree<Key, double> error = gm->error(continuousValues);
|
||||||
|
|
||||||
|
// Functional to take error and compute the probability
|
||||||
|
auto integrate = [&gm](const double &error) {
|
||||||
|
// q(mu; M, Z) = exp(-error)
|
||||||
|
// k = 1.0 / sqrt((2*pi)^n*det(Sigma))
|
||||||
|
// thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
|
||||||
|
// = exp(-error - log(k))
|
||||||
|
double prob = std::exp(-error - gm->logNormalizationConstant());
|
||||||
|
if (prob > 1e-12) {
|
||||||
|
return prob;
|
||||||
|
} else {
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
AlgebraicDecisionTree<Key> model_selection =
|
||||||
|
DecisionTree<Key, double>(error, integrate);
|
||||||
|
|
||||||
|
std::cout << "\n\nmodel selection";
|
||||||
|
model_selection.print("", DefaultKeyFormatter);
|
||||||
|
discrete_fg.push_back(
|
||||||
|
DecisionTreeFactor(gm->discreteKeys(), model_selection));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Solve for the MPE
|
// Solve for the MPE
|
||||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
discrete_fg.print();
|
||||||
|
DiscreteValues mpe = discrete_fg.optimize();
|
||||||
|
mpe.print("mpe");
|
||||||
|
|
||||||
// Given the MPE, compute the optimal continuous values.
|
// Given the MPE, compute the optimal continuous values.
|
||||||
return HybridValues(optimize(mpe), mpe);
|
return HybridValues(optimize(mpe), mpe);
|
||||||
|
|
Loading…
Reference in New Issue