Added handling of naked Gaussian factors added in python

release/4.3a0
Frank Dellaert 2023-01-06 23:22:56 -08:00
parent 88f27a210a
commit a46c53de3e
1 changed files with 9 additions and 4 deletions

View File

@ -98,7 +98,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result = addGaussian(result, gf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
result = gm->add(result); result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) { if (auto gm = hc->asMixture()) {
@ -107,6 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
result = addGaussian(result, g); result = addGaussian(result, g);
} else { } else {
// Has to be discrete. // Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
continue; continue;
} }
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
@ -486,7 +489,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// If factor at `idx` is discrete-only, we skip. // If factor at `idx` is discrete-only, we skip.
continue; continue;
} else { } else {
throwRuntimeError("HybridGaussianFactorGraph::error", f); throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
} }
} }
@ -497,13 +500,15 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
double HybridGaussianFactorGraph::error(const HybridValues &values) const { double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0; double error = 0.0;
for (auto &f : factors_) { for (auto &f : factors_) {
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) { if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
error += hf->error(values.continuous());
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
// TODO(dellaert): needs to change when we discard other wrappers. // TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values); error += hf->error(values);
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) { } else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete())); error -= log((*dtf)(values.discrete()));
} else { } else {
throwRuntimeError("HybridGaussianFactorGraph::error", f); throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
} }
} }
return error; return error;