better, more correct version of model selection

release/4.3a0
Varun Agrawal 2023-12-25 16:04:48 -05:00
parent c6584f63ce
commit ebcf958d69
1 changed files with 142 additions and 66 deletions

View File

@ -26,6 +26,18 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
using std::dynamic_pointer_cast;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s,
// and conditional f:
static void throwRuntimeError(const std::string &s,
const std::shared_ptr<HybridConditional> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for conditional type " +
demangle(typeid(fr).name()) + ".");
}
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
@ -217,6 +229,56 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}
/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */
GaussianBayesNetTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue;
}
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to evaluate continuous values only.
continue;
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
throwRuntimeError("HybridBayesNet::assembleTree", f);
}
}
return result;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
@ -227,74 +289,94 @@ HybridValues HybridBayesNet::optimize() const {
AlgebraicDecisionTree<Key> error(0.0);
std::set<DiscreteKey> discreteKeySet;
// this->print();
GaussianBayesNetTree bnTree = assembleTree();
// bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) {
// gbn.print();
// return "";
// });
/*
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors
of each of the conditionals, which we do below in hybrid case.
*/
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k)))
So we compute (error + log(k)) and exponentiate later
*/
// Compute the X* of each assignment and use that as the MAP.
DecisionTree<Key, VectorValues> x_map(
bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); });
// Only compute logNormalizationConstant for now
AlgebraicDecisionTree<Key> log_norm_constants =
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
if (gbn.size() == 0) {
return -std::numeric_limits<double>::max();
}
return -gbn.logNormalizationConstant();
});
// Compute unnormalized error term and compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = log_norm_constants.apply(
[this, &x_map](const Assignment<Key> &assignment, double x) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(x_map(assignment));
}
}
}
return -(error + x);
});
// model_selection_term.print("", DefaultKeyFormatter);
double max_log = model_selection_term.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
model_selection_term,
[&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// std::cout << "normalized model selection" << std::endl;
// model_selection.print("", DefaultKeyFormatter);
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
} else {
/*
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.
By Bayes' rule, P(X|M) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors
of each of the conditionals, which we do below in hybrid case.
*/
if (conditional->isContinuous()) {
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/
auto gc = conditional->asGaussian();
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
frontal != gc->endFrontals(); ++frontal) {
continuousValues.insert_or_assign(*frontal,
Vector::Zero(gc->getDim(frontal)));
}
// /*
// If we are here, it means there are no discrete variables in
// the Bayes net (due to strong elimination ordering).
// This is a continuous-only problem hence model selection doesn't matter.
// */
// auto gc = conditional->asGaussian();
// for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
// frontal != gc->endFrontals(); ++frontal) {
// continuousValues.insert_or_assign(*frontal,
// Vector::Zero(gc->getDim(frontal)));
// }
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
gm->conditionals().apply(
[&continuousValues](const GaussianConditional::shared_ptr &gc) {
if (gc) {
for (GaussianConditional::const_iterator frontal = gc->begin();
frontal != gc->end(); ++frontal) {
continuousValues.insert_or_assign(
*frontal, Vector::Zero(gc->getDim(frontal)));
}
}
return gc;
});
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k)))
So we compute (error + log(k)) and exponentiate later
*/
// Add the error and the logNormalization constant to the error
auto err = gm->error(continuousValues) + gm->logNormalizationConstant();
// Also compute the sum for discrete probability normalization
// (normalization trick for numerical stability)
double sum = 0.0;
auto absSum = [&sum](const double &e) {
sum += std::abs(e);
return e;
};
err.visit(absSum);
// Normalize by the sum to prevent overflow
error = error + err.normalize(sum);
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
std::inserter(discreteKeySet, discreteKeySet.end()));
@ -302,12 +384,6 @@ HybridValues HybridBayesNet::optimize() const {
}
}
error = error * -1;
double max_log = error.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
error, [&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor(