update MixtureFactor so that all tests pass

release/4.3a0
Varun Agrawal 2022-10-31 15:38:23 -04:00
parent 5c375f6d03
commit d834897b14
1 changed files with 9 additions and 5 deletions

View File

@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor {
std::copy(f->keys().begin(), f->keys().end(), std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end())); std::inserter(factor_keys_set, factor_keys_set.end()));
nonlinear_factors.push_back( if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) {
boost::dynamic_pointer_cast<NonlinearFactor>(f)); nonlinear_factors.push_back(nf);
} else {
throw std::runtime_error(
"Factors passed into MixtureFactor need to be nonlinear!");
}
} }
factors_ = Factors(discreteKeys, nonlinear_factors); factors_ = Factors(discreteKeys, nonlinear_factors);
@ -125,10 +129,10 @@ class MixtureFactor : public HybridFactor {
* @brief Compute error of the MixtureFactor as a tree. * @brief Compute error of the MixtureFactor as a tree.
* *
* @param continuousVals The continuous values for which to compute the error. * @param continuousVals The continuous values for which to compute the error.
* @return DecisionTree<Key, double> A decision tree with corresponding keys * @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
* as the factor but leaf values as the error. * as the factor but leaf values as the error.
*/ */
DecisionTree<Key, double> error(const Values& continuousVals) const { AlgebraicDecisionTree<Key> error(const Values& continuousVals) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousVals](const sharedFactor& factor) { auto errorFunc = [continuousVals](const sharedFactor& factor) {
return factor->error(continuousVals); return factor->error(continuousVals);
@ -165,7 +169,7 @@ class MixtureFactor : public HybridFactor {
/// print to stdout /// print to stdout
void print( void print(
const std::string& s = "MixtureFactor", const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
std::cout << (s.empty() ? "" : s + " "); std::cout << (s.empty() ? "" : s + " ");
Base::print("", keyFormatter); Base::print("", keyFormatter);