remove model selection code

release/4.3a0
Varun Agrawal 2024-08-20 14:04:37 -04:00
parent fef929f266
commit 654bad7381
4 changed files with 27 additions and 193 deletions

View File

@ -71,29 +71,27 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
return GaussianFactorGraph{gc};
if (gc) {
return GaussianBayesNet{gc};
} else {
return GaussianBayesNet();
}
};
return {conditionals_, wrap};
}
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto wrap = [](const GaussianBayesNet &gbn) {
return GaussianFactorGraph(gbn);
};
return {this->asGaussianBayesNetTree(), wrap};
}
/*
*******************************************************************************/
GaussianBayesNetTree GaussianMixture::add(
const GaussianBayesNetTree &sum) const {
using Y = GaussianBayesNet;
@ -110,15 +108,18 @@ GaussianBayesNetTree GaussianMixture::add(
}
/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
if (gc) {
return GaussianBayesNet{gc};
} else {
return GaussianBayesNet();
}
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
return {conditionals_, wrap};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/

View File

@ -26,16 +26,6 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam {
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s,
// and conditional f:
static void throwRuntimeError(const std::string &s,
const std::shared_ptr<HybridConditional> &f) {
auto &fr = *f;
throw std::runtime_error(s + " not implemented for conditional type " +
demangle(typeid(fr).name()) + ".");
}
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
@ -227,141 +217,17 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}
/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(f)) {
result = gm->add(result);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
result = gm->add(result);
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue;
}
} else if (std::dynamic_pointer_cast<DiscreteFactor>(f)) {
// Don't do anything for discrete-only factors
// since we want to evaluate continuous values only.
continue;
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
throwRuntimeError("HybridBayesNet::assembleTree", f);
}
}
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
return std::make_pair(gbn, 0.0);
});
return resultTree;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
/*
To perform model selection, we need:
q(mu; M, Z) = exp(-error)
where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute `error` and exponentiate after.
*/
GaussianBayesNetValTree bnTree = assembleTree();
GaussianBayesNetValTree bn_error = bnTree.apply(
[this](const Assignment<Key> &assignment,
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
// Compute the X* of each assignment
VectorValues mu = gbnAndValue.first.optimize();
// mu is empty if gbn had nullptrs
if (mu.size() == 0) {
return std::make_pair(gbnAndValue.first,
std::numeric_limits<double>::max());
}
// Compute the error for X* and the assignment
double error =
this->error(HybridValues(mu, DiscreteValues(assignment)));
return std::make_pair(gbnAndValue.first, error);
});
// Compute model selection term (with help from ADT methods)
auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
// Compute model selection term
AlgebraicDecisionTree<Key> modelSelectionTerm = modelSelection();
// Get the set of all discrete keys involved in model selection
std::set<DiscreteKey> discreteKeySet;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
} else {
if (conditional->isContinuous()) {
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
// Include the discrete keys
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
std::inserter(discreteKeySet, discreteKeySet.end()));
}
}
}
// Only add model_selection if we have discrete keys
if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
modelSelectionTerm));
}
// Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize();

View File

@ -118,34 +118,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
return evaluate(values);
}
/**
* @brief Assemble a DecisionTree of (GaussianBayesNet, double) leaves for
* each discrete assignment.
* The included double value is used to make
* constructing the model selection term cleaner and more efficient.
*
* @return GaussianBayesNetValTree
*/
GaussianBayesNetValTree assembleTree() const;
/*
Compute L(M;Z), the likelihood of the discrete model M
given the measurements Z.
This is called the model selection term.
To do so, we perform the integration of L(M;Z) L(X;M,Z)P(X|M).
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors
of each of the conditionals.
Return a tree where each leaf value is L(M_i;Z).
*/
AlgebraicDecisionTree<Key> modelSelection() const;
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on

View File

@ -13,6 +13,7 @@
* @file HybridFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/
#pragma once
@ -35,12 +36,6 @@ class HybridValues;
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
/// Alias for DecisionTree of GaussianBayesNets
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
/**
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
* Used for model selection in BayesNet::optimize
*/
using GaussianBayesNetValTree =
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);