(in branch) bug fix and unit test in permutation bug introduced during BayesTree Clique refactoring
parent
e75e4321af
commit
88c3e81a7d
|
@ -284,7 +284,7 @@ namespace gtsam {
|
|||
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& v1,
|
||||
const typename BayesTree<CONDITIONAL,CLIQUE>::sharedClique& v2
|
||||
) {
|
||||
return v1->equals(*v2);
|
||||
return (!v1 && !v2) || (v1 && v2 && v1->equals(*v2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -66,7 +66,7 @@ namespace gtsam {
|
|||
/* ************************************************************************* */
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
void BayesTreeCliqueBase<DERIVED,CONDITIONAL>::printTree(const std::string& indent) const {
|
||||
print(indent);
|
||||
asDerived(this)->print(indent);
|
||||
BOOST_FOREACH(const derived_ptr& child, children_)
|
||||
child->printTree(indent+" ");
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ namespace gtsam {
|
|||
}
|
||||
#endif
|
||||
if(changed) {
|
||||
BOOST_FOREACH(const shared_ptr& child, children_) {
|
||||
BOOST_FOREACH(const derived_ptr& child, children_) {
|
||||
(void)child->permuteSeparatorWithInverse(inversePermutation);
|
||||
}
|
||||
}
|
||||
|
@ -115,7 +115,7 @@ namespace gtsam {
|
|||
// A first base case is when this clique or its parent is the root,
|
||||
// in which case we return an empty Bayes net.
|
||||
|
||||
shared_ptr parent(parent_.lock());
|
||||
derived_ptr parent(parent_.lock());
|
||||
|
||||
if (R.get()==this || parent==R) {
|
||||
BayesNet<ConditionalType> empty;
|
||||
|
|
|
@ -132,7 +132,7 @@ namespace gtsam {
|
|||
|
||||
bool equals(const This& other, double tol=1e-9) const {
|
||||
return (!conditional_ && !other.conditional()) ||
|
||||
conditional_->equals(*(other.conditional()), tol);
|
||||
conditional_->equals(*other.conditional(), tol);
|
||||
}
|
||||
|
||||
friend class BayesTree<ConditionalType, DerivedType>;
|
||||
|
@ -150,12 +150,13 @@ namespace gtsam {
|
|||
}; // \struct Clique
|
||||
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
typename BayesTreeCliqueBase<DERIVED,CONDITIONAL>::derived_ptr asDerived(const BayesTreeCliqueBase<DERIVED,CONDITIONAL>& base) {
|
||||
#ifndef NDEBUG
|
||||
return boost::dynamic_pointer_cast<DERIVED>(base);
|
||||
#else
|
||||
return boost::static_pointer_cast<DERIVED>(base);
|
||||
#endif
|
||||
const DERIVED* asDerived(const BayesTreeCliqueBase<DERIVED,CONDITIONAL>* base) {
|
||||
return static_cast<const DERIVED*>(base);
|
||||
}
|
||||
|
||||
template<class DERIVED, class CONDITIONAL>
|
||||
DERIVED* asDerived(BayesTreeCliqueBase<DERIVED,CONDITIONAL>* base) {
|
||||
return static_cast<DERIVED*>(base);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -162,7 +162,7 @@ void GaussianConditional::print(const string &s) const
|
|||
}
|
||||
gtsam::print(Vector(get_d()),"d");
|
||||
gtsam::print(sigmas_,"sigmas");
|
||||
cout << "Permutation: " << permutation_.indices() << endl;
|
||||
cout << "Permutation: " << permutation_.indices().transpose() << endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -337,6 +337,8 @@ ISAM2<CONDITIONAL, VALUES, GRAPH>::Impl::PartialSolve(GaussianFactorGraph& facto
|
|||
JunctionTree<GaussianFactorGraph, typename ISAM2Type::Clique> jt(factors, affectedFactorsIndex);
|
||||
result.bayesTree = jt.eliminate(EliminatePreferLDL);
|
||||
if(debug && result.bayesTree) {
|
||||
if(boost::dynamic_pointer_cast<ISAM2Clique<CONDITIONAL> >(result.bayesTree))
|
||||
cout << "Is an ISAM2 clique" << endl;
|
||||
cout << "Re-eliminated BT:\n";
|
||||
result.bayesTree->printTree("");
|
||||
}
|
||||
|
|
|
@ -146,14 +146,26 @@ struct ISAM2Clique : public BayesTreeCliqueBase<ISAM2Clique<CONDITIONAL>, CONDIT
|
|||
/** Access the gradient contribution */
|
||||
const Vector& gradientContribution() const { return gradientContribution_; }
|
||||
|
||||
bool equals(const This& other, double tol=1e-9) const {
|
||||
return Base::equals(other) && ((!cachedFactor_ && !other.cachedFactor_) || (cachedFactor_ && other.cachedFactor_ && cachedFactor_->equals(*other.cachedFactor_, tol)));
|
||||
}
|
||||
|
||||
/** print this node */
|
||||
void print(const std::string& s = "") const {
|
||||
Base::print(s);
|
||||
if(cachedFactor_) cachedFactor_->print(s + "Cached: ");
|
||||
else cout << s << "Cached empty" << endl;
|
||||
}
|
||||
|
||||
void permuteWithInverse(const Permutation& inversePermutation) {
|
||||
if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation);
|
||||
Base::permuteWithInverse(inversePermutation);
|
||||
}
|
||||
|
||||
bool permuteSeparatorWithInverse(const Permutation& inversePermutation) {
|
||||
if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation);
|
||||
return Base::permuteSeparatorWithInverse(inversePermutation);
|
||||
bool changed = Base::permuteSeparatorWithInverse(inversePermutation);
|
||||
if(changed) if(cachedFactor_) cachedFactor_->permuteWithInverse(inversePermutation);
|
||||
return changed;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include <boost/foreach.hpp>
|
||||
#include <boost/assign/std/list.hpp> // for operator +=
|
||||
#include <boost/assign.hpp>
|
||||
using namespace boost::assign;
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
@ -401,6 +402,67 @@ TEST_UNSAFE(ISAM2, clone) {
|
|||
CHECK(assert_equal(isam, *isam2));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(ISAM2, permute_cached) {
|
||||
typedef ISAM2Clique<GaussianConditional> Clique;
|
||||
typedef boost::shared_ptr<ISAM2Clique<GaussianConditional> > sharedClique;
|
||||
|
||||
// Construct expected permuted BayesTree (variable 2 has been changed to 1)
|
||||
BayesTree<GaussianConditional, Clique> expected;
|
||||
expected.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(3, Matrix_(1,1,1.0))
|
||||
(4, Matrix_(1,1,2.0)),
|
||||
2, Vector_(1,1.0), Vector_(1,1.0)), // p(3,4)
|
||||
HessianFactor::shared_ptr())))); // Cached: empty
|
||||
expected.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(2, Matrix_(1,1,1.0))
|
||||
(3, Matrix_(1,1,2.0)),
|
||||
1, Vector_(1,1.0), Vector_(1,1.0)), // p(2|3)
|
||||
boost::make_shared<HessianFactor>(3, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(3)
|
||||
expected.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(0, Matrix_(1,1,1.0))
|
||||
(2, Matrix_(1,1,2.0)),
|
||||
1, Vector_(1,1.0), Vector_(1,1.0)), // p(0|2)
|
||||
boost::make_shared<HessianFactor>(1, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(1)
|
||||
// Change variable 2 to 1
|
||||
expected.root()->children().front()->conditional()->keys()[0] = 1;
|
||||
expected.root()->children().front()->children().front()->conditional()->keys()[1] = 1;
|
||||
|
||||
// Construct unpermuted BayesTree
|
||||
BayesTree<GaussianConditional, Clique> actual;
|
||||
actual.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(3, Matrix_(1,1,1.0))
|
||||
(4, Matrix_(1,1,2.0)),
|
||||
2, Vector_(1,1.0), Vector_(1,1.0)), // p(3,4)
|
||||
HessianFactor::shared_ptr())))); // Cached: empty
|
||||
actual.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(2, Matrix_(1,1,1.0))
|
||||
(3, Matrix_(1,1,2.0)),
|
||||
1, Vector_(1,1.0), Vector_(1,1.0)), // p(2|3)
|
||||
boost::make_shared<HessianFactor>(3, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(3)
|
||||
actual.insert(sharedClique(new Clique(make_pair(
|
||||
boost::make_shared<GaussianConditional>(pair_list_of
|
||||
(0, Matrix_(1,1,1.0))
|
||||
(2, Matrix_(1,1,2.0)),
|
||||
1, Vector_(1,1.0), Vector_(1,1.0)), // p(0|2)
|
||||
boost::make_shared<HessianFactor>(2, Matrix_(1,1,1.0), Vector_(1,1.0), 0.0))))); // Cached: p(2)
|
||||
|
||||
// Create permutation that changes variable 2 -> 0
|
||||
Permutation permutation = Permutation::Identity(5);
|
||||
permutation[2] = 1;
|
||||
|
||||
// Permute BayesTree
|
||||
actual.root()->permuteWithInverse(permutation);
|
||||
|
||||
// Check
|
||||
EXPECT(assert_equal(expected, actual));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr);}
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue