update hybrid elimination and corresponding tests

release/4.3a0
Varun Agrawal 2022-12-10 10:35:46 +05:30
parent 0596b2f543
commit 62bc9f20a3
7 changed files with 25 additions and 55 deletions

View File

@ -172,8 +172,13 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} }
} }
// std::cout << "Eliminate For MPE" << std::endl;
auto result = EliminateForMPE(dfg, frontalKeys); auto result = EliminateForMPE(dfg, frontalKeys);
// std::cout << "discrete elimination done!" << std::endl;
// dfg.print();
// std::cout << "\n\n\n" << std::endl;
// result.first->print();
// result.second->print();
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridDiscreteFactor>(result.second)}; boost::make_shared<HybridDiscreteFactor>(result.second)};
} }
@ -262,7 +267,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (!factor) { if (!factor) {
return 0.0; // If nullptr, return 0.0 probability return 0.0; // If nullptr, return 0.0 probability
} else { } else {
return 1.0; double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error);
} }
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorProb); DecisionTree<Key, double> fdt(separatorFactors, factorProb);
@ -550,5 +557,4 @@ HybridGaussianFactorGraph::separateContinuousDiscreteOrdering(
return std::make_pair(continuous_ordering, discrete_ordering); return std::make_pair(continuous_ordering, discrete_ordering);
} }
} // namespace gtsam } // namespace gtsam

View File

@ -32,7 +32,7 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
addConditionals(graph, hybridBayesNet_, ordering); addConditionals(graph, hybridBayesNet_, ordering);
// Eliminate. // Eliminate.
auto bayesNetFragment = graph.eliminateSequential(); auto bayesNetFragment = graph.eliminateSequential(ordering);
/// Prune /// Prune
if (maxNrLeaves) { if (maxNrLeaves) {

View File

@ -15,6 +15,7 @@
* @author Varun Agrawal * @author Varun Agrawal
*/ */
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/geometry/Pose2.h> #include <gtsam/geometry/Pose2.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
@ -280,8 +281,10 @@ TEST(HybridEstimation, Probability) {
VectorValues values = bayes_net->optimize(); VectorValues values = bayes_net->optimize();
expected_errors.push_back(linear_graph->error(values)); double error = linear_graph->error(values);
expected_prob_primes.push_back(linear_graph->probPrime(values)); expected_errors.push_back(error);
double prob_prime = linear_graph->probPrime(values);
expected_prob_primes.push_back(prob_prime);
} }
// Switching example of robot moving in 1D with given measurements and equal // Switching example of robot moving in 1D with given measurements and equal
@ -291,51 +294,21 @@ TEST(HybridEstimation, Probability) {
auto graph = switching.linearizedFactorGraph; auto graph = switching.linearizedFactorGraph;
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph); HybridBayesNet::shared_ptr bayesNet = graph.eliminateSequential(ordering);
auto discreteConditional = bayesNet->atDiscrete(bayesNet->size() - 3);
// Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys());
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering);
// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
graph.continuousDelta(discrete_keys, bayesNet, assignments);
AlgebraicDecisionTree<Key> probPrimeTree =
graph.continuousProbPrimes(discrete_keys, bayesNet);
EXPECT(assert_equal(expected_probPrimeTree, probPrimeTree));
// Test if the probPrimeTree matches the probability of // Test if the probPrimeTree matches the probability of
// the individual factor graphs // the individual factor graphs
for (size_t i = 0; i < pow(2, K - 1); i++) { for (size_t i = 0; i < pow(2, K - 1); i++) {
Assignment<Key> discrete_assignment; DiscreteValues discrete_assignment;
for (size_t v = 0; v < discrete_seq_map[i].size(); v++) { for (size_t v = 0; v < discrete_seq_map[i].size(); v++) {
discrete_assignment[M(v)] = discrete_seq_map[i][v]; discrete_assignment[M(v)] = discrete_seq_map[i][v];
} }
EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i), double discrete_transition_prob = 0.25;
probPrimeTree(discrete_assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i) * discrete_transition_prob,
(*discreteConditional)(discrete_assignment), 1e-8);
} }
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
Ordering discrete(graph.discreteKeys());
auto discreteBayesNet = discreteGraph->eliminateSequential();
bayesNet->add(*discreteBayesNet);
HybridValues hybrid_values = bayesNet->optimize(); HybridValues hybrid_values = bayesNet->optimize();
// This is the correct sequence as designed // This is the correct sequence as designed

View File

@ -277,7 +277,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full);
// 9 cliques in the bayes tree and 0 remaining variables to eliminate. // 9 cliques in the bayes tree and 0 remaining variables to eliminate.
EXPECT_LONGS_EQUAL(7, hbt->size()); EXPECT_LONGS_EQUAL(9, hbt->size());
EXPECT_LONGS_EQUAL(0, remaining->size()); EXPECT_LONGS_EQUAL(0, remaining->size());
/* /*

View File

@ -178,7 +178,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.166667, m00_prob, 1e-5)); EXPECT(assert_equal(0.0619233, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));

View File

@ -372,8 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner()); dynamic_pointer_cast<DecisionTreeFactor>(hybridDiscreteFactor->inner());
CHECK(discreteFactor); CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
// All leaves should be probability 1 since this is not P*(X|M,Z) EXPECT(discreteFactor->root_->isLeaf() == false);
EXPECT(discreteFactor->root_->isLeaf());
// TODO(Varun) Test emplace_discrete // TODO(Varun) Test emplace_discrete
} }
@ -441,14 +440,6 @@ TEST(HybridFactorGraph, Full_Elimination) {
discrete_fg.push_back(df->inner()); discrete_fg.push_back(df->inner());
} }
// Get the probabilit P*(X | M, Z)
DiscreteKeys discrete_keys =
remainingFactorGraph_partial->at(2)->discreteKeys();
AlgebraicDecisionTree<Key> probPrimeTree =
linearizedFactorGraph.continuousProbPrimes(discrete_keys,
hybridBayesNet_partial);
discrete_fg.add(DecisionTreeFactor(discrete_keys, probPrimeTree));
ordering.clear(); ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
discreteBayesNet = discreteBayesNet =

View File

@ -197,7 +197,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.166667, m00_prob, 1e-5)); EXPECT(assert_equal(0.0619233, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0619233, (*discreteConditional)(assignment), 1e-5));