helper functions for computing leaf error and normalization constants used for model selection

release/4.3a0
Varun Agrawal 2024-01-16 14:59:47 -05:00
parent c5bfd524e0
commit f3e84004bf
2 changed files with 87 additions and 35 deletions

View File

@ -227,24 +227,6 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn; return gbn;
} }
/* ************************************************************************ */
static GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gfgTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */ /* ************************************************************************ */
GaussianBayesNetValTree HybridBayesNet::assembleTree() const { GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
GaussianBayesNetTree result; GaussianBayesNetTree result;
@ -320,25 +302,12 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
AlgebraicDecisionTree<Key> errorTree = trees.second; AlgebraicDecisionTree<Key> errorTree = trees.second;
// Only compute logNormalizationConstant // Only compute logNormalizationConstant
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>( AlgebraicDecisionTree<Key> log_norm_constants =
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) { computeLogNormConstants(bnTree);
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
// Compute model selection term (with help from ADT methods) // Compute model selection term (with help from ADT methods)
AlgebraicDecisionTree<Key> modelSelectionTerm = AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1; computeModelSelectionTerm(errorTree, log_norm_constants);
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm; return modelSelectionTerm;
} }
@ -530,4 +499,52 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
return fg; return fg;
} }
/* ************************************************************************ */
GaussianBayesNetTree addGaussian(
const GaussianBayesNetTree &gbnTree,
const GaussianConditional::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gbnTree.empty()) {
GaussianBayesNet result{factor};
return GaussianBayesNetTree(result);
} else {
auto add = [&factor](const GaussianBayesNet &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gbnTree.apply(add);
}
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree) {
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
GaussianBayesNet gbn = gbnAndValue.first;
if (gbn.size() == 0) {
return 0.0;
}
return gbn.logNormalizationConstant();
});
return log_norm_constants;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants) {
AlgebraicDecisionTree<Key> modelSelectionTerm =
(errorTree + log_norm_constants) * -1;
double max_log = modelSelectionTerm.max();
modelSelectionTerm = DecisionTree<Key, double>(
modelSelectionTerm,
[&max_log](const double &x) { return std::exp(x - max_log); });
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
return modelSelectionTerm;
}
} // namespace gtsam } // namespace gtsam

View File

@ -141,7 +141,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
This can be computed by multiplying all the exponentiated errors This can be computed by multiplying all the exponentiated errors
of each of the conditionals. of each of the conditionals.
Return a tree where each leaf value is L(M_i;Z). Return a tree where each leaf value is L(M_i;Z).
*/ */
AlgebraicDecisionTree<Key> modelSelection() const; AlgebraicDecisionTree<Key> modelSelection() const;
@ -280,4 +280,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
template <> template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {}; struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
/**
* @brief Add a Gaussian conditional to each node of the GaussianBayesNetTree
*
* @param gbnTree
* @param factor
* @return GaussianBayesNetTree
*/
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
const GaussianConditional::shared_ptr &factor);
/**
* @brief Compute the (logarithmic) normalization constant for each Bayes
* network in the tree.
*
* @param bnTree A tree of Bayes networks in each leaf. The tree encodes a
* discrete assignment yielding the Bayes net.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeLogNormConstants(
const GaussianBayesNetValTree &bnTree);
/**
* @brief Compute the model selection term L(M; Z, X) given the error
* and log normalization constants.
*
* Perform normalization to handle underflow issues.
*
* @param errorTree
* @param log_norm_constants
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
const AlgebraicDecisionTree<Key> &errorTree,
const AlgebraicDecisionTree<Key> &log_norm_constants);
} // namespace gtsam } // namespace gtsam