diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 3bafe5a9c..b02967555 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -281,6 +281,36 @@ HybridValues HybridBayesNet::sample() const { return sample(&kRandomNumberGenerator); } +/* ************************************************************************* */ +AlgebraicDecisionTree HybridBayesNet::errorTree( + const VectorValues &continuousValues) const { + AlgebraicDecisionTree result(0.0); + + // Iterate over each conditional. + for (auto &&conditional : *this) { + if (auto gm = conditional->asMixture()) { + // If conditional is hybrid, compute error for all assignments. + result = result + gm->errorTree(continuousValues); + + } else if (auto gc = conditional->asGaussian()) { + // If continuous, get the error and add it to the result + double error = gc->error(continuousValues); + // Add the computed error to every leaf of the result tree. + result = result.apply( + [error](double leaf_value) { return leaf_value + error; }); + + } else if (auto dc = conditional->asDiscrete()) { + // If discrete, add the discrete error in the right branch + result = result.apply( + [dc](const Assignment &assignment, double leaf_value) { + return leaf_value + dc->error(DiscreteValues(assignment)); + }); + } + } + + return result; +} + /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::logProbability( const VectorValues &continuousValues) const { diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e71cfe9b4..032cd55b9 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -187,6 +187,23 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @param continuousValues Continuous values at which to compute the error. * @return AlgebraicDecisionTree */ + AlgebraicDecisionTree errorTree( + const VectorValues &continuousValues) const; + + /** + * @brief Error method using HybridValues which returns specific error for + * assignment. + */ + using Base::error; + + /** + * @brief Compute log probability for each discrete assignment, + * and return as a tree. + * + * @param continuousValues Continuous values at which + * to compute the log probability. + * @return AlgebraicDecisionTree + */ AlgebraicDecisionTree logProbability( const VectorValues &continuousValues) const; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index bdfac8468..b764dc9e0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -74,6 +74,85 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) { index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); } +/* ************************************************************************ */ +void HybridGaussianFactorGraph::printErrors( + const HybridValues &values, const std::string &str, + const KeyFormatter &keyFormatter, + const std::function + &printCondition) const { + std::cout << str << "size: " << size() << std::endl << std::endl; + + std::stringstream ss; + + for (size_t i = 0; i < factors_.size(); i++) { + auto &&factor = factors_[i]; + std::cout << "Factor " << i << ": "; + + // Clear the stringstream + ss.str(std::string()); + + if (auto gmf = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + gmf->errorTree(values.continuous()).print("", keyFormatter); + std::cout << std::endl; + } + } else if (auto hc = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + + if (hc->isContinuous()) { + std::cout << "error = " << hc->asGaussian()->error(values) << "\n"; + } else if (hc->isDiscrete()) { + std::cout << "error = "; + hc->asDiscrete()->errorTree().print("", keyFormatter); + std::cout << "\n"; + } else { + // Is hybrid + std::cout << "error = "; + hc->asMixture()->errorTree(values.continuous()).print(); + std::cout << "\n"; + } + } + } else if (auto gf = std::dynamic_pointer_cast(factor)) { + const double errorValue = (factor != nullptr ? gf->error(values) : .0); + if (!printCondition(factor.get(), errorValue, i)) + continue; // User-provided filter did not pass + + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = " << errorValue << "\n"; + } + } else if (auto df = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + df->errorTree().print("", keyFormatter); + } + + } else { + continue; + } + + std::cout << "\n"; + } + std::cout.flush(); +} + /* ************************************************************************ */ static GaussianFactorGraphTree addGaussian( const GaussianFactorGraphTree &gfgTree, diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index f924b7a1c..1708ff64b 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -140,9 +140,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph /// @{ // TODO(dellaert): customize print and equals. - // void print(const std::string& s = "HybridGaussianFactorGraph", - // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const - // override; + // void print( + // const std::string& s = "HybridGaussianFactorGraph", + // const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + + void printErrors( + const HybridValues& values, + const std::string& str = "HybridGaussianFactorGraph: ", + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const std::function& + printCondition = + [](const Factor*, double, size_t) { return true; }) const; + // bool equals(const This& fg, double tol = 1e-9) const override; /// @} diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index e51adb9cd..cdd448412 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -42,6 +42,98 @@ void HybridNonlinearFactorGraph::print(const std::string& s, } } +/* ************************************************************************* */ +void HybridNonlinearFactorGraph::printErrors( + const HybridValues& values, const std::string& str, + const KeyFormatter& keyFormatter, + const std::function& printCondition) const { + std::cout << str << "size: " << size() << std::endl << std::endl; + + std::stringstream ss; + + for (size_t i = 0; i < factors_.size(); i++) { + auto&& factor = factors_[i]; + std::cout << "Factor " << i << ": "; + + // Clear the stringstream + ss.str(std::string()); + + if (auto mf = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + mf->errorTree(values.nonlinear()).print("", keyFormatter); + std::cout << std::endl; + } + } else if (auto gmf = + std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + gmf->errorTree(values.continuous()).print("", keyFormatter); + std::cout << std::endl; + } + } else if (auto gm = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + gm->errorTree(values.continuous()).print("", keyFormatter); + std::cout << std::endl; + } + } else if (auto nf = std::dynamic_pointer_cast(factor)) { + const double errorValue = (factor != nullptr ? nf->error(values) : .0); + if (!printCondition(factor.get(), errorValue, i)) + continue; // User-provided filter did not pass + + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = " << errorValue << "\n"; + } + } else if (auto gf = std::dynamic_pointer_cast(factor)) { + const double errorValue = (factor != nullptr ? gf->error(values) : .0); + if (!printCondition(factor.get(), errorValue, i)) + continue; // User-provided filter did not pass + + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = " << errorValue << "\n"; + } + } else if (auto df = std::dynamic_pointer_cast(factor)) { + if (factor == nullptr) { + std::cout << "nullptr" + << "\n"; + } else { + factor->print(ss.str(), keyFormatter); + std::cout << "error = "; + df->errorTree().print("", keyFormatter); + std::cout << std::endl; + } + + } else { + continue; + } + + std::cout << "\n"; + } + std::cout.flush(); +} + /* ************************************************************************* */ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( const Values& continuousValues) const { diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.h b/gtsam/hybrid/HybridNonlinearFactorGraph.h index 89b4fb391..54dc9e93f 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.h +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.h @@ -34,7 +34,7 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { protected: public: using Base = HybridFactorGraph; - using This = HybridNonlinearFactorGraph; ///< this class + using This = HybridNonlinearFactorGraph; ///< this class using shared_ptr = std::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility @@ -63,6 +63,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph { const std::string& s = "HybridNonlinearFactorGraph", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** print errors along with factors*/ + void printErrors( + const HybridValues& values, + const std::string& str = "HybridNonlinearFactorGraph: ", + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const std::function& + printCondition = + [](const Factor*, double, size_t) { return true; }) const; + /// @} /// @name Standard Interface /// @{ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 5248fce01..00dc36cd0 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -153,6 +153,45 @@ TEST(HybridBayesNet, Choose) { *gbn.at(3))); } +/* ****************************************************************************/ +// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). +TEST(HybridBayesNet, Error) { + const auto continuousConditional = GaussianConditional::sharedMeanAndStddev( + X(0), 2 * I_1x1, X(1), Vector1(-4.0), 5.0); + + const SharedDiagonal model0 = noiseModel::Diagonal::Sigmas(Vector1(2.0)), + model1 = noiseModel::Diagonal::Sigmas(Vector1(3.0)); + + const auto conditional0 = std::make_shared( + X(1), Vector1::Constant(5), I_1x1, model0), + conditional1 = std::make_shared( + X(1), Vector1::Constant(2), I_1x1, model1); + + auto gm = + new GaussianMixture({X(1)}, {}, {Asia}, {conditional0, conditional1}); + // Create hybrid Bayes net. + HybridBayesNet bayesNet; + bayesNet.push_back(continuousConditional); + bayesNet.emplace_back(gm); + bayesNet.emplace_back(new DiscreteConditional(Asia, "99/1")); + + // Create values at which to evaluate. + HybridValues values; + values.insert(asiaKey, 0); + values.insert(X(0), Vector1(-6)); + values.insert(X(1), Vector1(1)); + + AlgebraicDecisionTree actual_errors = + bayesNet.errorTree(values.continuous()); + + // Regression. + // Manually added all the error values from the 3 conditional types. + AlgebraicDecisionTree expected_errors( + {Asia}, std::vector{2.33005033585, 5.38619084965}); + + EXPECT(assert_equal(expected_errors, actual_errors)); +} + /* ****************************************************************************/ // Test Bayes net optimize TEST(HybridBayesNet, OptimizeAssignment) { diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 98a8a794f..5be2f2742 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -658,7 +658,7 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, } /* ****************************************************************************/ -// Check that the factor graph unnormalized probability is proportional to the +// Check that the bayes net unnormalized probability is proportional to the // Bayes net probability for the given measurements. bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements, const HybridBayesNet &posterior, size_t num_samples = 100) { diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index a493de5c5..93081d309 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -327,8 +327,8 @@ GaussianFactorGraph::shared_ptr batchGFG(double between, NonlinearFactorGraph graph; graph.addPrior(X(0), 0, Isotropic::Sigma(1, 0.1)); - auto between_x0_x1 = std::make_shared( - X(0), X(1), between, Isotropic::Sigma(1, 1.0)); + auto between_x0_x1 = std::make_shared(X(0), X(1), between, + Isotropic::Sigma(1, 1.0)); graph.push_back(between_x0_x1); @@ -397,6 +397,25 @@ TEST(HybridFactorGraph, Partial_Elimination) { EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)})); } +TEST(HybridFactorGraph, PrintErrors) { + Switching self(3); + + // Get nonlinear factor graph and add linear factors to be holistic + HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph; + fg.add(self.linearizedFactorGraph); + + // Optimize to get HybridValues + HybridBayesNet::shared_ptr bn = + self.linearizedFactorGraph.eliminateSequential(); + HybridValues hv = bn->optimize(); + + // Print and verify + // fg.print(); + // std::cout << "\n\n\n" << std::endl; + // fg.printErrors( + // HybridValues(hv.continuous(), DiscreteValues(), self.linearizationPoint)); +} + /**************************************************************************** * Test full elimination */ @@ -564,7 +583,7 @@ factor 6: P( m1 | m0 ): )"; #else -string expected_hybridFactorGraph = R"( + string expected_hybridFactorGraph = R"( size: 7 factor 0: A[x0] = [ @@ -759,9 +778,9 @@ TEST(HybridFactorGraph, DefaultDecisionTree) { KeyVector contKeys = {X(0), X(1)}; auto noise_model = noiseModel::Isotropic::Sigma(3, 1.0); auto still = std::make_shared(X(0), X(1), Pose2(0, 0, 0), - noise_model), + noise_model), moving = std::make_shared(X(0), X(1), odometry, - noise_model); + noise_model); std::vector motion_models = {still, moving}; fg.emplace_shared( contKeys, DiscreteKeys{gtsam::DiscreteKey(M(1), 2)}, motion_models); @@ -788,7 +807,7 @@ TEST(HybridFactorGraph, DefaultDecisionTree) { initialEstimate.insert(L(1), Point2(4.1, 1.8)); // We want to eliminate variables not connected to DCFactors first. - const Ordering ordering {L(0), L(1), X(0), X(1)}; + const Ordering ordering{L(0), L(1), X(0), X(1)}; HybridGaussianFactorGraph linearized = *fg.linearize(initialEstimate);