Merge branch 'hybrid/error' into hybrid/tests

release/4.3a0
Varun Agrawal 2022-11-07 15:43:24 -05:00
commit a97d27e981
6 changed files with 118 additions and 2 deletions

View File

@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; }
};
AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}

View File

@ -273,4 +273,10 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
return error_tree;
}
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
}
} // namespace gtsam

View File

@ -144,6 +144,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
/**
* @brief Compute unnormalized probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const;
/// @}
private:

View File

@ -439,4 +439,53 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
return ordering;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
factor_error = gaussianMixture->error(continuousValues);
if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
double error = gaussian->error(continuousValues);
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}
return error_tree;
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
return prob_tree;
}
} // namespace gtsam

View File

@ -170,6 +170,26 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}
/**
* @brief Compute error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
/**
* @brief Compute unnormalized probability for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;
/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.

View File

@ -562,6 +562,36 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
EXPECT(assert_equal(expected_discrete, result.discrete()));
}
/* ****************************************************************************/
// Test hybrid gaussian factor graph error and unnormalized probabilities
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
Switching s(3);
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
auto probs = graph.probPrime(delta.continuous());
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
0.99029064};
AlgebraicDecisionTree<Key> expected_probs(discrete_keys, prob_leaves);
// regression
EXPECT(assert_equal(expected_probs, probs, 1e-7));
}
/* ************************************************************************* */
int main() {
TestResult tr;