park changes so I can come back to them later

release/4.3a0
Varun Agrawal 2023-07-20 15:47:58 -04:00
parent 1dfb388587
commit ea24a2c7e8
4 changed files with 111 additions and 20 deletions

View File

@ -204,12 +204,18 @@ namespace gtsam {
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
EliminateDiscrete(const DiscreteFactorGraph& factors,
const Ordering& frontalKeys) {
factors.print("The Factors to eliminate:");
// PRODUCT: multiply all factors
gttic(product);
DecisionTreeFactor product;
for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
std::cout << "\n\n==========" << std::endl;
std::cout << "Product" << std::endl;
std::cout << std::endl;
product.print();
// Max over all the potentials by pretending all keys are frontal:
auto normalization = product.max(product.size());
@ -221,6 +227,10 @@ namespace gtsam {
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum);
std::cout << "\n->Sum" << std::endl;
sum->print();
std::cout << "----------------------" << std::endl;
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),

View File

@ -22,13 +22,15 @@
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
using namespace std;
using namespace gtsam;
/* ************************************************************************* */
TEST(DiscreteConditional, constructors) {
TEST_DISABLED(DiscreteConditional, constructors) {
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
@ -49,7 +51,7 @@ TEST(DiscreteConditional, constructors) {
}
/* ************************************************************************* */
TEST(DiscreteConditional, constructors_alt_interface) {
TEST_DISABLED(DiscreteConditional, constructors_alt_interface) {
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4};
@ -68,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
}
/* ************************************************************************* */
TEST(DiscreteConditional, constructors2) {
TEST_DISABLED(DiscreteConditional, constructors2) {
DiscreteKey C(0, 2), B(1, 2);
Signature signature((C | B) = "4/1 3/1");
DiscreteConditional actual(signature);
@ -78,7 +80,7 @@ TEST(DiscreteConditional, constructors2) {
}
/* ************************************************************************* */
TEST(DiscreteConditional, constructors3) {
TEST_DISABLED(DiscreteConditional, constructors3) {
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
DiscreteConditional actual(signature);
@ -89,7 +91,7 @@ TEST(DiscreteConditional, constructors3) {
/* ****************************************************************************/
// Test evaluate for a discrete Prior P(Asia).
TEST(DiscreteConditional, PriorProbability) {
TEST_DISABLED(DiscreteConditional, PriorProbability) {
constexpr Key asiaKey = 0;
const DiscreteKey Asia(asiaKey, 2);
DiscreteConditional dc(Asia, "4/6");
@ -100,7 +102,7 @@ TEST(DiscreteConditional, PriorProbability) {
/* ************************************************************************* */
// Check that error, logProbability, evaluate all work as expected.
TEST(DiscreteConditional, probability) {
TEST_DISABLED(DiscreteConditional, probability) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
@ -114,7 +116,7 @@ TEST(DiscreteConditional, probability) {
/* ************************************************************************* */
// Check calculation of joint P(A,B)
TEST(DiscreteConditional, Multiply) {
TEST_DISABLED(DiscreteConditional, Multiply) {
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
@ -139,7 +141,7 @@ TEST(DiscreteConditional, Multiply) {
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C)
TEST(DiscreteConditional, Multiply2) {
TEST_DISABLED(DiscreteConditional, Multiply2) {
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_C(B | C = "1/3 3/1");
@ -159,7 +161,7 @@ TEST(DiscreteConditional, Multiply2) {
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B|C), double check keys
TEST(DiscreteConditional, Multiply3) {
TEST_DISABLED(DiscreteConditional, Multiply3) {
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_C(B | C = "1/3 3/1");
@ -179,7 +181,7 @@ TEST(DiscreteConditional, Multiply3) {
/* ************************************************************************* */
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
TEST(DiscreteConditional, Multiply4) {
TEST_DISABLED(DiscreteConditional, Multiply4) {
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional A_given_B(A | B = "1/3 3/1");
DiscreteConditional B_given_D(B | D = "1/3 3/1");
@ -203,7 +205,7 @@ TEST(DiscreteConditional, Multiply4) {
/* ************************************************************************* */
// Check calculation of marginals for joint P(A,B)
TEST(DiscreteConditional, marginals) {
TEST_DISABLED(DiscreteConditional, marginals) {
DiscreteKey A(1, 2), B(0, 2);
DiscreteConditional conditional(A | B = "1/2 2/1");
DiscreteConditional prior(B % "1/2");
@ -225,7 +227,7 @@ TEST(DiscreteConditional, marginals) {
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
TEST_DISABLED(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
@ -241,7 +243,7 @@ TEST(DiscreteConditional, marginals2) {
}
/* ************************************************************************* */
TEST(DiscreteConditional, likelihood) {
TEST_DISABLED(DiscreteConditional, likelihood) {
DiscreteKey X(0, 2), Y(1, 3);
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
@ -256,7 +258,7 @@ TEST(DiscreteConditional, likelihood) {
/* ************************************************************************* */
// Check choose on P(C|D,E)
TEST(DiscreteConditional, choose) {
TEST_DISABLED(DiscreteConditional, choose) {
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
@ -284,7 +286,7 @@ TEST(DiscreteConditional, choose) {
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {
TEST_DISABLED(DiscreteConditional, markdown_prior) {
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
string expected =
@ -300,7 +302,7 @@ TEST(DiscreteConditional, markdown_prior) {
/* ************************************************************************* */
// Check markdown representation looks as expected, no parents + names.
TEST(DiscreteConditional, markdown_prior_names) {
TEST_DISABLED(DiscreteConditional, markdown_prior_names) {
Symbol x1('x', 1);
DiscreteKey A(x1, 3);
DiscreteConditional conditional(A % "1/2/2");
@ -318,7 +320,7 @@ TEST(DiscreteConditional, markdown_prior_names) {
/* ************************************************************************* */
// Check markdown representation looks as expected, multivalued.
TEST(DiscreteConditional, markdown_multivalued) {
TEST_DISABLED(DiscreteConditional, markdown_multivalued) {
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
DiscreteConditional conditional(
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
@ -337,7 +339,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
/* ************************************************************************* */
// Check markdown representation looks as expected, two parents + names.
TEST(DiscreteConditional, markdown) {
TEST_DISABLED(DiscreteConditional, markdown) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
@ -360,7 +362,7 @@ TEST(DiscreteConditional, markdown) {
/* ************************************************************************* */
// Check html representation looks as expected, two parents + names.
TEST(DiscreteConditional, html) {
TEST_DISABLED(DiscreteConditional, html) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
@ -388,6 +390,72 @@ TEST(DiscreteConditional, html) {
EXPECT(actual == expected);
}
/* ************************************************************************* */
TEST(DiscreteConditional, NrAssignments) {
#ifdef GTSAM_DT_MERGING
string expected = R"( P( 0 1 2 ):
Choice(2)
0 Choice(1)
0 0 Leaf [2] 0
0 1 Choice(0)
0 1 0 Leaf [1] 0.27527634
0 1 1 Leaf [1] 0
1 Choice(1)
1 0 Leaf [2] 0
1 1 Choice(0)
1 1 0 Leaf [1] 0.44944733
1 1 1 Leaf [1] 0.27527634
)";
#else
string expected = R"( P( 0 1 2 ):
Choice(2)
0 Choice(1)
0 0 Choice(0)
0 0 0 Leaf [1] 0
0 0 1 Leaf [1] 0
0 1 Choice(0)
0 1 0 Leaf [1] 0.27527634
0 1 1 Leaf [1] 0.44944733
1 Choice(1)
1 0 Choice(0)
1 0 0 Leaf [1] 0
1 0 1 Leaf [1] 0
1 1 Choice(0)
1 1 0 Leaf [1] 0
1 1 1 Leaf [1] 0.27527634
)";
#endif
DiscreteKeys d0{{0, 2}, {1, 2}, {2, 2}};
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
AlgebraicDecisionTree<Key> dt(d0, p0);
DecisionTreeFactor dtf(d0, dt);
DiscreteConditional f0(3, dtf);
EXPECT(assert_print_equal(expected, f0));
DiscreteFactorGraph dfg{f0};
dfg.print();
auto dbn = dfg.eliminateSequential();
dbn->print();
// DiscreteKeys d0{{0, 2}, {1, 2}};
// std::vector<double> p0 = {0, 1, 0, 2};
// AlgebraicDecisionTree<Key> dt0(d0, p0);
// dt0.print("", DefaultKeyFormatter);
// DiscreteKeys d1{{0, 2}};
// std::vector<double> p1 = {1, 1, 1, 1};
// AlgebraicDecisionTree<Key> dt1(d0, p1);
// dt1.print("", DefaultKeyFormatter);
// auto dd = dt0 / dt1;
// dd.print("", DefaultKeyFormatter);
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -140,15 +140,26 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
for (size_t i = 0; i < this->size(); i++) {
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
std::cout << ">>>" << std::endl;
conditional->print();
discreteProbs = discreteProbs * (*conditional->asDiscrete());
// discreteProbs.print();
// std::cout << "================\n" << std::endl;
Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
std::cout << "Original Joint Prob:" << std::endl;
std::cout << discreteProbs.nrAssignments() << std::endl;
discreteProbs.print();
const DecisionTreeFactor prunedDiscreteProbs =
discreteProbs.prune(maxNrLeaves);
std::cout << "Pruned Joint Prob:" << std::endl;
std::cout << prunedDiscreteProbs.nrAssignments() << std::endl;
prunedDiscreteProbs.print();
std::cout << "\n\n\n";
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
// Eliminate joint probability back into conditionals
@ -159,6 +170,8 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
// std::cout << i << std::endl;
// dbn->at(i)->print();
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

View File

@ -178,7 +178,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
throwRuntimeError("continuousElimination", f);
}
}
dfg.print("The DFG to eliminate");
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);