better, more correct version of model selection

release/4.3a0
Varun Agrawal 2023-12-25 16:04:48 -05:00
parent c6584f63ce
commit ebcf958d69
1 changed files with 142 additions and 66 deletions

View File

@ -26,6 +26,18 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
using std::dynamic_pointer_cast;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s,
// and conditional f:
static void throwRuntimeError(const std::string &s,
const std::shared_ptr<HybridConditional> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for conditional type " +
demangle(typeid(fr).name()) + ".");
}
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::print(const std::string &s, void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
@ -217,6 +229,56 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */
GaussianBayesNetTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue;
}
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to evaluate continuous values only.
continue;
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
throwRuntimeError("HybridBayesNet::assembleTree", f);
}
}
return result;
}
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE // Collect all the discrete factors to compute MPE
@ -227,48 +289,23 @@ HybridValues HybridBayesNet::optimize() const {
AlgebraicDecisionTree<Key> error(0.0); AlgebraicDecisionTree<Key> error(0.0);
std::set<DiscreteKey> discreteKeySet; std::set<DiscreteKey> discreteKeySet;
for (auto &&conditional : *this) { // this->print();
if (conditional->isDiscrete()) { GaussianBayesNetTree bnTree = assembleTree();
discrete_fg.push_back(conditional->asDiscrete()); // bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) {
} else { // gbn.print();
// return "";
// });
/* /*
Perform the integration of L(X;M,Z)P(X|M) Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term. which is the model selection term.
By Bayes' rule, P(X|M) L(X;M,Z)P(X|M), By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution. the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors This can be computed by multiplying all the exponentiated errors
of each of the conditionals, which we do below in hybrid case. of each of the conditionals, which we do below in hybrid case.
*/ */
if (conditional->isContinuous()) {
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/
auto gc = conditional->asGaussian();
for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
frontal != gc->endFrontals(); ++frontal) {
continuousValues.insert_or_assign(*frontal,
Vector::Zero(gc->getDim(frontal)));
}
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
gm->conditionals().apply(
[&continuousValues](const GaussianConditional::shared_ptr &gc) {
if (gc) {
for (GaussianConditional::const_iterator frontal = gc->begin();
frontal != gc->end(); ++frontal) {
continuousValues.insert_or_assign(
*frontal, Vector::Zero(gc->getDim(frontal)));
}
}
return gc;
});
/* /*
To perform model selection, we need: To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma)) q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
@ -280,21 +317,66 @@ HybridValues HybridBayesNet::optimize() const {
So we compute (error + log(k)) and exponentiate later So we compute (error + log(k)) and exponentiate later
*/ */
// Compute the X* of each assignment and use that as the MAP.
DecisionTree<Key, VectorValues> x_map(
bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); });
// Add the error and the logNormalization constant to the error // Only compute logNormalizationConstant for now
auto err = gm->error(continuousValues) + gm->logNormalizationConstant(); AlgebraicDecisionTree<Key> log_norm_constants =
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
if (gbn.size() == 0) {
return -std::numeric_limits<double>::max();
}
return -gbn.logNormalizationConstant();
});
// Compute unnormalized error term and compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = log_norm_constants.apply(
[this, &x_map](const Assignment<Key> &assignment, double x) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(x_map(assignment));
}
}
}
return -(error + x);
});
// model_selection_term.print("", DefaultKeyFormatter);
// Also compute the sum for discrete probability normalization double max_log = model_selection_term.max();
// (normalization trick for numerical stability) AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
double sum = 0.0; model_selection_term,
auto absSum = [&sum](const double &e) { [&max_log](const double &x) { return std::exp(x - max_log); });
sum += std::abs(e); model_selection = model_selection.normalize(model_selection.sum());
return e; // std::cout << "normalized model selection" << std::endl;
}; // model_selection.print("", DefaultKeyFormatter);
err.visit(absSum);
// Normalize by the sum to prevent overflow
error = error + err.normalize(sum);
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
} else {
if (conditional->isContinuous()) {
// /*
// If we are here, it means there are no discrete variables in
// the Bayes net (due to strong elimination ordering).
// This is a continuous-only problem hence model selection doesn't matter.
// */
// auto gc = conditional->asGaussian();
// for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
// frontal != gc->endFrontals(); ++frontal) {
// continuousValues.insert_or_assign(*frontal,
// Vector::Zero(gc->getDim(frontal)));
// }
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
// Include the discrete keys // Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(), std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
std::inserter(discreteKeySet, discreteKeySet.end())); std::inserter(discreteKeySet, discreteKeySet.end()));
@ -302,12 +384,6 @@ HybridValues HybridBayesNet::optimize() const {
} }
} }
error = error * -1;
double max_log = error.max();
AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
error, [&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// Only add model_selection if we have discrete keys // Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) { if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor( discrete_fg.push_back(DecisionTreeFactor(