improved model selection code
parent
6f4343ca94
commit
409938f4b4
|
|
@ -248,7 +248,7 @@ static GaussianBayesNetTree addGaussian(
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianBayesNetTree HybridBayesNet::assembleTree() const {
|
||||
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||
GaussianBayesNetTree result;
|
||||
|
||||
for (auto &f : factors_) {
|
||||
|
|
@ -276,23 +276,17 @@ GaussianBayesNetTree HybridBayesNet::assembleTree() const {
|
|||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
|
||||
return std::make_pair(gbn, 0.0);
|
||||
});
|
||||
return resultTree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
// Collect all the discrete factors to compute MPE
|
||||
DiscreteFactorGraph discrete_fg;
|
||||
VectorValues continuousValues;
|
||||
|
||||
std::set<DiscreteKey> discreteKeySet;
|
||||
|
||||
// this->print();
|
||||
GaussianBayesNetTree bnTree = assembleTree();
|
||||
// bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) {
|
||||
// gbn.print();
|
||||
// return "";
|
||||
// });
|
||||
/*
|
||||
Perform the integration of L(X;M,Z)P(X|M)
|
||||
which is the model selection term.
|
||||
|
|
@ -316,43 +310,35 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
|
||||
So we compute (error + log(k)) and exponentiate later
|
||||
*/
|
||||
// Compute the X* of each assignment and use that as the MAP.
|
||||
DecisionTree<Key, VectorValues> x_map(
|
||||
bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); });
|
||||
|
||||
// Only compute logNormalizationConstant for now
|
||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
|
||||
std::set<DiscreteKey> discreteKeySet;
|
||||
GaussianBayesNetValTree bnTree = assembleTree();
|
||||
|
||||
GaussianBayesNetValTree bn_error = bnTree.apply(
|
||||
[this](const Assignment<Key> &assignment,
|
||||
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
// Compute the X* of each assignment
|
||||
VectorValues mu = gbnAndValue.first.optimize();
|
||||
// Compute the error for X* and the assignment
|
||||
double error =
|
||||
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
||||
return std::make_pair(gbnAndValue.first, error);
|
||||
});
|
||||
|
||||
auto trees = unzip(bn_error);
|
||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
|
||||
// Only compute logNormalizationConstant
|
||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
GaussianBayesNet gbn = gbnAndValue.first;
|
||||
if (gbn.size() == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return gbn.logNormalizationConstant();
|
||||
});
|
||||
|
||||
// Compute errors as VectorValues
|
||||
DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
|
||||
[this](const Assignment<Key> &assignment, const VectorValues &mu) {
|
||||
double error = 0.0;
|
||||
for (auto &&f : *this) {
|
||||
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
|
||||
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asMixture()) {
|
||||
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||
|
||||
} else if (auto g = hc->asGaussian()) {
|
||||
error += g->error(mu);
|
||||
}
|
||||
}
|
||||
}
|
||||
VectorValues e;
|
||||
e.insert(0, Vector1(error));
|
||||
return e;
|
||||
});
|
||||
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
|
||||
errorVectors, [](const VectorValues &v) { return v[0](0); });
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
AlgebraicDecisionTree<Key> model_selection_term =
|
||||
(errorTree + log_norm_constants) * -1;
|
||||
|
|
|
|||
|
|
@ -118,6 +118,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
return evaluate(values);
|
||||
}
|
||||
|
||||
GaussianBayesNetValTree assembleTree() const;
|
||||
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
|
|
|
|||
|
|
@ -33,6 +33,14 @@ class HybridValues;
|
|||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||
/// Alias for DecisionTree of GaussianBayesNets
|
||||
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
|
||||
/**
|
||||
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
|
||||
* Used for model selection in BayesNet::optimize
|
||||
*/
|
||||
using GaussianBayesNetValTree =
|
||||
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
|
|
|
|||
Loading…
Reference in New Issue