helper functions for computing leaf error and normalization constants used for model selection
parent
c5bfd524e0
commit
f3e84004bf
|
|
@ -227,24 +227,6 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
return gbn;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianBayesNetTree addGaussian(
|
||||
const GaussianBayesNetTree &gfgTree,
|
||||
const GaussianConditional::shared_ptr &factor) {
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (gfgTree.empty()) {
|
||||
GaussianBayesNet result{factor};
|
||||
return GaussianBayesNetTree(result);
|
||||
} else {
|
||||
auto add = [&factor](const GaussianBayesNet &graph) {
|
||||
auto result = graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
};
|
||||
return gfgTree.apply(add);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||
GaussianBayesNetTree result;
|
||||
|
|
@ -320,25 +302,12 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
|||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||
|
||||
// Only compute logNormalizationConstant
|
||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
GaussianBayesNet gbn = gbnAndValue.first;
|
||||
if (gbn.size() == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return gbn.logNormalizationConstant();
|
||||
});
|
||||
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||
computeLogNormConstants(bnTree);
|
||||
|
||||
// Compute model selection term (with help from ADT methods)
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
(errorTree + log_norm_constants) * -1;
|
||||
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
computeModelSelectionTerm(errorTree, log_norm_constants);
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
|
|
@ -530,4 +499,52 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
|||
return fg;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
GaussianBayesNetTree addGaussian(
|
||||
const GaussianBayesNetTree &gbnTree,
|
||||
const GaussianConditional::shared_ptr &factor) {
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (gbnTree.empty()) {
|
||||
GaussianBayesNet result{factor};
|
||||
return GaussianBayesNetTree(result);
|
||||
} else {
|
||||
auto add = [&factor](const GaussianBayesNet &graph) {
|
||||
auto result = graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
};
|
||||
return gbnTree.apply(add);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||
const GaussianBayesNetValTree &bnTree) {
|
||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||
GaussianBayesNet gbn = gbnAndValue.first;
|
||||
if (gbn.size() == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return gbn.logNormalizationConstant();
|
||||
});
|
||||
return log_norm_constants;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||
const AlgebraicDecisionTree<Key> &errorTree,
|
||||
const AlgebraicDecisionTree<Key> &log_norm_constants) {
|
||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||
(errorTree + log_norm_constants) * -1;
|
||||
|
||||
double max_log = modelSelectionTerm.max();
|
||||
modelSelectionTerm = DecisionTree<Key, double>(
|
||||
modelSelectionTerm,
|
||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||
|
||||
return modelSelectionTerm;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -280,4 +280,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
template <>
|
||||
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||
|
||||
/**
|
||||
* @brief Add a Gaussian conditional to each node of the GaussianBayesNetTree
|
||||
*
|
||||
* @param gbnTree
|
||||
* @param factor
|
||||
* @return GaussianBayesNetTree
|
||||
*/
|
||||
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
||||
const GaussianConditional::shared_ptr &factor);
|
||||
|
||||
/**
|
||||
* @brief Compute the (logarithmic) normalization constant for each Bayes
|
||||
* network in the tree.
|
||||
*
|
||||
* @param bnTree A tree of Bayes networks in each leaf. The tree encodes a
|
||||
* discrete assignment yielding the Bayes net.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||
const GaussianBayesNetValTree &bnTree);
|
||||
|
||||
/**
|
||||
* @brief Compute the model selection term L(M; Z, X) given the error
|
||||
* and log normalization constants.
|
||||
*
|
||||
* Perform normalization to handle underflow issues.
|
||||
*
|
||||
* @param errorTree
|
||||
* @param log_norm_constants
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||
const AlgebraicDecisionTree<Key> &errorTree,
|
||||
const AlgebraicDecisionTree<Key> &log_norm_constants);
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
Loading…
Reference in New Issue