Better way of handling assignments

release/4.3a0
Varun Agrawal 2023-12-25 18:47:44 -05:00
parent ebcf958d69
commit 1e298be3b3
1 changed files with 43 additions and 33 deletions

View File

@ -285,8 +285,6 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteFactorGraph discrete_fg;
VectorValues continuousValues;
// Error values for each hybrid factor
AlgebraicDecisionTree<Key> error(0.0);
std::set<DiscreteKey> discreteKeySet;
// this->print();
@ -313,7 +311,8 @@ HybridValues HybridBayesNet::optimize() const {
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k)))
= exp(-(error + log(k))),
where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute (error + log(k)) and exponentiate later
*/
@ -325,29 +324,45 @@ HybridValues HybridBayesNet::optimize() const {
AlgebraicDecisionTree<Key> log_norm_constants =
DecisionTree<Key, double>(bnTree, [](const GaussianBayesNet &gbn) {
if (gbn.size() == 0) {
return -std::numeric_limits<double>::max();
return 0.0;
}
return -gbn.logNormalizationConstant();
return gbn.logNormalizationConstant();
});
// Compute unnormalized error term and compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = log_norm_constants.apply(
[this, &x_map](const Assignment<Key> &assignment, double x) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(
HybridValues(x_map(assignment), DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(x_map(assignment));
}
}
// Compute unnormalized error term
std::vector<DiscreteKey> labels;
for (auto &&key : x_map.labels()) {
labels.push_back(std::make_pair(key, 2));
}
std::vector<double> errors;
x_map.visitWith([this, &errors](const Assignment<Key> &assignment,
const VectorValues &mu) {
double error = 0.0;
for (auto &&f : *this) {
if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) {
error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
} else if (auto g = hc->asGaussian()) {
error += g->error(mu);
}
return -(error + x);
}
}
errors.push_back(error);
});
AlgebraicDecisionTree<Key> errorTree =
DecisionTree<Key, double>(labels, errors);
// Compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = errorTree.apply(
[&log_norm_constants](const Assignment<Key> assignment, double err) {
return -(err + log_norm_constants(assignment));
});
// std::cout << "model selection term" << std::endl;
// model_selection_term.print("", DefaultKeyFormatter);
double max_log = model_selection_term.max();
@ -355,6 +370,7 @@ HybridValues HybridBayesNet::optimize() const {
model_selection_term,
[&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum());
// std::cout << "normalized model selection" << std::endl;
// model_selection.print("", DefaultKeyFormatter);
@ -363,17 +379,11 @@ HybridValues HybridBayesNet::optimize() const {
discrete_fg.push_back(conditional->asDiscrete());
} else {
if (conditional->isContinuous()) {
// /*
// If we are here, it means there are no discrete variables in
// the Bayes net (due to strong elimination ordering).
// This is a continuous-only problem hence model selection doesn't matter.
// */
// auto gc = conditional->asGaussian();
// for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
// frontal != gc->endFrontals(); ++frontal) {
// continuousValues.insert_or_assign(*frontal,
// Vector::Zero(gc->getDim(frontal)));
// }
/*
If we are here, it means there are no discrete variables in
the Bayes net (due to strong elimination ordering).
This is a continuous-only problem hence model selection doesn't matter.
*/
} else if (conditional->isHybrid()) {
auto gm = conditional->asMixture();