Added handling of naked Gaussian factors added in python
parent
88f27a210a
commit
a46c53de3e
|
@ -98,7 +98,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
result = addGaussian(result, gf);
|
||||||
|
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
result = gm->add(result);
|
result = gm->add(result);
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
if (auto gm = hc->asMixture()) {
|
if (auto gm = hc->asMixture()) {
|
||||||
|
@ -107,6 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
result = addGaussian(result, g);
|
result = addGaussian(result, g);
|
||||||
} else {
|
} else {
|
||||||
// Has to be discrete.
|
// Has to be discrete.
|
||||||
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
|
@ -486,7 +489,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
// If factor at `idx` is discrete-only, we skip.
|
// If factor at `idx` is discrete-only, we skip.
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -497,13 +500,15 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||||
double error = 0.0;
|
double error = 0.0;
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
error += hf->error(values.continuous());
|
||||||
|
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||||
// TODO(dellaert): needs to change when we discard other wrappers.
|
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||||
error += hf->error(values);
|
error += hf->error(values);
|
||||||
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||||
error -= log((*dtf)(values.discrete()));
|
error -= log((*dtf)(values.discrete()));
|
||||||
} else {
|
} else {
|
||||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return error;
|
return error;
|
||||||
|
|
Loading…
Reference in New Issue