park changes so I can come back to them later
parent
1dfb388587
commit
ea24a2c7e8
|
@ -204,12 +204,18 @@ namespace gtsam {
|
|||
std::pair<DiscreteConditional::shared_ptr, DecisionTreeFactor::shared_ptr> //
|
||||
EliminateDiscrete(const DiscreteFactorGraph& factors,
|
||||
const Ordering& frontalKeys) {
|
||||
factors.print("The Factors to eliminate:");
|
||||
// PRODUCT: multiply all factors
|
||||
gttic(product);
|
||||
DecisionTreeFactor product;
|
||||
for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
std::cout << "\n\n==========" << std::endl;
|
||||
std::cout << "Product" << std::endl;
|
||||
std::cout << std::endl;
|
||||
product.print();
|
||||
|
||||
// Max over all the potentials by pretending all keys are frontal:
|
||||
auto normalization = product.max(product.size());
|
||||
|
||||
|
@ -221,6 +227,10 @@ namespace gtsam {
|
|||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
gttoc(sum);
|
||||
|
||||
std::cout << "\n->Sum" << std::endl;
|
||||
sum->print();
|
||||
std::cout << "----------------------" << std::endl;
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||
|
|
|
@ -22,13 +22,15 @@
|
|||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors) {
|
||||
TEST_DISABLED(DiscreteConditional, constructors) {
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
DiscreteConditional actual(X | Y = "1/1 2/3 1/4");
|
||||
|
@ -49,7 +51,7 @@ TEST(DiscreteConditional, constructors) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors_alt_interface) {
|
||||
TEST_DISABLED(DiscreteConditional, constructors_alt_interface) {
|
||||
DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering !
|
||||
|
||||
const Signature::Row r1{1, 1}, r2{2, 3}, r3{1, 4};
|
||||
|
@ -68,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors2) {
|
||||
TEST_DISABLED(DiscreteConditional, constructors2) {
|
||||
DiscreteKey C(0, 2), B(1, 2);
|
||||
Signature signature((C | B) = "4/1 3/1");
|
||||
DiscreteConditional actual(signature);
|
||||
|
@ -78,7 +80,7 @@ TEST(DiscreteConditional, constructors2) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, constructors3) {
|
||||
TEST_DISABLED(DiscreteConditional, constructors3) {
|
||||
DiscreteKey C(0, 2), B(1, 2), A(2, 2);
|
||||
Signature signature((C | B, A) = "4/1 1/1 1/1 1/4");
|
||||
DiscreteConditional actual(signature);
|
||||
|
@ -89,7 +91,7 @@ TEST(DiscreteConditional, constructors3) {
|
|||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a discrete Prior P(Asia).
|
||||
TEST(DiscreteConditional, PriorProbability) {
|
||||
TEST_DISABLED(DiscreteConditional, PriorProbability) {
|
||||
constexpr Key asiaKey = 0;
|
||||
const DiscreteKey Asia(asiaKey, 2);
|
||||
DiscreteConditional dc(Asia, "4/6");
|
||||
|
@ -100,7 +102,7 @@ TEST(DiscreteConditional, PriorProbability) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check that error, logProbability, evaluate all work as expected.
|
||||
TEST(DiscreteConditional, probability) {
|
||||
TEST_DISABLED(DiscreteConditional, probability) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
|
@ -114,7 +116,7 @@ TEST(DiscreteConditional, probability) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of joint P(A,B)
|
||||
TEST(DiscreteConditional, Multiply) {
|
||||
TEST_DISABLED(DiscreteConditional, Multiply) {
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
|
@ -139,7 +141,7 @@ TEST(DiscreteConditional, Multiply) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C)
|
||||
TEST(DiscreteConditional, Multiply2) {
|
||||
TEST_DISABLED(DiscreteConditional, Multiply2) {
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2);
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||
|
@ -159,7 +161,7 @@ TEST(DiscreteConditional, Multiply2) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B|C), double check keys
|
||||
TEST(DiscreteConditional, Multiply3) {
|
||||
TEST_DISABLED(DiscreteConditional, Multiply3) {
|
||||
DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!!
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_C(B | C = "1/3 3/1");
|
||||
|
@ -179,7 +181,7 @@ TEST(DiscreteConditional, Multiply3) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)
|
||||
TEST(DiscreteConditional, Multiply4) {
|
||||
TEST_DISABLED(DiscreteConditional, Multiply4) {
|
||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional A_given_B(A | B = "1/3 3/1");
|
||||
DiscreteConditional B_given_D(B | D = "1/3 3/1");
|
||||
|
@ -203,7 +205,7 @@ TEST(DiscreteConditional, Multiply4) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals for joint P(A,B)
|
||||
TEST(DiscreteConditional, marginals) {
|
||||
TEST_DISABLED(DiscreteConditional, marginals) {
|
||||
DiscreteKey A(1, 2), B(0, 2);
|
||||
DiscreteConditional conditional(A | B = "1/2 2/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
|
@ -225,7 +227,7 @@ TEST(DiscreteConditional, marginals) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check calculation of marginals in case branches are pruned
|
||||
TEST(DiscreteConditional, marginals2) {
|
||||
TEST_DISABLED(DiscreteConditional, marginals2) {
|
||||
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||
DiscreteConditional prior(B % "1/2");
|
||||
|
@ -241,7 +243,7 @@ TEST(DiscreteConditional, marginals2) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, likelihood) {
|
||||
TEST_DISABLED(DiscreteConditional, likelihood) {
|
||||
DiscreteKey X(0, 2), Y(1, 3);
|
||||
DiscreteConditional conditional(X | Y = "2/8 4/6 5/5");
|
||||
|
||||
|
@ -256,7 +258,7 @@ TEST(DiscreteConditional, likelihood) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check choose on P(C|D,E)
|
||||
TEST(DiscreteConditional, choose) {
|
||||
TEST_DISABLED(DiscreteConditional, choose) {
|
||||
DiscreteKey C(2, 2), D(4, 2), E(3, 2);
|
||||
DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4");
|
||||
|
||||
|
@ -284,7 +286,7 @@ TEST(DiscreteConditional, choose) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents.
|
||||
TEST(DiscreteConditional, markdown_prior) {
|
||||
TEST_DISABLED(DiscreteConditional, markdown_prior) {
|
||||
DiscreteKey A(Symbol('x', 1), 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
string expected =
|
||||
|
@ -300,7 +302,7 @@ TEST(DiscreteConditional, markdown_prior) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, no parents + names.
|
||||
TEST(DiscreteConditional, markdown_prior_names) {
|
||||
TEST_DISABLED(DiscreteConditional, markdown_prior_names) {
|
||||
Symbol x1('x', 1);
|
||||
DiscreteKey A(x1, 3);
|
||||
DiscreteConditional conditional(A % "1/2/2");
|
||||
|
@ -318,7 +320,7 @@ TEST(DiscreteConditional, markdown_prior_names) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, multivalued.
|
||||
TEST(DiscreteConditional, markdown_multivalued) {
|
||||
TEST_DISABLED(DiscreteConditional, markdown_multivalued) {
|
||||
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
|
||||
DiscreteConditional conditional(
|
||||
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
|
||||
|
@ -337,7 +339,7 @@ TEST(DiscreteConditional, markdown_multivalued) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check markdown representation looks as expected, two parents + names.
|
||||
TEST(DiscreteConditional, markdown) {
|
||||
TEST_DISABLED(DiscreteConditional, markdown) {
|
||||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||
string expected =
|
||||
|
@ -360,7 +362,7 @@ TEST(DiscreteConditional, markdown) {
|
|||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation looks as expected, two parents + names.
|
||||
TEST(DiscreteConditional, html) {
|
||||
TEST_DISABLED(DiscreteConditional, html) {
|
||||
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
|
||||
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
|
||||
string expected =
|
||||
|
@ -388,6 +390,72 @@ TEST(DiscreteConditional, html) {
|
|||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteConditional, NrAssignments) {
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
string expected = R"( P( 0 1 2 ):
|
||||
Choice(2)
|
||||
0 Choice(1)
|
||||
0 0 Leaf [2] 0
|
||||
0 1 Choice(0)
|
||||
0 1 0 Leaf [1] 0.27527634
|
||||
0 1 1 Leaf [1] 0
|
||||
1 Choice(1)
|
||||
1 0 Leaf [2] 0
|
||||
1 1 Choice(0)
|
||||
1 1 0 Leaf [1] 0.44944733
|
||||
1 1 1 Leaf [1] 0.27527634
|
||||
|
||||
)";
|
||||
#else
|
||||
string expected = R"( P( 0 1 2 ):
|
||||
Choice(2)
|
||||
0 Choice(1)
|
||||
0 0 Choice(0)
|
||||
0 0 0 Leaf [1] 0
|
||||
0 0 1 Leaf [1] 0
|
||||
0 1 Choice(0)
|
||||
0 1 0 Leaf [1] 0.27527634
|
||||
0 1 1 Leaf [1] 0.44944733
|
||||
1 Choice(1)
|
||||
1 0 Choice(0)
|
||||
1 0 0 Leaf [1] 0
|
||||
1 0 1 Leaf [1] 0
|
||||
1 1 Choice(0)
|
||||
1 1 0 Leaf [1] 0
|
||||
1 1 1 Leaf [1] 0.27527634
|
||||
|
||||
)";
|
||||
#endif
|
||||
|
||||
DiscreteKeys d0{{0, 2}, {1, 2}, {2, 2}};
|
||||
std::vector<double> p0 = {0, 0, 0.17054468, 0.27845056, 0, 0, 0, 0.17054468};
|
||||
AlgebraicDecisionTree<Key> dt(d0, p0);
|
||||
DecisionTreeFactor dtf(d0, dt);
|
||||
DiscreteConditional f0(3, dtf);
|
||||
|
||||
EXPECT(assert_print_equal(expected, f0));
|
||||
|
||||
DiscreteFactorGraph dfg{f0};
|
||||
dfg.print();
|
||||
auto dbn = dfg.eliminateSequential();
|
||||
dbn->print();
|
||||
|
||||
// DiscreteKeys d0{{0, 2}, {1, 2}};
|
||||
// std::vector<double> p0 = {0, 1, 0, 2};
|
||||
// AlgebraicDecisionTree<Key> dt0(d0, p0);
|
||||
// dt0.print("", DefaultKeyFormatter);
|
||||
|
||||
// DiscreteKeys d1{{0, 2}};
|
||||
// std::vector<double> p1 = {1, 1, 1, 1};
|
||||
// AlgebraicDecisionTree<Key> dt1(d0, p1);
|
||||
// dt1.print("", DefaultKeyFormatter);
|
||||
|
||||
// auto dd = dt0 / dt1;
|
||||
// dd.print("", DefaultKeyFormatter);
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
@ -140,15 +140,26 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
|||
for (size_t i = 0; i < this->size(); i++) {
|
||||
auto conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
std::cout << ">>>" << std::endl;
|
||||
conditional->print();
|
||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
||||
// discreteProbs.print();
|
||||
// std::cout << "================\n" << std::endl;
|
||||
|
||||
Ordering conditional_keys(conditional->frontals());
|
||||
discrete_frontals += conditional_keys;
|
||||
discrete_factor_idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
std::cout << "Original Joint Prob:" << std::endl;
|
||||
std::cout << discreteProbs.nrAssignments() << std::endl;
|
||||
discreteProbs.print();
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteProbs.prune(maxNrLeaves);
|
||||
std::cout << "Pruned Joint Prob:" << std::endl;
|
||||
std::cout << prunedDiscreteProbs.nrAssignments() << std::endl;
|
||||
prunedDiscreteProbs.print();
|
||||
std::cout << "\n\n\n";
|
||||
gttoc_(HybridBayesNet_PruneDiscreteConditionals);
|
||||
|
||||
// Eliminate joint probability back into conditionals
|
||||
|
@ -159,6 +170,8 @@ DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
|||
// Assign pruned discrete conditionals back at the correct indices.
|
||||
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
||||
size_t idx = discrete_factor_idxs.at(i);
|
||||
// std::cout << i << std::endl;
|
||||
// dbn->at(i)->print();
|
||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
||||
}
|
||||
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
|
||||
|
|
|
@ -178,7 +178,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
throwRuntimeError("continuousElimination", f);
|
||||
}
|
||||
}
|
||||
|
||||
dfg.print("The DFG to eliminate");
|
||||
// NOTE: This does sum-product. For max-product, use EliminateForMPE.
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
|
||||
|
|
Loading…
Reference in New Issue