Merge pull request #1837 from borglab/improved-api-2

release/4.3a0
Varun Agrawal 2024-09-22 19:35:11 -04:00 committed by GitHub
commit e52973b72d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 168 additions and 118 deletions

View File

@ -129,6 +129,22 @@ double HybridConditional::error(const HybridValues &values) const {
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return AlgebraicDecisionTree<Key>(gc->error(values));
}
if (auto gm = asHybrid()) {
return gm->errorTree(values);
}
if (auto dc = asDiscrete()) {
return AlgebraicDecisionTree<Key>(0.0);
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {

View File

@ -179,6 +179,16 @@ class GTSAM_EXPORT HybridConditional
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;
/**
* @brief Compute error of the HybridConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& values) const override;
/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

View File

@ -136,6 +136,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
/// @}
private:

View File

@ -323,40 +323,6 @@ AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc);
}
/* ************************************************************************* */
double HybridGaussianConditional::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
-logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianConditional::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}
/* *******************************************************************************/
double HybridGaussianConditional::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditionalError(conditional, values.continuous());
}
/* *******************************************************************************/
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {

View File

@ -109,9 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
const Conditionals &conditionals);
/**
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
* conditionals. The DecisionTree-based constructor is preferred over this
* one.
* @brief Make a Hybrid Gaussian Conditional from
* a vector of Gaussian conditionals.
* The DecisionTree-based constructor is preferred over this one.
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
@ -174,43 +174,6 @@ class GTSAM_EXPORT HybridGaussianConditional
AlgebraicDecisionTree<Key> logProbability(
const VectorValues &continuousValues) const;
/**
* @brief Compute the error of this hybrid Gaussian conditional.
*
* This requires some care, as different components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;
/**
* @brief Compute error of the HybridGaussianConditional as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
*
@ -241,10 +204,6 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;
/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;

View File

@ -151,12 +151,26 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}
/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
@ -164,8 +178,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
return potentiallyPrunedComponentError(gf, values.continuous());
}
} // namespace gtsam

View File

@ -166,7 +166,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
const VectorValues &continuousValues) const override;
/**
* @brief Compute the log-likelihood, including the log-normalizing constant.
@ -186,6 +186,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @}
private:
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;

View File

@ -329,8 +329,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
// Logspace version of:
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
// We take negative of the logNormalizationConstant `log(1/k)`
// to get `log(k)`.
// We take negative of the logNormalizationConstant `log(k)`
// to get `log(1/k) = log(\sqrt{|2πΣ|})`.
return -factor->error(kEmpty) - conditional->logNormalizationConstant();
};
@ -539,36 +539,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (auto &factor : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;
auto f = factor;
if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
f = hc->inner();
}
if (auto hybridGaussianCond =
dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Compute factor error and add it.
error_tree = error_tree + hybridGaussianCond->errorTree(continuousValues);
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
// If factor at `idx` is discrete-only, we skip.
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
// Check for HybridFactor, and call errorTree
error_tree = error_tree + f->errorTree(continuousValues);
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
// Skip discrete factors
continue;
} else {
throwRuntimeError("HybridGaussianFactorGraph::error(VV)", f);
// Everything else is a continuous only factor
HybridValues hv(continuousValues, DiscreteValues());
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
}
}
return error_tree;
}

View File

@ -74,6 +74,13 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;
/// HybridFactor method implementation. Should not be used.
AlgebraicDecisionTree<Key> errorTree(
const VectorValues& continuousValues) const override {
throw std::runtime_error(
"HybridNonlinearFactor::error does not take VectorValues.");
}
public:
HybridNonlinearFactor() = default;

View File

@ -86,6 +86,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
*/
std::shared_ptr<HybridGaussianFactorGraph> linearize(
const Values& continuousValues) const;
/// Expose error(const HybridValues&) method.
using Base::error;
/// @}
};

View File

@ -678,6 +678,55 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
EXPECT(assert_equal(expected, errorTree, 1e-9));
}
/* ****************************************************************************/
// Test hybrid gaussian factor graph errorTree during
// incremental operation
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
Switching s(4);
HybridGaussianFactorGraph graph;
graph.push_back(s.linearizedFactorGraph.at(0)); // f(X0)
graph.push_back(s.linearizedFactorGraph.at(1)); // f(X0, X1, M0)
graph.push_back(s.linearizedFactorGraph.at(2)); // f(X1, X2, M1)
graph.push_back(s.linearizedFactorGraph.at(4)); // f(X1)
graph.push_back(s.linearizedFactorGraph.at(5)); // f(X2)
graph.push_back(s.linearizedFactorGraph.at(7)); // f(M0)
graph.push_back(s.linearizedFactorGraph.at(8)); // f(M0, M1)
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.errorTree(delta.continuous());
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
0.0097568009};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
graph = HybridGaussianFactorGraph();
graph.push_back(*hybridBayesNet);
graph.push_back(s.linearizedFactorGraph.at(3)); // f(X2, X3, M2)
graph.push_back(s.linearizedFactorGraph.at(6)); // f(X3)
hybridBayesNet = graph.eliminateSequential();
EXPECT_LONGS_EQUAL(7, hybridBayesNet->size());
delta = hybridBayesNet->optimize();
auto error_tree2 = graph.errorTree(delta.continuous());
discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
}
/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.

View File

@ -51,7 +51,7 @@ using symbol_shorthand::X;
* Test that any linearizedFactorGraph gaussian factors are appended to the
* existing gaussian factor graph in the hybrid factor graph.
*/
TEST(HybridFactorGraph, GaussianFactorGraph) {
TEST(HybridNonlinearFactorGraph, GaussianFactorGraph) {
HybridNonlinearFactorGraph fg;
// Add a simple prior factor to the nonlinear factor graph
@ -181,7 +181,7 @@ TEST(HybridGaussianFactorGraph, HybridNonlinearFactor) {
/*****************************************************************************
* Test push_back on HFG makes the correct distinction.
*/
TEST(HybridFactorGraph, PushBack) {
TEST(HybridNonlinearFactorGraph, PushBack) {
HybridNonlinearFactorGraph fg;
auto nonlinearFactor = std::make_shared<BetweenFactor<double>>();
@ -240,7 +240,7 @@ TEST(HybridFactorGraph, PushBack) {
/****************************************************************************
* Test construction of switching-like hybrid factor graph.
*/
TEST(HybridFactorGraph, Switching) {
TEST(HybridNonlinearFactorGraph, Switching) {
Switching self(3);
EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph.size());
@ -250,7 +250,7 @@ TEST(HybridFactorGraph, Switching) {
/****************************************************************************
* Test linearization on a switching-like hybrid factor graph.
*/
TEST(HybridFactorGraph, Linearization) {
TEST(HybridNonlinearFactorGraph, Linearization) {
Switching self(3);
// Linearize here:
@ -263,7 +263,7 @@ TEST(HybridFactorGraph, Linearization) {
/****************************************************************************
* Test elimination tree construction
*/
TEST(HybridFactorGraph, EliminationTree) {
TEST(HybridNonlinearFactorGraph, EliminationTree) {
Switching self(3);
// Create ordering.
@ -372,7 +372,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
/****************************************************************************
* Test partial elimination
*/
TEST(HybridFactorGraph, Partial_Elimination) {
TEST(HybridNonlinearFactorGraph, Partial_Elimination) {
Switching self(3);
auto linearizedFactorGraph = self.linearizedFactorGraph;
@ -401,7 +401,39 @@ TEST(HybridFactorGraph, Partial_Elimination) {
EXPECT(remainingFactorGraph->at(2)->keys() == KeyVector({M(0), M(1)}));
}
TEST(HybridFactorGraph, PrintErrors) {
/* ****************************************************************************/
TEST(HybridNonlinearFactorGraph, Error) {
Switching self(3);
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph;
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(152.791759469, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 1}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.598612289, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 0}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.703972804, fg.error(values), 1e-9);
}
{
HybridValues values(VectorValues(), DiscreteValues{{M(0), 1}, {M(1), 1}},
self.linearizationPoint);
// regression
EXPECT_DOUBLES_EQUAL(151.609437912, fg.error(values), 1e-9);
}
}
/* ****************************************************************************/
TEST(HybridNonlinearFactorGraph, PrintErrors) {
Switching self(3);
// Get nonlinear factor graph and add linear factors to be holistic
@ -424,7 +456,7 @@ TEST(HybridFactorGraph, PrintErrors) {
/****************************************************************************
* Test full elimination
*/
TEST(HybridFactorGraph, Full_Elimination) {
TEST(HybridNonlinearFactorGraph, Full_Elimination) {
Switching self(3);
auto linearizedFactorGraph = self.linearizedFactorGraph;
@ -492,7 +524,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
/****************************************************************************
* Test printing
*/
TEST(HybridFactorGraph, Printing) {
TEST(HybridNonlinearFactorGraph, Printing) {
Switching self(3);
auto linearizedFactorGraph = self.linearizedFactorGraph;
@ -784,7 +816,7 @@ conditional 2: Hybrid P( x2 | m0 m1)
* The issue arises if we eliminate a landmark variable first since it is not
* connected to a HybridFactor.
*/
TEST(HybridFactorGraph, DefaultDecisionTree) {
TEST(HybridNonlinearFactorGraph, DefaultDecisionTree) {
HybridNonlinearFactorGraph fg;
// Add a prior on pose x0 at the origin.