separate HybridGaussianFactorGraph::error() using both continuous and discrete values
parent
1b168cefba
commit
cb55af3a81
|
@ -483,6 +483,34 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||||
return error_tree;
|
return error_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double HybridGaussianFactorGraph::error(
|
||||||
|
const VectorValues &continuousValues,
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
double error = 0.0;
|
||||||
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
|
auto factor = factors_.at(idx);
|
||||||
|
|
||||||
|
if (factor->isHybrid()) {
|
||||||
|
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
error += c->asMixture()->error(continuousValues, discreteValues);
|
||||||
|
}
|
||||||
|
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
||||||
|
error += f->error(continuousValues, discreteValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (factor->isContinuous()) {
|
||||||
|
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||||
|
error += f->inner()->error(continuousValues);
|
||||||
|
}
|
||||||
|
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
error += cg->asGaussian()->error(continuousValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
|
@ -539,32 +567,11 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
double error = 0.0;
|
|
||||||
// Compute the error given the delta and the assignment.
|
// Compute the error given the delta and the assignment.
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
double error = this->error(delta, assignment);
|
||||||
auto factor = factors_.at(idx);
|
|
||||||
|
|
||||||
if (factor->isHybrid()) {
|
|
||||||
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
|
||||||
error += c->asMixture()->error(delta, assignment);
|
|
||||||
}
|
|
||||||
if (auto f =
|
|
||||||
boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
|
|
||||||
error += f->error(delta, assignment);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (factor->isContinuous()) {
|
|
||||||
if (auto f =
|
|
||||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
|
||||||
error += f->inner()->error(delta);
|
|
||||||
}
|
|
||||||
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
|
|
||||||
error += cg->asGaussian()->error(delta);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
probPrimes.push_back(exp(-error));
|
probPrimes.push_back(exp(-error));
|
||||||
}
|
}
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
|
||||||
return probPrimeTree;
|
return probPrimeTree;
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error given a continuous vector values
|
||||||
|
* and a discrete assignment.
|
||||||
|
*
|
||||||
|
* @param continuousValues The continuous VectorValues
|
||||||
|
* for computing the error.
|
||||||
|
* @param discreteValues The specific discrete assignment
|
||||||
|
* whose error we wish to compute.
|
||||||
|
* @return double
|
||||||
|
*/
|
||||||
|
double error(const VectorValues& continuousValues,
|
||||||
|
const DiscreteValues& discreteValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability for each discrete assignment,
|
* @brief Compute unnormalized probability for each discrete assignment,
|
||||||
* and return as a tree.
|
* and return as a tree.
|
||||||
|
|
|
@ -569,6 +569,31 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||||
|
|
||||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||||
|
|
||||||
|
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
graph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
double error = graph.error(delta.continuous(), delta.discrete());
|
||||||
|
|
||||||
|
double expected_error = 0.490243199;
|
||||||
|
// regression
|
||||||
|
EXPECT(assert_equal(expected_error, error, 1e-9));
|
||||||
|
|
||||||
|
double probs = exp(-error);
|
||||||
|
double expected_probs = exp(-expected_error);
|
||||||
|
|
||||||
|
// regression
|
||||||
|
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||||
|
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||||
|
|
||||||
Ordering hybridOrdering = graph.getHybridOrdering();
|
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
graph.eliminateSequential(hybridOrdering);
|
graph.eliminateSequential(hybridOrdering);
|
||||||
|
|
Loading…
Reference in New Issue