update MixtureFactor so that all tests pass
parent
5c375f6d03
commit
d834897b14
|
|
@ -107,8 +107,12 @@ class MixtureFactor : public HybridFactor {
|
||||||
std::copy(f->keys().begin(), f->keys().end(),
|
std::copy(f->keys().begin(), f->keys().end(),
|
||||||
std::inserter(factor_keys_set, factor_keys_set.end()));
|
std::inserter(factor_keys_set, factor_keys_set.end()));
|
||||||
|
|
||||||
nonlinear_factors.push_back(
|
if (auto nf = boost::dynamic_pointer_cast<NonlinearFactor>(f)) {
|
||||||
boost::dynamic_pointer_cast<NonlinearFactor>(f));
|
nonlinear_factors.push_back(nf);
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Factors passed into MixtureFactor need to be nonlinear!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
factors_ = Factors(discreteKeys, nonlinear_factors);
|
factors_ = Factors(discreteKeys, nonlinear_factors);
|
||||||
|
|
||||||
|
|
@ -125,10 +129,10 @@ class MixtureFactor : public HybridFactor {
|
||||||
* @brief Compute error of the MixtureFactor as a tree.
|
* @brief Compute error of the MixtureFactor as a tree.
|
||||||
*
|
*
|
||||||
* @param continuousVals The continuous values for which to compute the error.
|
* @param continuousVals The continuous values for which to compute the error.
|
||||||
* @return DecisionTree<Key, double> A decision tree with corresponding keys
|
* @return AlgebraicDecisionTree<Key> A decision tree with corresponding keys
|
||||||
* as the factor but leaf values as the error.
|
* as the factor but leaf values as the error.
|
||||||
*/
|
*/
|
||||||
DecisionTree<Key, double> error(const Values& continuousVals) const {
|
AlgebraicDecisionTree<Key> error(const Values& continuousVals) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [continuousVals](const sharedFactor& factor) {
|
auto errorFunc = [continuousVals](const sharedFactor& factor) {
|
||||||
return factor->error(continuousVals);
|
return factor->error(continuousVals);
|
||||||
|
|
@ -165,7 +169,7 @@ class MixtureFactor : public HybridFactor {
|
||||||
|
|
||||||
/// print to stdout
|
/// print to stdout
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "MixtureFactor",
|
const std::string& s = "",
|
||||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override {
|
||||||
std::cout << (s.empty() ? "" : s + " ");
|
std::cout << (s.empty() ? "" : s + " ");
|
||||||
Base::print("", keyFormatter);
|
Base::print("", keyFormatter);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue