diff --git a/gtsam/inference/BayesTree-inl.h b/gtsam/inference/BayesTree-inl.h index 7f77fa898..1c696d3ef 100644 --- a/gtsam/inference/BayesTree-inl.h +++ b/gtsam/inference/BayesTree-inl.h @@ -284,7 +284,7 @@ namespace gtsam { const typename BayesTree::sharedClique& v1, const typename BayesTree::sharedClique& v2 ) { - return v1->equals(*v2); + return (!v1 && !v2) || (v1 && v2 && v1->equals(*v2)); } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTreeCliqueBase-inl.h b/gtsam/inference/BayesTreeCliqueBase-inl.h index d2d2cec65..dfc8864f7 100644 --- a/gtsam/inference/BayesTreeCliqueBase-inl.h +++ b/gtsam/inference/BayesTreeCliqueBase-inl.h @@ -66,7 +66,7 @@ namespace gtsam { /* ************************************************************************* */ template void BayesTreeCliqueBase::printTree(const std::string& indent) const { - print(indent); + asDerived(this)->print(indent); BOOST_FOREACH(const derived_ptr& child, children_) child->printTree(indent+" "); } @@ -94,7 +94,7 @@ namespace gtsam { } #endif if(changed) { - BOOST_FOREACH(const shared_ptr& child, children_) { + BOOST_FOREACH(const derived_ptr& child, children_) { (void)child->permuteSeparatorWithInverse(inversePermutation); } } @@ -115,7 +115,7 @@ namespace gtsam { // A first base case is when this clique or its parent is the root, // in which case we return an empty Bayes net. - shared_ptr parent(parent_.lock()); + derived_ptr parent(parent_.lock()); if (R.get()==this || parent==R) { BayesNet empty; diff --git a/gtsam/inference/BayesTreeCliqueBase.h b/gtsam/inference/BayesTreeCliqueBase.h index fe7c1b81b..66be72f86 100644 --- a/gtsam/inference/BayesTreeCliqueBase.h +++ b/gtsam/inference/BayesTreeCliqueBase.h @@ -132,7 +132,7 @@ namespace gtsam { bool equals(const This& other, double tol=1e-9) const { return (!conditional_ && !other.conditional()) || - conditional_->equals(*(other.conditional()), tol); + conditional_->equals(*other.conditional(), tol); } friend class BayesTree; @@ -150,12 +150,13 @@ namespace gtsam { }; // \struct Clique template - typename BayesTreeCliqueBase::derived_ptr asDerived(const BayesTreeCliqueBase& base) { -#ifndef NDEBUG - return boost::dynamic_pointer_cast(base); -#else - return boost::static_pointer_cast(base); -#endif + const DERIVED* asDerived(const BayesTreeCliqueBase* base) { + return static_cast(base); + } + + template + DERIVED* asDerived(BayesTreeCliqueBase* base) { + return static_cast(base); } } diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 37d9c1add..168bfec97 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -162,7 +162,7 @@ void GaussianConditional::print(const string &s) const } gtsam::print(Vector(get_d()),"d"); gtsam::print(sigmas_,"sigmas"); - cout << "Permutation: " << permutation_.indices() << endl; + cout << "Permutation: " << permutation_.indices().transpose() << endl; } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/ISAM2-impl-inl.h b/gtsam/nonlinear/ISAM2-impl-inl.h index 85b28d744..17c24e23a 100644 --- a/gtsam/nonlinear/ISAM2-impl-inl.h +++ b/gtsam/nonlinear/ISAM2-impl-inl.h @@ -337,6 +337,8 @@ ISAM2::Impl::PartialSolve(GaussianFactorGraph& facto JunctionTree jt(factors, affectedFactorsIndex); result.bayesTree = jt.eliminate(EliminatePreferLDL); if(debug && result.bayesTree) { + if(boost::dynamic_pointer_cast >(result.bayesTree)) + cout << "Is an ISAM2 clique" << endl; cout << "Re-eliminated BT:\n"; result.bayesTree->printTree(""); } diff --git a/gtsam/nonlinear/ISAM2.h b/gtsam/nonlinear/ISAM2.h index 8663aae41..44df3035f 100644 --- a/gtsam/nonlinear/ISAM2.h +++ b/gtsam/nonlinear/ISAM2.h @@ -146,14 +146,26 @@ struct ISAM2Clique : public BayesTreeCliqueBase, CONDIT /** Access the gradient contribution */ const Vector& gradientContribution() const { return gradientContribution_; } + bool equals(const This& other, double tol=1e-9) const { + return Base::equals(other) && ((!cachedFactor_ && !other.cachedFactor_) || (cachedFactor_ && other.cachedFactor_ && cachedFactor_->equals(*other.cachedFactor_, tol))); + } + + /** print this node */ + void print(const std::string& s = "") const { + Base::print(s); + if(cachedFactor_) cachedFactor_->print(s + "Cached: "); + else cout << s << "Cached empty" << endl; + } + void permuteWithInverse(const Permutation& inversePermutation) { if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation); Base::permuteWithInverse(inversePermutation); } bool permuteSeparatorWithInverse(const Permutation& inversePermutation) { - if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation); - return Base::permuteSeparatorWithInverse(inversePermutation); + bool changed = Base::permuteSeparatorWithInverse(inversePermutation); + if(changed) if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation); + return changed; } private: diff --git a/tests/testGaussianISAM2.cpp b/tests/testGaussianISAM2.cpp index 8bdacbf0f..d56d2edbc 100644 --- a/tests/testGaussianISAM2.cpp +++ b/tests/testGaussianISAM2.cpp @@ -6,6 +6,7 @@ #include #include // for operator += +#include using namespace boost::assign; #include @@ -401,6 +402,67 @@ TEST_UNSAFE(ISAM2, clone) { CHECK(assert_equal(isam, *isam2)); } +/* ************************************************************************* */ +TEST(ISAM2, permute_cached) { + typedef ISAM2Clique Clique; + typedef boost::shared_ptr > sharedClique; + + // Construct expected permuted BayesTree (variable 2 has been changed to 1) + BayesTree expected; + expected.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (3, Matrix_(1,1,1.0)) + (4, Matrix_(1,1,2.0)), + 2, Vector_(1,1.0), Vector_(1,1.0)), // p(3,4) + HessianFactor::shared_ptr())))); // Cached: empty + expected.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (2, Matrix_(1,1,1.0)) + (3, Matrix_(1,1,2.0)), + 1, Vector_(1,1.0), Vector_(1,1.0)), // p(2|3) + boost::make_shared(3, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(3) + expected.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (0, Matrix_(1,1,1.0)) + (2, Matrix_(1,1,2.0)), + 1, Vector_(1,1.0), Vector_(1,1.0)), // p(0|2) + boost::make_shared(1, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(1) + // Change variable 2 to 1 + expected.root()->children().front()->conditional()->keys()[0] = 1; + expected.root()->children().front()->children().front()->conditional()->keys()[1] = 1; + + // Construct unpermuted BayesTree + BayesTree actual; + actual.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (3, Matrix_(1,1,1.0)) + (4, Matrix_(1,1,2.0)), + 2, Vector_(1,1.0), Vector_(1,1.0)), // p(3,4) + HessianFactor::shared_ptr())))); // Cached: empty + actual.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (2, Matrix_(1,1,1.0)) + (3, Matrix_(1,1,2.0)), + 1, Vector_(1,1.0), Vector_(1,1.0)), // p(2|3) + boost::make_shared(3, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(3) + actual.insert(sharedClique(new Clique(make_pair( + boost::make_shared(pair_list_of + (0, Matrix_(1,1,1.0)) + (2, Matrix_(1,1,2.0)), + 1, Vector_(1,1.0), Vector_(1,1.0)), // p(0|2) + boost::make_shared(2, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(2) + + // Create permutation that changes variable 2 -> 0 + Permutation permutation = Permutation::Identity(5); + permutation[2] = 1; + + // Permute BayesTree + actual.root()->permuteWithInverse(permutation); + + // Check + EXPECT(assert_equal(expected, actual)); +} + /* ************************************************************************* */ int main() { TestResult tr; return TestRegistry::runAllTests(tr);} /* ************************************************************************* */