improved model selection code
parent
6f4343ca94
commit
409938f4b4
|
|
@ -248,7 +248,7 @@ static GaussianBayesNetTree addGaussian(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
GaussianBayesNetTree HybridBayesNet::assembleTree() const {
|
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||||
GaussianBayesNetTree result;
|
GaussianBayesNetTree result;
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
|
|
@ -276,23 +276,17 @@ GaussianBayesNetTree HybridBayesNet::assembleTree() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
|
||||||
|
return std::make_pair(gbn, 0.0);
|
||||||
|
});
|
||||||
|
return resultTree;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
// Collect all the discrete factors to compute MPE
|
// Collect all the discrete factors to compute MPE
|
||||||
DiscreteFactorGraph discrete_fg;
|
DiscreteFactorGraph discrete_fg;
|
||||||
VectorValues continuousValues;
|
|
||||||
|
|
||||||
std::set<DiscreteKey> discreteKeySet;
|
|
||||||
|
|
||||||
// this->print();
|
|
||||||
GaussianBayesNetTree bnTree = assembleTree();
|
|
||||||
// bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) {
|
|
||||||
// gbn.print();
|
|
||||||
// return "";
|
|
||||||
// });
|
|
||||||
/*
|
/*
|
||||||
Perform the integration of L(X;M,Z)P(X|M)
|
Perform the integration of L(X;M,Z)P(X|M)
|
||||||
which is the model selection term.
|
which is the model selection term.
|
||||||
|
|
@ -316,43 +310,35 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
|
|
||||||
So we compute (error + log(k)) and exponentiate later
|
So we compute (error + log(k)) and exponentiate later
|
||||||
*/
|
*/
|
||||||
// Compute the X* of each assignment and use that as the MAP.
|
|
||||||
DecisionTree<Key, VectorValues> x_map(
|
|
||||||
bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); });
|
|
||||||
|
|
||||||
// Only compute logNormalizationConstant for now
|
std::set<DiscreteKey> discreteKeySet;
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
GaussianBayesNetValTree bnTree = assembleTree();
|
||||||
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
|
|
||||||
|
GaussianBayesNetValTree bn_error = bnTree.apply(
|
||||||
|
[this](const Assignment<Key> &assignment,
|
||||||
|
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||||
|
// Compute the X* of each assignment
|
||||||
|
VectorValues mu = gbnAndValue.first.optimize();
|
||||||
|
// Compute the error for X* and the assignment
|
||||||
|
double error =
|
||||||
|
this->error(HybridValues(mu, DiscreteValues(assignment)));
|
||||||
|
|
||||||
|
return std::make_pair(gbnAndValue.first, error);
|
||||||
|
});
|
||||||
|
|
||||||
|
auto trees = unzip(bn_error);
|
||||||
|
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||||
|
|
||||||
|
// Only compute logNormalizationConstant
|
||||||
|
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
||||||
|
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||||
|
GaussianBayesNet gbn = gbnAndValue.first;
|
||||||
if (gbn.size() == 0) {
|
if (gbn.size() == 0) {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
return gbn.logNormalizationConstant();
|
return gbn.logNormalizationConstant();
|
||||||
});
|
});
|
||||||
|
|
||||||
// Compute errors as VectorValues
|
|
||||||
DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
|
|
||||||
[this](const Assignment<Key> &assignment, const VectorValues &mu) {
|
|
||||||
double error = 0.0;
|
|
||||||
for (auto &&f : *this) {
|
|
||||||
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
|
|
||||||
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
|
|
||||||
|
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
|
||||||
if (auto gm = hc->asMixture()) {
|
|
||||||
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
|
|
||||||
|
|
||||||
} else if (auto g = hc->asGaussian()) {
|
|
||||||
error += g->error(mu);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
VectorValues e;
|
|
||||||
e.insert(0, Vector1(error));
|
|
||||||
return e;
|
|
||||||
});
|
|
||||||
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
|
|
||||||
errorVectors, [](const VectorValues &v) { return v[0](0); });
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
// Compute model selection term (with help from ADT methods)
|
||||||
AlgebraicDecisionTree<Key> model_selection_term =
|
AlgebraicDecisionTree<Key> model_selection_term =
|
||||||
(errorTree + log_norm_constants) * -1;
|
(errorTree + log_norm_constants) * -1;
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
return evaluate(values);
|
return evaluate(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GaussianBayesNetValTree assembleTree() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||||
* discrete variables and then optimizing the continuous variables based on
|
* discrete variables and then optimizing the continuous variables based on
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,14 @@ class HybridValues;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
/// Alias for DecisionTree of GaussianBayesNets
|
||||||
|
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
|
||||||
|
/**
|
||||||
|
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
|
||||||
|
* Used for model selection in BayesNet::optimize
|
||||||
|
*/
|
||||||
|
using GaussianBayesNetValTree =
|
||||||
|
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue