remove model selection code
parent
fef929f266
commit
654bad7381
|
|
@ -71,29 +71,27 @@ GaussianMixture::GaussianMixture(
|
|||
Conditionals(discreteParents, conditionals)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
|
||||
// GaussianMixtureFactor, no?
|
||||
GaussianFactorGraphTree GaussianMixture::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
const auto tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||
return GaussianFactorGraph{gc};
|
||||
if (gc) {
|
||||
return GaussianBayesNet{gc};
|
||||
} else {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||
auto wrap = [](const GaussianBayesNet &gbn) {
|
||||
return GaussianFactorGraph(gbn);
|
||||
};
|
||||
return {this->asGaussianBayesNetTree(), wrap};
|
||||
}
|
||||
|
||||
/*
|
||||
*******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::add(
|
||||
const GaussianBayesNetTree &sum) const {
|
||||
using Y = GaussianBayesNet;
|
||||
|
|
@ -110,15 +108,18 @@ GaussianBayesNetTree GaussianMixture::add(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||
if (gc) {
|
||||
return GaussianBayesNet{gc};
|
||||
} else {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from
|
||||
// GaussianMixtureFactor, no?
|
||||
GaussianFactorGraphTree GaussianMixture::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
const auto tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
|||
|
|
@ -26,16 +26,6 @@ static std::mt19937_64 kRandomNumberGenerator(42);
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
// Throw a runtime exception for method specified in string s,
|
||||
// and conditional f:
|
||||
static void throwRuntimeError(const std::string &s,
|
||||
const std::shared_ptr<HybridConditional> &f) {
|
||||
auto &fr = *f;
|
||||
throw std::runtime_error(s + " not implemented for conditional type " +
|
||||
demangle(typeid(fr).name()) + ".");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
|
|
@ -227,141 +217,17 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
return gbn;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianBayesNetTree addGaussian(
|
||||
const GaussianBayesNetTree &gfgTree,
|
||||
const GaussianConditional::shared_ptr &factor) {
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (gfgTree.empty()) {
|
||||
GaussianBayesNet result{factor};
|
||||
return GaussianBayesNetTree(result);
|
||||
} else {
|
||||
auto add = [&factor](const GaussianBayesNet &graph) {
|
||||
auto result = graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
};
|
||||
return gfgTree.apply(add);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||
GaussianBayesNetTree result;
|
||||
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(f)) {
|
||||
result = gm->add(result);
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asMixture()) {
|
||||
result = gm->add(result);
|
||||
} else if (auto g = hc->asGaussian()) {
|
||||
result = addGaussian(result, g);
|
||||
} else {
|
||||
// Has to be discrete.
|
||||
// TODO(dellaert): in C++20, we can use std::visit.
|
||||
continue;
|
||||
}
|
||||
} else if (std::dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||
// Don't do anything for discrete-only factors
|
||||
// since we want to evaluate continuous values only.
|
||||
continue;
|
||||
} else {
|
||||
// We need to handle the case where the object is actually an
|
||||
// BayesTreeOrphanWrapper!
|
||||
throwRuntimeError("HybridBayesNet::assembleTree", f);
|
||||
}
|
||||
}
|
||||
|
||||
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
|
||||
return std::make_pair(gbn, 0.0);
|
||||
});
|
||||
return resultTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||
/*
|
||||
To perform model selection, we need:
|
||||
q(mu; M, Z) = exp(-error)
|
||||
where error is computed at the corresponding MAP point, gbn.error(mu).
|
||||
|
||||
So we compute `error` and exponentiate after.
|
||||
*/
|
||||
|
||||
GaussianBayesNetValTree bnTree = assembleTree();
|
||||
|
||||
GaussianBayesNetValTree bn_error = bnTree.apply(
|
||||
[this](const Assignment<Key> &assignment,
|
||||
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
// Compute the X* of each assignment
|
||||
VectorValues mu = gbnAndValue.first.optimize();
|
||||
|
||||
// mu is empty if gbn had nullptrs
|
||||
if (mu.size() == 0) {
|
||||
return std::make_pair(gbnAndValue.first,
|
||||
std::numeric_limits<double>::max());
|
||||
}
|
||||
|
||||
// Compute the error for X* and the assignment
|
||||
double error =
|
||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
||||
return std::make_pair(gbnAndValue.first, error);
|
||||
});
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
auto trees = unzip(bn_error);
|
||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm = errorTree * -1;
|
||||
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Collect all the discrete factors to compute MPE
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
|
||||
// Compute model selection term
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm = modelSelection();
|
||||
|
||||
// Get the set of all discrete keys involved in model selection
|
||||
std::set<DiscreteKey> discreteKeySet;
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
} else {
|
||||
if (conditional->isContinuous()) {
|
||||
/*
|
||||
If we are here, it means there are no discrete variables in
|
||||
the Bayes net (due to strong elimination ordering).
|
||||
This is a continuous-only problem hence model selection doesn't matter.
|
||||
*/
|
||||
|
||||
} else if (conditional->isHybrid()) {
|
||||
auto gm = conditional->asMixture();
|
||||
// Include the discrete keys
|
||||
std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
|
||||
std::inserter(discreteKeySet, discreteKeySet.end()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only add model_selection if we have discrete keys
|
||||
if (discreteKeySet.size() > 0) {
|
||||
discrete_fg.push_back(DecisionTreeFactor(
|
||||
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
|
||||
modelSelectionTerm));
|
||||
}
|
||||
|
||||
// Solve for the MPE
|
||||
DiscreteValues mpe = discrete_fg.optimize();
|
||||
|
||||
|
|
|
|||
|
|
@ -118,34 +118,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
return evaluate(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Assemble a DecisionTree of (GaussianBayesNet, double) leaves for
|
||||
* each discrete assignment.
|
||||
* The included double value is used to make
|
||||
* constructing the model selection term cleaner and more efficient.
|
||||
*
|
||||
* @return GaussianBayesNetValTree
|
||||
*/
|
||||
GaussianBayesNetValTree assembleTree() const;
|
||||
|
||||
/*
|
||||
Compute L(M;Z), the likelihood of the discrete model M
|
||||
given the measurements Z.
|
||||
This is called the model selection term.
|
||||
|
||||
To do so, we perform the integration of L(M;Z) ∝ L(X;M,Z)P(X|M).
|
||||
|
||||
By Bayes' rule, P(X|M,Z) ∝ L(X;M,Z)P(X|M),
|
||||
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
|
||||
the joint Gaussian distribution.
|
||||
|
||||
This can be computed by multiplying all the exponentiated errors
|
||||
of each of the conditionals.
|
||||
|
||||
Return a tree where each leaf value is L(M_i;Z).
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
||||
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
* @file HybridFactor.h
|
||||
* @date Mar 11, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
|
@ -35,12 +36,6 @@ class HybridValues;
|
|||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||
/// Alias for DecisionTree of GaussianBayesNets
|
||||
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
|
||||
/**
|
||||
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
|
||||
* Used for model selection in BayesNet::optimize
|
||||
*/
|
||||
using GaussianBayesNetValTree =
|
||||
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
|
|
|
|||
Loading…
Reference in New Issue