From 02ecc80ecfae7f26752e469de7cc8e17c469d790 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 24 Jul 2023 20:46:30 -0400 Subject: [PATCH 1/3] additional ordering test --- gtsam/inference/tests/testOrdering.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/inference/tests/testOrdering.cpp b/gtsam/inference/tests/testOrdering.cpp index 328d383d8..b6cfcb6ed 100644 --- a/gtsam/inference/tests/testOrdering.cpp +++ b/gtsam/inference/tests/testOrdering.cpp @@ -219,6 +219,11 @@ TEST(Ordering, AppendVector) { Ordering expected{X(0), X(1), X(2)}; EXPECT(assert_equal(expected, actual)); + + actual = Ordering(); + Ordering addl{X(0), X(1), X(2)}; + actual += addl; + EXPECT(assert_equal(expected, actual)); } /* ************************************************************************* */ From 5f93febcbe0d9f98d9fce4b13a498e14f10ae4df Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 11:35:09 -0400 Subject: [PATCH 2/3] keyformatter for NonlinearFactorGraph::printErrors in python --- gtsam/nonlinear/nonlinear.i | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gtsam/nonlinear/nonlinear.i b/gtsam/nonlinear/nonlinear.i index 06451ab1f..3f5fb1dd5 100644 --- a/gtsam/nonlinear/nonlinear.i +++ b/gtsam/nonlinear/nonlinear.i @@ -86,7 +86,10 @@ class NonlinearFactorGraph { const gtsam::noiseModel::Base* noiseModel); // NonlinearFactorGraph - void printErrors(const gtsam::Values& values) const; + void printErrors(const gtsam::Values& values, + const string& str = "NonlinearFactorGraph: ", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double error(const gtsam::Values& values) const; double probPrime(const gtsam::Values& values) const; gtsam::Ordering orderingCOLAMD() const; From b51ff749645d2b6ecc8b3d4ed8434e7a2276caaa Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 25 Jul 2023 14:28:14 -0400 Subject: [PATCH 3/3] discrete conditional from vector of doubles --- gtsam/discrete/DiscreteConditional.h | 12 ++++++++++++ gtsam/discrete/DiscreteKey.h | 6 ++++++ gtsam/discrete/tests/testDiscreteConditional.cpp | 5 +++++ 3 files changed, 23 insertions(+) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 183cf8561..50fa6e161 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -77,6 +77,18 @@ class GTSAM_EXPORT DiscreteConditional const Signature::Table& table) : DiscreteConditional(Signature(key, parents, table)) {} + /** + * Construct from key, parents, and a vector specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteConditional P(D, {B,E}, table); + */ + DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::vector& table) + : DiscreteConditional(1, DiscreteKeys{key} & parents, + ADT(DiscreteKeys{key} & parents, table)) {} + /** * Construct from key, parents, and a string specifying the conditional * probability table (CPT) in 00 01 10 11 order. For three-valued, it would diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 3a626c6b3..44cc192ef 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -74,6 +74,12 @@ namespace gtsam { return *this; } + /// Add multiple keys (non-const!) + DiscreteKeys& operator&(const DiscreteKeys& keys) { + this->insert(this->end(), keys.begin(), keys.end()); + return *this; + } + /// Print the keys and cardinalities. void print(const std::string& s = "", const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index aa393d74c..9439f5653 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -46,6 +46,11 @@ TEST(DiscreteConditional, constructors) { DiscreteConditional actual2(1, f2); DecisionTreeFactor expected2 = f2 / *f2.sum(1); EXPECT(assert_equal(expected2, static_cast(actual2))); + + std::vector probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75}; + DiscreteConditional actual3(X, {Y, Z}, probs); + DecisionTreeFactor expected3 = f2; + EXPECT(assert_equal(expected3, static_cast(actual3))); } /* ************************************************************************* */