helper functions for computing leaf error and normalization constants used for model selection

release/4.3a0
Varun Agrawal 2024-01-16 14:59:47 -05:00
parent c5bfd524e0
commit f3e84004bf
2 changed files with 87 additions and 35 deletions

View File

@ -227,24 +227,6 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */ /* ************************************************************************ */
GaussianBayesNetValTree HybridBayesNet::assembleTree() const { GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result; GaussianBayesNetTree result;
@ -320,25 +302,12 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
AlgebraicDecisionTree<Key> errorTree = trees.second; AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant // Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>( AlgebraicDecisionTree<Key> log_norm_constants =
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) { computeLogNormConstants(bnTree);
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm = AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1; computeModelSelectionTerm(errorTree, log_norm_constants);
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm; return modelSelectionTerm;
} }
@ -530,4 +499,52 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
return fg; return fg;
} }
/* ************************************************************************ */
GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gbnTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gbnTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gbnTree.apply(add);
}
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree) {
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
return log_norm_constants;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants) {
AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
} // namespace gtsam } // namespace gtsam

View File

@ -280,4 +280,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
template <> template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {}; struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
/**
* @brief Add a Gaussian conditional to each node of the GaussianBayesNetTree
*
* @param gbnTree
* @param factor
* @return GaussianBayesNetTree
*/
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
const GaussianConditional::shared_ptr &factor);
/**
* @brief Compute the (logarithmic) normalization constant for each Bayes
* network in the tree.
*
* @param bnTree A tree of Bayes networks in each leaf. The tree encodes a
* discrete assignment yielding the Bayes net.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree);
/**
* @brief Compute the model selection term L(M; Z, X) given the error
* and log normalization constants.
*
* Perform normalization to handle underflow issues.
*
* @param errorTree
* @param log_norm_constants
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants);
} // namespace gtsam } // namespace gtsam