Added handling of naked Gaussian factors added in python
parent
88f27a210a
commit
a46c53de3e
|
@ -98,7 +98,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
result = addGaussian(result, gf);
|
||||
} else if (auto gm = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
result = gm->add(result);
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asMixture()) {
|
||||
|
@ -107,6 +109,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
result = addGaussian(result, g);
|
||||
} else {
|
||||
// Has to be discrete.
|
||||
// TODO(dellaert): in C++20, we can use std::visit.
|
||||
continue;
|
||||
}
|
||||
} else if (auto gf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
|
@ -486,7 +489,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -497,13 +500,15 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &f : factors_) {
|
||||
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
if (auto hf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
error += hf->error(values.continuous());
|
||||
} else if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||
error += hf->error(values);
|
||||
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
error -= log((*dtf)(values.discrete()));
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error(HV)", f);
|
||||
}
|
||||
}
|
||||
return error;
|
||||
|
|
Loading…
Reference in New Issue