Merge branch 'hybrid/error' into hybrid/tests
commit
a97d27e981
|
|
@ -71,7 +71,7 @@ namespace gtsam {
|
|||
static inline double id(const double& x) { return x; }
|
||||
};
|
||||
|
||||
AlgebraicDecisionTree() : Base(1.0) {}
|
||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||
|
||||
// Explicitly non-explicit constructor
|
||||
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||
|
|
|
|||
|
|
@ -273,4 +273,10 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
return error_tree;
|
||||
}
|
||||
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
return error_tree.apply([](double error) { return exp(-error); });
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -144,6 +144,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -439,4 +439,53 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
|||
return ordering;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixtureFactor::shared_ptr gaussianMixture =
|
||||
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
|
||||
factor_error = gaussianMixture->error(continuousValues);
|
||||
|
||||
if (idx == 0) {
|
||||
error_tree = factor_error;
|
||||
} else {
|
||||
error_tree = error_tree + factor_error;
|
||||
}
|
||||
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
auto hybridGaussianFactor =
|
||||
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
|
||||
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
|
||||
|
||||
double error = gaussian->error(continuousValues);
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree =
|
||||
error_tree.apply([](double error) { return exp(-error); });
|
||||
return prob_tree;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -170,6 +170,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute error for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability for each discrete assignment,
|
||||
* and return as a tree.
|
||||
*
|
||||
* @param continuousValues Continuous values at which to compute the
|
||||
* probability.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> probPrime(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||
* eliminated after the continuous keys.
|
||||
|
|
|
|||
|
|
@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
|
|||
EXPECT(assert_equal(expected_discrete, result.discrete()));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||
Switching s(3);
|
||||
|
||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||
|
||||
Ordering hybridOrdering = graph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
graph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.error(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||
|
||||
auto probs = graph.probPrime(delta.continuous());
|
||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
||||
0.99029064};
|
||||
AlgebraicDecisionTree<Key> expected_probs(discrete_keys, prob_leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_probs, probs, 1e-7));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue