update other classes with correct types
parent
9360165ef6
commit
0ee9aac434
|
@ -219,7 +219,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
|
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::Factors likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals_,
|
conditionals_,
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianConditional::shared_ptr &conditional)
|
||||||
-> GaussianFactorValuePair {
|
-> GaussianFactorValuePair {
|
||||||
|
|
|
@ -97,9 +97,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
std::cout << "nullptr"
|
std::cout << "nullptr"
|
||||||
<< "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
auto [factor, val] = hgf->operator()(values.discrete());
|
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "value: " << val << std::endl;
|
|
||||||
std::cout << "error = " << factor->error(values) << std::endl;
|
std::cout << "error = " << factor->error(values) << std::endl;
|
||||||
}
|
}
|
||||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
@ -263,11 +261,10 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute discrete probabilities.
|
||||||
auto logProbability = [&](const GaussianFactorValuePair &fv) -> double {
|
auto logProbability =
|
||||||
auto [factor, val] = fv;
|
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
||||||
double v = 0.5 * val * val;
|
if (!factor) return 0.0;
|
||||||
if (!factor) return -v;
|
return -factor->error(VectorValues());
|
||||||
return -(factor->error(VectorValues()) + v);
|
|
||||||
};
|
};
|
||||||
AlgebraicDecisionTree<Key> logProbabilities =
|
AlgebraicDecisionTree<Key> logProbabilities =
|
||||||
DecisionTree<Key, double>(gmf->factors(), logProbability);
|
DecisionTree<Key, double>(gmf->factors(), logProbability);
|
||||||
|
@ -601,7 +598,7 @@ GaussianFactorGraph HybridGaussianFactorGraph::operator()(
|
||||||
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
|
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
|
||||||
gfg.push_back(gf);
|
gfg.push_back(gf);
|
||||||
} else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
gfg.push_back((*hgf)(assignment).first);
|
gfg.push_back((*hgf)(assignment));
|
||||||
} else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
} else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
||||||
gfg.push_back((*hgc)(assignment));
|
gfg.push_back((*hgc)(assignment));
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in New Issue