improved model selection code

release/4.3a0
Varun Agrawal 2023-12-26 16:33:43 -05:00
parent 6f4343ca94
commit 409938f4b4
3 changed files with 37 additions and 41 deletions

View File

@ -248,7 +248,7 @@ static GaussianBayesNetTree addGaussian(
}
/* ************************************************************************ */
GaussianBayesNetTree HybridBayesNet::assembleTree() const {
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result;
for (auto &f : factors_) {
@ -276,23 +276,17 @@ GaussianBayesNetTree HybridBayesNet::assembleTree() const {
}
}
return result;
GaussianBayesNetValTree resultTree(result, [](const GaussianBayesNet &gbn) {
return std::make_pair(gbn, 0.0);
});
return resultTree;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
VectorValues continuousValues;
std::set<DiscreteKey> discreteKeySet;
// this->print();
GaussianBayesNetTree bnTree = assembleTree();
// bnTree.print("", DefaultKeyFormatter, [](const GaussianBayesNet &gbn) {
// gbn.print();
// return "";
// });
/*
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.
@ -316,43 +310,35 @@ HybridValues HybridBayesNet::optimize() const {
So we compute (error + log(k)) and exponentiate later
*/
// Compute the X* of each assignment and use that as the MAP.
DecisionTree<Key, VectorValues> x_map(
bnTree, [](const GaussianBayesNet &gbn) { return gbn.optimize(); });
// Only compute logNormalizationConstant for now
AlgebraicDecisionTree<Key> log_norm_constants =
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
std::set<DiscreteKey> discreteKeySet;
GaussianBayesNetValTree bnTree = assembleTree();
GaussianBayesNetValTree bn_error = bnTree.apply(
[this](const Assignment<Key> &assignment,
const std::pair<GaussianBayesNet, double> &gbnAndValue) {
// Compute the X* of each assignment
VectorValues mu = gbnAndValue.first.optimize();
// Compute the error for X* and the assignment
double error =
this->error(HybridValues(mu, DiscreteValues(assignment)));
return std::make_pair(gbnAndValue.first, error);
});
auto trees = unzip(bn_error);
AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
// Compute errors as VectorValues
DecisionTree<Key, VectorValues> errorVectors = x_map.apply(
[this](const Assignment<Key> &assignment, const VectorValues &mu) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(mu);
}
}
}
VectorValues e;
e.insert(0, Vector1(error));
return e;
});
AlgebraicDecisionTree<Key> errorTree = DecisionTree<Key, double>(
errorVectors, [](const VectorValues &v) { return v[0](0); });
// Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> model_selection_term =
(errorTree + log_norm_constants) * -1;

View File

@ -118,6 +118,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
return evaluate(values);
}
GaussianBayesNetValTree assembleTree() const;
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on

View File

@ -33,6 +33,14 @@ class HybridValues;
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
/// Alias for DecisionTree of GaussianBayesNets
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
/**
* Alias for DecisionTree of (GaussianBayesNet, double) pairs.
* Used for model selection in BayesNet::optimize
*/
using GaussianBayesNetValTree =
DecisionTree<Key, std::pair<GaussianBayesNet, double>>;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);