From ea24a2c7e8e8ca10cdd02ddea663748905cf8db4 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 20 Jul 2023 15:47:58 -0400 Subject: [PATCH] park changes so I can come back to them later --- gtsam/discrete/DiscreteFactorGraph.cpp | 10 ++ .../tests/testDiscreteConditional.cpp | 106 ++++++++++++++---- gtsam/hybrid/HybridBayesNet.cpp | 13 +++ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- 4 files changed, 111 insertions(+), 20 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index 4ededbb8b..e0d2ed4de 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -204,12 +204,18 @@ namespace gtsam { std::pair // EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { + factors.print("The Factors to eliminate:"); // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; for (auto&& factor : factors) product = (*factor) * product; gttoc(product); + std::cout << "\n\n==========" << std::endl; + std::cout << "Product" << std::endl; + std::cout << std::endl; + product.print(); + // Max over all the potentials by pretending all keys are frontal: auto normalization = product.max(product.size()); @@ -221,6 +227,10 @@ namespace gtsam { DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys); gttoc(sum); + std::cout << "\n->Sum" << std::endl; + sum->print(); + std::cout << "----------------------" << std::endl; + // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index aa393d74c..397e7ff0c 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -22,13 +22,15 @@ #include #include #include +#include +#include using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST(DiscreteConditional, constructors) { +TEST_DISABLED(DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); @@ -49,7 +51,7 @@ TEST(DiscreteConditional, constructors) { } /* ************************************************************************* */ -TEST(DiscreteConditional, constructors_alt_interface) { +TEST_DISABLED(DiscreteConditional, constructors_alt_interface) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4}; @@ -68,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { } /* ************************************************************************* */ -TEST(DiscreteConditional, constructors2) { +TEST_DISABLED(DiscreteConditional, constructors2) { DiscreteKey C(0, 2), B(1, 2); Signature signature((C | B) = "4/1 3/1"); DiscreteConditional actual(signature); @@ -78,7 +80,7 @@ TEST(DiscreteConditional, constructors2) { } /* ************************************************************************* */ -TEST(DiscreteConditional, constructors3) { +TEST_DISABLED(DiscreteConditional, constructors3) { DiscreteKey C(0, 2), B(1, 2), A(2, 2); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); DiscreteConditional actual(signature); @@ -89,7 +91,7 @@ TEST(DiscreteConditional, constructors3) { /* ****************************************************************************/ // Test evaluate for a discrete Prior P(Asia). -TEST(DiscreteConditional, PriorProbability) { +TEST_DISABLED(DiscreteConditional, PriorProbability) { constexpr Key asiaKey = 0; const DiscreteKey Asia(asiaKey, 2); DiscreteConditional dc(Asia, "4/6"); @@ -100,7 +102,7 @@ TEST(DiscreteConditional, PriorProbability) { /* ************************************************************************* */ // Check that error, logProbability, evaluate all work as expected. -TEST(DiscreteConditional, probability) { +TEST_DISABLED(DiscreteConditional, probability) { DiscreteKey C(2, 2), D(4, 2), E(3, 2); DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); @@ -114,7 +116,7 @@ TEST(DiscreteConditional, probability) { /* ************************************************************************* */ // Check calculation of joint P(A,B) -TEST(DiscreteConditional, Multiply) { +TEST_DISABLED(DiscreteConditional, Multiply) { DiscreteKey A(1, 2), B(0, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional prior(B % "1/2"); @@ -139,7 +141,7 @@ TEST(DiscreteConditional, Multiply) { /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C) -TEST(DiscreteConditional, Multiply2) { +TEST_DISABLED(DiscreteConditional, Multiply2) { DiscreteKey A(0, 2), B(1, 2), C(2, 2); DiscreteConditional A_given_B(A | B = "1/3 3/1"); DiscreteConditional B_given_C(B | C = "1/3 3/1"); @@ -159,7 +161,7 @@ TEST(DiscreteConditional, Multiply2) { /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C), double check keys -TEST(DiscreteConditional, Multiply3) { +TEST_DISABLED(DiscreteConditional, Multiply3) { DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! DiscreteConditional A_given_B(A | B = "1/3 3/1"); DiscreteConditional B_given_C(B | C = "1/3 3/1"); @@ -179,7 +181,7 @@ TEST(DiscreteConditional, Multiply3) { /* ************************************************************************* */ // Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) -TEST(DiscreteConditional, Multiply4) { +TEST_DISABLED(DiscreteConditional, Multiply4) { DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); DiscreteConditional A_given_B(A | B = "1/3 3/1"); DiscreteConditional B_given_D(B | D = "1/3 3/1"); @@ -203,7 +205,7 @@ TEST(DiscreteConditional, Multiply4) { /* ************************************************************************* */ // Check calculation of marginals for joint P(A,B) -TEST(DiscreteConditional, marginals) { +TEST_DISABLED(DiscreteConditional, marginals) { DiscreteKey A(1, 2), B(0, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional prior(B % "1/2"); @@ -225,7 +227,7 @@ TEST(DiscreteConditional, marginals) { /* ************************************************************************* */ // Check calculation of marginals in case branches are pruned -TEST(DiscreteConditional, marginals2) { +TEST_DISABLED(DiscreteConditional, marginals2) { DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen! DiscreteConditional conditional(A | B = "2/2 3/1"); DiscreteConditional prior(B % "1/2"); @@ -241,7 +243,7 @@ TEST(DiscreteConditional, marginals2) { } /* ************************************************************************* */ -TEST(DiscreteConditional, likelihood) { +TEST_DISABLED(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); @@ -256,7 +258,7 @@ TEST(DiscreteConditional, likelihood) { /* ************************************************************************* */ // Check choose on P(C|D,E) -TEST(DiscreteConditional, choose) { +TEST_DISABLED(DiscreteConditional, choose) { DiscreteKey C(2, 2), D(4, 2), E(3, 2); DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); @@ -284,7 +286,7 @@ TEST(DiscreteConditional, choose) { /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. -TEST(DiscreteConditional, markdown_prior) { +TEST_DISABLED(DiscreteConditional, markdown_prior) { DiscreteKey A(Symbol('x', 1), 3); DiscreteConditional conditional(A % "1/2/2"); string expected = @@ -300,7 +302,7 @@ TEST(DiscreteConditional, markdown_prior) { /* ************************************************************************* */ // Check markdown representation looks as expected, no parents + names. -TEST(DiscreteConditional, markdown_prior_names) { +TEST_DISABLED(DiscreteConditional, markdown_prior_names) { Symbol x1('x', 1); DiscreteKey A(x1, 3); DiscreteConditional conditional(A % "1/2/2"); @@ -318,7 +320,7 @@ TEST(DiscreteConditional, markdown_prior_names) { /* ************************************************************************* */ // Check markdown representation looks as expected, multivalued. -TEST(DiscreteConditional, markdown_multivalued) { +TEST_DISABLED(DiscreteConditional, markdown_multivalued) { DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5); DiscreteConditional conditional( A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); @@ -337,7 +339,7 @@ TEST(DiscreteConditional, markdown_multivalued) { /* ************************************************************************* */ // Check markdown representation looks as expected, two parents + names. -TEST(DiscreteConditional, markdown) { +TEST_DISABLED(DiscreteConditional, markdown) { DiscreteKey A(2, 2), B(1, 2), C(0, 3); DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = @@ -360,7 +362,7 @@ TEST(DiscreteConditional, markdown) { /* ************************************************************************* */ // Check html representation looks as expected, two parents + names. -TEST(DiscreteConditional, html) { +TEST_DISABLED(DiscreteConditional, html) { DiscreteKey A(2, 2), B(1, 2), C(0, 3); DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = @@ -388,6 +390,72 @@ TEST(DiscreteConditional, html) { EXPECT(actual == expected); } +/* ************************************************************************* */ +TEST(DiscreteConditional, NrAssignments) { +#ifdef GTSAM_DT_MERGING + string expected = R"( P( 0 1 2 ): + Choice(2) + 0 Choice(1) + 0 0 Leaf [2] 0 + 0 1 Choice(0) + 0 1 0 Leaf [1] 0.27527634 + 0 1 1 Leaf [1] 0 + 1 Choice(1) + 1 0 Leaf [2] 0 + 1 1 Choice(0) + 1 1 0 Leaf [1] 0.44944733 + 1 1 1 Leaf [1] 0.27527634 + +)"; +#else + string expected = R"( P( 0 1 2 ): + Choice(2) + 0 Choice(1) + 0 0 Choice(0) + 0 0 0 Leaf [1] 0 + 0 0 1 Leaf [1] 0 + 0 1 Choice(0) + 0 1 0 Leaf [1] 0.27527634 + 0 1 1 Leaf [1] 0.44944733 + 1 Choice(1) + 1 0 Choice(0) + 1 0 0 Leaf [1] 0 + 1 0 1 Leaf [1] 0 + 1 1 Choice(0) + 1 1 0 Leaf [1] 0 + 1 1 1 Leaf [1] 0.27527634 + +)"; +#endif + + DiscreteKeys d0{{0, 2}, {1, 2}, {2, 2}}; + std::vector p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468}; + AlgebraicDecisionTree dt(d0, p0); + DecisionTreeFactor dtf(d0, dt); + DiscreteConditional f0(3, dtf); + + EXPECT(assert_print_equal(expected, f0)); + + DiscreteFactorGraph dfg{f0}; + dfg.print(); + auto dbn = dfg.eliminateSequential(); + dbn->print(); + + // DiscreteKeys d0{{0, 2}, {1, 2}}; + // std::vector p0 = {0, 1, 0, 2}; + // AlgebraicDecisionTree dt0(d0, p0); + // dt0.print("", DefaultKeyFormatter); + + // DiscreteKeys d1{{0, 2}}; + // std::vector p1 = {1, 1, 1, 1}; + // AlgebraicDecisionTree dt1(d0, p1); + // dt1.print("", DefaultKeyFormatter); + + // auto dd = dt0 / dt1; + // dd.print("", DefaultKeyFormatter); +} + + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index b4bf61220..a6163fd3c 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -140,15 +140,26 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( for (size_t i = 0; i < this->size(); i++) { auto conditional = this->at(i); if (conditional->isDiscrete()) { + std::cout << ">>>" << std::endl; + conditional->print(); discreteProbs = discreteProbs * (*conditional->asDiscrete()); + // discreteProbs.print(); + // std::cout << "================\n" << std::endl; Ordering conditional_keys(conditional->frontals()); discrete_frontals += conditional_keys; discrete_factor_idxs.push_back(i); } } + std::cout << "Original Joint Prob:" << std::endl; + std::cout << discreteProbs.nrAssignments() << std::endl; + discreteProbs.print(); const DecisionTreeFactor prunedDiscreteProbs = discreteProbs.prune(maxNrLeaves); + std::cout << "Pruned Joint Prob:" << std::endl; + std::cout << prunedDiscreteProbs.nrAssignments() << std::endl; + prunedDiscreteProbs.print(); + std::cout << "\n\n\n"; gttoc_(HybridBayesNet_PruneDiscreteConditionals); // Eliminate joint probability back into conditionals @@ -159,6 +170,8 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( // Assign pruned discrete conditionals back at the correct indices. for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { size_t idx = discrete_factor_idxs.at(i); + // std::cout << i << std::endl; + // dbn->at(i)->print(); this->at(idx) = std::make_shared(dbn->at(i)); } gttoc_(HybridBayesNet_UpdateDiscreteConditionals); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2d4ac83f6..444cef439 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -178,7 +178,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors, throwRuntimeError("continuousElimination", f); } } - + dfg.print("The DFG to eliminate"); // NOTE: This does sum-product. For max-product, use EliminateForMPE. auto result = EliminateDiscrete(dfg, frontalKeys);