separate HybridGaussianFactorGraph::error() using both continuous and discrete values

release/4.3a0
Varun Agrawal 2022-11-08 14:20:51 -05:00
parent 1b168cefba
commit cb55af3a81
3 changed files with 68 additions and 23 deletions

View File

@ -483,6 +483,34 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree;
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
}
return error;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
@ -539,32 +567,11 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
continue;
}
double error = 0.0;
// Compute the error given the delta and the assignment.
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(delta, assignment);
}
if (auto f =
boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(delta, assignment);
}
} else if (factor->isContinuous()) {
if (auto f =
boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(delta);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(delta);
}
}
}
double error = this->error(delta, assignment);
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
return probPrimeTree;
}

View File

@ -180,6 +180,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
/**
* @brief Compute unnormalized probability for each discrete assignment,
* and return as a tree.

View File

@ -569,6 +569,31 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
double error = graph.error(delta.continuous(), delta.discrete());
double expected_error = 0.490243199;
// regression
EXPECT(assert_equal(expected_error, error, 1e-9));
double probs = exp(-error);
double expected_probs = exp(-expected_error);
// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));
}
/* ****************************************************************************/
// Test hybrid gaussian factor graph error and unnormalized probabilities
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
Switching s(3);
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);