update other classes with correct types

release/4.3a0
Varun Agrawal 2024-09-14 14:54:49 -04:00
parent 9360165ef6
commit 0ee9aac434
2 changed files with 7 additions and 10 deletions

View File

@ -219,7 +219,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::Factors likelihoods( const HybridGaussianFactor::FactorValuePairs likelihoods(
conditionals_, conditionals_,
[&](const GaussianConditional::shared_ptr &conditional) [&](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair { -> GaussianFactorValuePair {

View File

@ -97,9 +97,7 @@ void HybridGaussianFactorGraph::printErrors(
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
} else { } else {
auto [factor, val] = hgf->operator()(values.discrete()); hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
factor->print(ss.str(), keyFormatter);
std::cout << "value: " << val << std::endl;
std::cout << "error = " << factor->error(values) << std::endl; std::cout << "error = " << factor->error(values) << std::endl;
} }
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) { } else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
@ -263,11 +261,10 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys. // Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities. // In this case, compute discrete probabilities.
auto logProbability = [&](const GaussianFactorValuePair &fv) -> double { auto logProbability =
auto [factor, val] = fv; [&](const GaussianFactor::shared_ptr &factor) -> double {
double v = 0.5 * val * val; if (!factor) return 0.0;
if (!factor) return -v; return -factor->error(VectorValues());
return -(factor->error(VectorValues()) + v);
}; };
AlgebraicDecisionTree<Key> logProbabilities = AlgebraicDecisionTree<Key> logProbabilities =
DecisionTree<Key, double>(gmf->factors(), logProbability); DecisionTree<Key, double>(gmf->factors(), logProbability);
@ -601,7 +598,7 @@ GaussianFactorGraph HybridGaussianFactorGraph::operator()(
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) { } else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
gfg.push_back(gf); gfg.push_back(gf);
} else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
gfg.push_back((*hgf)(assignment).first); gfg.push_back((*hgf)(assignment));
} else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) { } else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
gfg.push_back((*hgc)(assignment)); gfg.push_back((*hgc)(assignment));
} else { } else {