helper functions for computing leaf error and normalization constants used for model selection
parent
c5bfd524e0
commit
f3e84004bf
|
|
@ -227,24 +227,6 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
return gbn;
|
return gbn;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
static GaussianBayesNetTree addGaussian(
|
|
||||||
const GaussianBayesNetTree &gfgTree,
|
|
||||||
const GaussianConditional::shared_ptr &factor) {
|
|
||||||
// If the decision tree is not initialized, then initialize it.
|
|
||||||
if (gfgTree.empty()) {
|
|
||||||
GaussianBayesNet result{factor};
|
|
||||||
return GaussianBayesNetTree(result);
|
|
||||||
} else {
|
|
||||||
auto add = [&factor](const GaussianBayesNet &graph) {
|
|
||||||
auto result = graph;
|
|
||||||
result.push_back(factor);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
return gfgTree.apply(add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
|
||||||
GaussianBayesNetTree result;
|
GaussianBayesNetTree result;
|
||||||
|
|
@ -320,25 +302,12 @@ AlgebraicDecisionTree<Key> HybridBayesNet::modelSelection() const {
|
||||||
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
AlgebraicDecisionTree<Key> errorTree = trees.second;
|
||||||
|
|
||||||
// Only compute logNormalizationConstant
|
// Only compute logNormalizationConstant
|
||||||
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
AlgebraicDecisionTree<Key> log_norm_constants =
|
||||||
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
computeLogNormConstants(bnTree);
|
||||||
GaussianBayesNet gbn = gbnAndValue.first;
|
|
||||||
if (gbn.size() == 0) {
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
return gbn.logNormalizationConstant();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Compute model selection term (with help from ADT methods)
|
// Compute model selection term (with help from ADT methods)
|
||||||
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||||
(errorTree + log_norm_constants) * -1;
|
computeModelSelectionTerm(errorTree, log_norm_constants);
|
||||||
|
|
||||||
double max_log = modelSelectionTerm.max();
|
|
||||||
modelSelectionTerm = DecisionTree<Key, double>(
|
|
||||||
modelSelectionTerm,
|
|
||||||
[&max_log](const double &x) { return std::exp(x - max_log); });
|
|
||||||
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
|
||||||
|
|
||||||
return modelSelectionTerm;
|
return modelSelectionTerm;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -530,4 +499,52 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
GaussianBayesNetTree addGaussian(
|
||||||
|
const GaussianBayesNetTree &gbnTree,
|
||||||
|
const GaussianConditional::shared_ptr &factor) {
|
||||||
|
// If the decision tree is not initialized, then initialize it.
|
||||||
|
if (gbnTree.empty()) {
|
||||||
|
GaussianBayesNet result{factor};
|
||||||
|
return GaussianBayesNetTree(result);
|
||||||
|
} else {
|
||||||
|
auto add = [&factor](const GaussianBayesNet &graph) {
|
||||||
|
auto result = graph;
|
||||||
|
result.push_back(factor);
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
return gbnTree.apply(add);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||||
|
const GaussianBayesNetValTree &bnTree) {
|
||||||
|
AlgebraicDecisionTree<Key> log_norm_constants = DecisionTree<Key, double>(
|
||||||
|
bnTree, [](const std::pair<GaussianBayesNet, double> &gbnAndValue) {
|
||||||
|
GaussianBayesNet gbn = gbnAndValue.first;
|
||||||
|
if (gbn.size() == 0) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
return gbn.logNormalizationConstant();
|
||||||
|
});
|
||||||
|
return log_norm_constants;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||||
|
const AlgebraicDecisionTree<Key> &errorTree,
|
||||||
|
const AlgebraicDecisionTree<Key> &log_norm_constants) {
|
||||||
|
AlgebraicDecisionTree<Key> modelSelectionTerm =
|
||||||
|
(errorTree + log_norm_constants) * -1;
|
||||||
|
|
||||||
|
double max_log = modelSelectionTerm.max();
|
||||||
|
modelSelectionTerm = DecisionTree<Key, double>(
|
||||||
|
modelSelectionTerm,
|
||||||
|
[&max_log](const double &x) { return std::exp(x - max_log); });
|
||||||
|
modelSelectionTerm = modelSelectionTerm.normalize(modelSelectionTerm.sum());
|
||||||
|
|
||||||
|
return modelSelectionTerm;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
|
|
||||||
This can be computed by multiplying all the exponentiated errors
|
This can be computed by multiplying all the exponentiated errors
|
||||||
of each of the conditionals.
|
of each of the conditionals.
|
||||||
|
|
||||||
Return a tree where each leaf value is L(M_i;Z).
|
Return a tree where each leaf value is L(M_i;Z).
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> modelSelection() const;
|
AlgebraicDecisionTree<Key> modelSelection() const;
|
||||||
|
|
@ -280,4 +280,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
template <>
|
template <>
|
||||||
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Add a Gaussian conditional to each node of the GaussianBayesNetTree
|
||||||
|
*
|
||||||
|
* @param gbnTree
|
||||||
|
* @param factor
|
||||||
|
* @return GaussianBayesNetTree
|
||||||
|
*/
|
||||||
|
GaussianBayesNetTree addGaussian(const GaussianBayesNetTree &gbnTree,
|
||||||
|
const GaussianConditional::shared_ptr &factor);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the (logarithmic) normalization constant for each Bayes
|
||||||
|
* network in the tree.
|
||||||
|
*
|
||||||
|
* @param bnTree A tree of Bayes networks in each leaf. The tree encodes a
|
||||||
|
* discrete assignment yielding the Bayes net.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> computeLogNormConstants(
|
||||||
|
const GaussianBayesNetValTree &bnTree);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute the model selection term L(M; Z, X) given the error
|
||||||
|
* and log normalization constants.
|
||||||
|
*
|
||||||
|
* Perform normalization to handle underflow issues.
|
||||||
|
*
|
||||||
|
* @param errorTree
|
||||||
|
* @param log_norm_constants
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> computeModelSelectionTerm(
|
||||||
|
const AlgebraicDecisionTree<Key> &errorTree,
|
||||||
|
const AlgebraicDecisionTree<Key> &log_norm_constants);
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue