Fix tests once more

release/4.3a0
Frank Dellaert 2022-12-30 13:59:48 -05:00
parent 3a8220c264
commit bcf8a9ddfd
5 changed files with 18 additions and 16 deletions

View File

@ -179,8 +179,9 @@ TEST(GaussianMixture, Likelihood) {
const GaussianMixtureFactor::Factors factors(
gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) {
return std::make_pair(conditional->likelihood(measurements),
0.5 * conditional->logDeterminant());
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(measurements),
conditional->logNormalizationConstant()};
});
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(*factor, expected));

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -188,7 +189,8 @@ TEST(GaussianMixtureFactor, Error) {
DiscreteValues discreteValues;
discreteValues[m1.first] = 1;
EXPECT_DOUBLES_EQUAL(
4.0, mixtureFactor.error({continuousValues, discreteValues}), 1e-9);
4.0, mixtureFactor.error({continuousValues, discreteValues}),
1e-9);
}
/* ************************************************************************* */

View File

@ -188,14 +188,14 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize();
//TODO(Varun) The expectedAssignment should be 111, not 101
// TODO(Varun) The expectedAssignment should be 111, not 101
DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 0;
expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
//TODO(Varun) This should be all -Vector1::Ones()
// TODO(Varun) This should be all -Vector1::Ones()
VectorValues expectedValues;
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
@ -243,8 +243,8 @@ TEST(HybridBayesNet, Error) {
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
if (hybridBayesNet->at(idx)->isHybrid()) {
double error = hybridBayesNet->atMixture(idx)->error(delta.continuous(),
discrete_values);
double error = hybridBayesNet->atMixture(idx)->error(
{delta.continuous(), discrete_values});
total_error += error;
} else if (hybridBayesNet->at(idx)->isContinuous()) {
double error = hybridBayesNet->atGaussian(idx)->error(delta.continuous());
@ -253,7 +253,7 @@ TEST(HybridBayesNet, Error) {
}
EXPECT_DOUBLES_EQUAL(
total_error, hybridBayesNet->error(delta.continuous(), discrete_values),
total_error, hybridBayesNet->error({delta.continuous(), discrete_values}),
1e-9);
EXPECT_DOUBLES_EQUAL(total_error, error_tree(discrete_values), 1e-9);
EXPECT_DOUBLES_EQUAL(total_error, pruned_error_tree(discrete_values), 1e-9);

View File

@ -273,7 +273,7 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
continue;
}
double error = graph.error(delta, assignment);
double error = graph.error({delta, assignment});
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
@ -487,8 +487,8 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error(sample.continuous(), assignment) -
factorGraph->error(sample.continuous(), assignment);
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio);
return ratio;
};

View File

@ -575,15 +575,14 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
double error = graph.error(delta.continuous(), delta.discrete());
const HybridValues delta = hybridBayesNet->optimize();
const double error = graph.error(delta);
double expected_error = 0.490243199;
// regression
EXPECT(assert_equal(expected_error, error, 1e-9));
EXPECT(assert_equal(0.490243199, error, 1e-9));
double probs = exp(-error);
double expected_probs = graph.probPrime(delta.continuous(), delta.discrete());
double expected_probs = graph.probPrime(delta);
// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));