Merge pull request #1836 from borglab/improved-api
commit
017044eabd
|
|
@ -28,23 +28,37 @@
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
HybridGaussianFactor::FactorValuePairs GetFactorValuePairs(
|
||||||
|
const HybridGaussianConditional::Conditionals &conditionals) {
|
||||||
|
auto func = [](const GaussianConditional::shared_ptr &conditional)
|
||||||
|
-> GaussianFactorValuePair {
|
||||||
|
double value = 0.0;
|
||||||
|
// Check if conditional is pruned
|
||||||
|
if (conditional) {
|
||||||
|
// Assign log(\sqrt(|2πΣ|)) = -log(1 / sqrt(|2πΣ|))
|
||||||
|
value = -conditional->logNormalizationConstant();
|
||||||
|
}
|
||||||
|
return {std::dynamic_pointer_cast<GaussianFactor>(conditional), value};
|
||||||
|
};
|
||||||
|
return HybridGaussianFactor::FactorValuePairs(conditionals, func);
|
||||||
|
}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const HybridGaussianConditional::Conditionals &conditionals)
|
const HybridGaussianConditional::Conditionals &conditionals)
|
||||||
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
|
||||||
discreteParents),
|
discreteParents, GetFactorValuePairs(conditionals)),
|
||||||
BaseConditional(continuousFrontals.size()),
|
BaseConditional(continuousFrontals.size()),
|
||||||
conditionals_(conditionals) {
|
conditionals_(conditionals) {
|
||||||
// Calculate logConstant_ as the maximum of the log constants of the
|
// Calculate logConstant_ as the minimum of the log normalizers of the
|
||||||
// conditionals, by visiting the decision tree:
|
// conditionals, by visiting the decision tree:
|
||||||
logConstant_ = -std::numeric_limits<double>::infinity();
|
logConstant_ = std::numeric_limits<double>::infinity();
|
||||||
conditionals_.visit(
|
conditionals_.visit(
|
||||||
[this](const GaussianConditional::shared_ptr &conditional) {
|
[this](const GaussianConditional::shared_ptr &conditional) {
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
this->logConstant_ = std::max(
|
this->logConstant_ = std::min(
|
||||||
this->logConstant_, conditional->logNormalizationConstant());
|
this->logConstant_, -conditional->logNormalizationConstant());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -64,21 +78,6 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
DiscreteKeys{discreteParent},
|
DiscreteKeys{discreteParent},
|
||||||
Conditionals({discreteParent}, conditionals)) {}
|
Conditionals({discreteParent}, conditionals)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
|
|
||||||
// derived from HybridGaussianFactor, no?
|
|
||||||
GaussianFactorGraphTree HybridGaussianConditional::add(
|
|
||||||
const GaussianFactorGraphTree &sum) const {
|
|
||||||
using Y = GaussianFactorGraph;
|
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
|
||||||
auto result = graph1;
|
|
||||||
result.push_back(graph2);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
const auto tree = asGaussianFactorGraphTree();
|
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
|
|
@ -86,7 +85,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
||||||
// First check if conditional has not been pruned
|
// First check if conditional has not been pruned
|
||||||
if (gc) {
|
if (gc) {
|
||||||
const double Cgm_Kgcm =
|
const double Cgm_Kgcm =
|
||||||
this->logConstant_ - gc->logNormalizationConstant();
|
-this->logConstant_ - gc->logNormalizationConstant();
|
||||||
// If there is a difference in the covariances, we need to account for
|
// If there is a difference in the covariances, we need to account for
|
||||||
// that since the error is dependent on the mode.
|
// that since the error is dependent on the mode.
|
||||||
if (Cgm_Kgcm > 0.0) {
|
if (Cgm_Kgcm > 0.0) {
|
||||||
|
|
@ -157,7 +156,8 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||||
}
|
}
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << logConstant_ << std::endl
|
<< " logNormalizationConstant: " << logNormalizationConstant()
|
||||||
|
<< std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
conditionals_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
|
|
@ -216,7 +216,7 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
-> GaussianFactorValuePair {
|
-> GaussianFactorValuePair {
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
const double Cgm_Kgcm =
|
const double Cgm_Kgcm =
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
-logConstant_ - conditional->logNormalizationConstant();
|
||||||
if (Cgm_Kgcm == 0.0) {
|
if (Cgm_Kgcm == 0.0) {
|
||||||
return {likelihood_m, 0.0};
|
return {likelihood_m, 0.0};
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -330,7 +330,7 @@ double HybridGaussianConditional::conditionalError(
|
||||||
// Check if valid pointer
|
// Check if valid pointer
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
return conditional->error(continuousValues) + //
|
return conditional->error(continuousValues) + //
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
-logConstant_ - conditional->logNormalizationConstant();
|
||||||
} else {
|
} else {
|
||||||
// If not valid, pointer, it means this conditional was pruned,
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
// so we return maximum error.
|
// so we return maximum error.
|
||||||
|
|
|
||||||
|
|
@ -51,20 +51,22 @@ class HybridValues;
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridGaussianConditional
|
class GTSAM_EXPORT HybridGaussianConditional
|
||||||
: public HybridFactor,
|
: public HybridGaussianFactor,
|
||||||
public Conditional<HybridFactor, HybridGaussianConditional> {
|
public Conditional<HybridGaussianFactor, HybridGaussianConditional> {
|
||||||
public:
|
public:
|
||||||
using This = HybridGaussianConditional;
|
using This = HybridGaussianConditional;
|
||||||
using shared_ptr = std::shared_ptr<HybridGaussianConditional>;
|
using shared_ptr = std::shared_ptr<This>;
|
||||||
using BaseFactor = HybridFactor;
|
using BaseFactor = HybridGaussianFactor;
|
||||||
using BaseConditional = Conditional<HybridFactor, HybridGaussianConditional>;
|
using BaseConditional = Conditional<BaseFactor, HybridGaussianConditional>;
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Conditionals
|
/// typedef for Decision Tree of Gaussian Conditionals
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
||||||
double logConstant_; ///< log of the normalization constant.
|
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
||||||
|
///< Take advantage of the neg-log space so everything is a minimization
|
||||||
|
double logConstant_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a HybridGaussianConditional of conditionals into
|
* @brief Convert a HybridGaussianConditional of conditionals into
|
||||||
|
|
@ -107,8 +109,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
const Conditionals &conditionals);
|
const Conditionals &conditionals);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian conditionals.
|
* @brief Make a Hybrid Gaussian Conditional from a vector of Gaussian
|
||||||
* The DecisionTree-based constructor is preferred over this one.
|
* conditionals. The DecisionTree-based constructor is preferred over this
|
||||||
|
* one.
|
||||||
*
|
*
|
||||||
* @param continuousFrontals The continuous frontal variables
|
* @param continuousFrontals The continuous frontal variables
|
||||||
* @param continuousParents The continuous parent variables
|
* @param continuousParents The continuous parent variables
|
||||||
|
|
@ -149,7 +152,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
/// The log normalization constant is max of the the individual
|
/// The log normalization constant is max of the the individual
|
||||||
/// log-normalization constants.
|
/// log-normalization constants.
|
||||||
double logNormalizationConstant() const override { return logConstant_; }
|
double logNormalizationConstant() const override { return -logConstant_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a likelihood factor for a hybrid Gaussian conditional,
|
* Create a likelihood factor for a hybrid Gaussian conditional,
|
||||||
|
|
@ -232,14 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
*/
|
*/
|
||||||
void prune(const DecisionTreeFactor &discreteProbs);
|
void prune(const DecisionTreeFactor &discreteProbs);
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
|
||||||
* maintaining the decision tree structure.
|
|
||||||
*
|
|
||||||
* @param sum Decision Tree of Gaussian Factor Graphs
|
|
||||||
* @return GaussianFactorGraphTree
|
|
||||||
*/
|
|
||||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,7 @@ TEST(HybridGaussianConditional, Error) {
|
||||||
auto actual = hybrid_conditional.errorTree(vv);
|
auto actual = hybrid_conditional.errorTree(vv);
|
||||||
|
|
||||||
// Check result.
|
// Check result.
|
||||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
DiscreteKeys discrete_keys{mode};
|
||||||
std::vector<double> leaves = {conditionals[0]->error(vv),
|
std::vector<double> leaves = {conditionals[0]->error(vv),
|
||||||
conditionals[1]->error(vv)};
|
conditionals[1]->error(vv)};
|
||||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||||
|
|
@ -172,6 +172,37 @@ TEST(HybridGaussianConditional, ContinuousParents) {
|
||||||
EXPECT(continuousParentKeys[0] == X(0));
|
EXPECT(continuousParentKeys[0] == X(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/// Check error with mode dependent constants.
|
||||||
|
TEST(HybridGaussianConditional, Error2) {
|
||||||
|
using namespace mode_dependent_constants;
|
||||||
|
auto actual = hybrid_conditional.errorTree(vv);
|
||||||
|
|
||||||
|
// Check result.
|
||||||
|
DiscreteKeys discrete_keys{mode};
|
||||||
|
double logNormalizer0 = -conditionals[0]->logNormalizationConstant();
|
||||||
|
double logNormalizer1 = -conditionals[1]->logNormalizationConstant();
|
||||||
|
double minLogNormalizer = std::min(logNormalizer0, logNormalizer1);
|
||||||
|
|
||||||
|
// Expected error is e(X) + log(|2πΣ|).
|
||||||
|
// We normalize log(|2πΣ|) with min(logNormalizers) so it is non-negative.
|
||||||
|
std::vector<double> leaves = {
|
||||||
|
conditionals[0]->error(vv) + logNormalizer0 - minLogNormalizer,
|
||||||
|
conditionals[1]->error(vv) + logNormalizer1 - minLogNormalizer};
|
||||||
|
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-6));
|
||||||
|
|
||||||
|
// Check for non-tree version.
|
||||||
|
for (size_t mode : {0, 1}) {
|
||||||
|
const HybridValues hv{vv, {{M(0), mode}}};
|
||||||
|
EXPECT_DOUBLES_EQUAL(conditionals[mode]->error(vv) -
|
||||||
|
conditionals[mode]->logNormalizationConstant() -
|
||||||
|
minLogNormalizer,
|
||||||
|
hybrid_conditional.error(hv), 1e-8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/// Check that the likelihood is proportional to the conditional density given
|
/// Check that the likelihood is proportional to the conditional density given
|
||||||
/// the measurements.
|
/// the measurements.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue