set up unit test to verify that the probPrimeTree has the same values as individual factor graphs

release/4.3a0
Varun Agrawal 2022-11-07 18:25:23 -05:00
parent 98febf2f0c
commit 610a535b30
1 changed files with 120 additions and 82 deletions

View File

@ -69,57 +69,63 @@ Ordering getOrdering(HybridGaussianFactorGraph& factors,
return ordering; return ordering;
} }
// /****************************************************************************/ /****************************************************************************/
// // Test approximate inference with an additional pruning step. // Test approximate inference with an additional pruning step.
// TEST(HybridEstimation, Incremental) { TEST(HybridEstimation, Incremental) {
// size_t K = 15; // size_t K = 15;
// std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 7, 8, // std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
// 9, 9, 9, 10, 11, 11, 11, 11}; // 7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
// // Ground truth discrete seq // // Ground truth discrete seq
// std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, // std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
// 0, 0, 1, 1, 0, 0, 0}; Switching switching(K, 1.0, 0.1, measurements); // 1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
// HybridSmoother smoother; size_t K = 4;
// HybridNonlinearFactorGraph graph; std::vector<double> measurements = {0, 1, 2, 2};
// Values initial; // Ground truth discrete seq
std::vector<size_t> discrete_seq = {1, 1, 0};
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
HybridSmoother smoother;
HybridNonlinearFactorGraph graph;
Values initial;
// // Add the X(0) prior // Add the X(0) prior
// graph.push_back(switching.nonlinearFactorGraph.at(0)); graph.push_back(switching.nonlinearFactorGraph.at(0));
// initial.insert(X(0), switching.linearizationPoint.at<double>(X(0))); initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
// HybridGaussianFactorGraph linearized; HybridGaussianFactorGraph linearized;
// HybridGaussianFactorGraph bayesNet; HybridGaussianFactorGraph bayesNet;
// for (size_t k = 1; k < K; k++) { for (size_t k = 1; k < K; k++) {
// // Motion Model // Motion Model
// graph.push_back(switching.nonlinearFactorGraph.at(k)); graph.push_back(switching.nonlinearFactorGraph.at(k));
// // Measurement // Measurement
// graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1)); graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
// initial.insert(X(k), switching.linearizationPoint.at<double>(X(k))); initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
// bayesNet = smoother.hybridBayesNet(); bayesNet = smoother.hybridBayesNet();
// linearized = *graph.linearize(initial); linearized = *graph.linearize(initial);
// Ordering ordering = getOrdering(bayesNet, linearized); Ordering ordering = getOrdering(bayesNet, linearized);
// smoother.update(linearized, ordering, 3); smoother.update(linearized, ordering, 3);
// graph.resize(0); graph.resize(0);
// } }
// HybridValues delta = smoother.hybridBayesNet().optimize();
// Values result = initial.retract(delta.continuous()); HybridValues delta = smoother.hybridBayesNet().optimize();
// DiscreteValues expected_discrete; Values result = initial.retract(delta.continuous());
// for (size_t k = 0; k < K - 1; k++) {
// expected_discrete[M(k)] = discrete_seq[k];
// }
// EXPECT(assert_equal(expected_discrete, delta.discrete()));
// Values expected_continuous; DiscreteValues expected_discrete;
// for (size_t k = 0; k < K; k++) { for (size_t k = 0; k < K - 1; k++) {
// expected_continuous.insert(X(k), measurements[k]); expected_discrete[M(k)] = discrete_seq[k];
// } }
// EXPECT(assert_equal(expected_continuous, result)); EXPECT(assert_equal(expected_discrete, delta.discrete()));
// }
Values expected_continuous;
for (size_t k = 0; k < K; k++) {
expected_continuous.insert(X(k), measurements[k]);
}
EXPECT(assert_equal(expected_continuous, result));
}
/** /**
* @brief A function to get a specific 1D robot motion problem as a linearized * @brief A function to get a specific 1D robot motion problem as a linearized
@ -180,6 +186,50 @@ std::vector<size_t> getDiscreteSequence(size_t x) {
return discrete_seq; return discrete_seq;
} }
AlgebraicDecisionTree<Key> probPrimeTree(
const HybridGaussianFactorGraph& graph) {
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr remainingGraph;
Ordering continuous(graph.continuousKeys());
std::tie(bayesNet, remainingGraph) =
graph.eliminatePartialSequential(continuous);
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);
std::reverse(discrete_keys.begin(), discrete_keys.end());
vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues& assignment : assignments) {
VectorValues values = bayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
std::vector<double> probPrimes;
for (const DiscreteValues& assignment : assignments) {
double error = 0.0;
VectorValues delta = *delta_tree(assignment);
for (auto factor : graph) {
if (factor->isHybrid()) {
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
error += f->error(delta, assignment);
} else if (factor->isContinuous()) {
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
error += f->inner()->error(delta);
}
}
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}
TEST(HybridEstimation, Probability) { TEST(HybridEstimation, Probability) {
constexpr size_t K = 4; constexpr size_t K = 4;
std::vector<double> measurements = {0, 1, 2, 2}; std::vector<double> measurements = {0, 1, 2, 2};
@ -202,63 +252,51 @@ TEST(HybridEstimation, Probability) {
expected_errors.push_back(linear_graph->error(values)); expected_errors.push_back(linear_graph->error(values));
expected_prob_primes.push_back(linear_graph->probPrime(values)); expected_prob_primes.push_back(linear_graph->probPrime(values));
std::cout << i << " : " << expected_errors.at(i) << "\t|\t"
<< expected_prob_primes.at(i) << std::endl;
} }
// std::vector<size_t> discrete_seq = getDiscreteSequence<K>(0);
// GaussianFactorGraph::shared_ptr linear_graph = specificProblem(
// K, measurements, discrete_seq, measurement_sigma, between_sigma);
// auto bayes_net = linear_graph->eliminateSequential();
// VectorValues values = bayes_net->optimize();
// std::cout << "Total NLFG Error: " << linear_graph->error(values) << std::endl;
// std::cout << "===============" << std::endl;
Switching switching(K, between_sigma, measurement_sigma, measurements); Switching switching(K, between_sigma, measurement_sigma, measurements);
auto graph = switching.linearizedFactorGraph; auto graph = switching.linearizedFactorGraph;
Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph()); Ordering ordering = getOrdering(graph, HybridGaussianFactorGraph());
HybridBayesNet::shared_ptr bayesNet; AlgebraicDecisionTree<Key> expected_probPrimeTree = probPrimeTree(graph);
HybridGaussianFactorGraph::shared_ptr remainingGraph;
Ordering continuous(graph.continuousKeys());
std::tie(bayesNet, remainingGraph) =
graph.eliminatePartialSequential(continuous);
// Eliminate continuous
Ordering continuous_ordering(graph.continuousKeys());
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
graph.eliminatePartialSequential(continuous_ordering);
// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1); auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys(); DiscreteKeys discrete_keys = last_conditional->discreteKeys();
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys); DiscreteValues::CartesianProduct(discrete_keys);
vector<VectorValues::shared_ptr> vector_values; // Reverse discrete keys order for correct tree construction
for (const DiscreteValues& assignment : assignments) {
VectorValues values = bayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
std::reverse(discrete_keys.begin(), discrete_keys.end()); std::reverse(discrete_keys.begin(), discrete_keys.end());
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);
vector<double> probPrimes; // Create a decision tree of all the different VectorValues
for (const DiscreteValues& assignment : assignments) { DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
double error = 0.0; graph.continuousDelta(discrete_keys, bayesNet, assignments);
VectorValues delta = *delta_tree(assignment);
for (auto factor : graph) {
if (factor->isHybrid()) {
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
error += f->error(delta, assignment);
} else if (factor->isContinuous()) { AlgebraicDecisionTree<Key> probPrimeTree =
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor); graph.continuousProbPrimes(discrete_keys, bayesNet, assignments);
error += f->inner()->error(delta);
} EXPECT(assert_equal(expected_probPrimeTree, probPrimeTree));
// Test if the probPrimeTree matches the probability of
// the individual factor graphs
for (size_t i = 0; i < pow(2, K - 1); i++) {
std::vector<size_t> discrete_seq = getDiscreteSequence<K>(i);
Assignment<Key> discrete_assignment;
for (size_t v = 0; v < discrete_seq.size(); v++) {
discrete_assignment[M(v)] = discrete_seq[v];
} }
// std::cout << "\n" << std::endl; EXPECT_DOUBLES_EQUAL(expected_prob_primes.at(i),
// assignment.print(); probPrimeTree(discrete_assignment), 1e-8);
// std::cout << error << " | " << exp(-error) << std::endl;
probPrimes.push_back(exp(-error));
} }
AlgebraicDecisionTree<Key> expected_probPrimeTree(discrete_keys, probPrimes);
expected_probPrimeTree.print("", DefaultKeyFormatter);
// remainingGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); // remainingGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));