Removed FactorAndConstant, no longer needed
parent
1dcc6ddde9
commit
03ad393e12
|
|
@ -203,8 +203,8 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const GaussianMixtureFactor::Factors likelihoods(
|
const GaussianMixtureFactor::Factors likelihoods(
|
||||||
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return GaussianMixtureFactor::FactorAndConstant{
|
return GaussianMixtureFactor::sharedFactor{
|
||||||
conditional->likelihood(given), 0.0};
|
conditional->likelihood(given)};
|
||||||
});
|
});
|
||||||
return boost::make_shared<GaussianMixtureFactor>(
|
return boost::make_shared<GaussianMixtureFactor>(
|
||||||
continuousParentKeys, discreteParentKeys, likelihoods);
|
continuousParentKeys, discreteParentKeys, likelihoods);
|
||||||
|
|
|
||||||
|
|
@ -31,11 +31,8 @@ namespace gtsam {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const Mixture &factors)
|
const Factors &factors)
|
||||||
: Base(continuousKeys, discreteKeys),
|
: Base(continuousKeys, discreteKeys), factors_(factors) {}
|
||||||
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
|
|
||||||
return FactorAndConstant{gf, 0.0};
|
|
||||||
}) {}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
|
|
@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return Base::equals(*e, tol) &&
|
return Base::equals(*e, tol) &&
|
||||||
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
|
factors_.equals(e->factors_,
|
||||||
const FactorAndConstant &f2) {
|
[tol](const sharedFactor &f1, const sharedFactor &f2) {
|
||||||
return f1.factor->equals(*(f2.factor), tol) &&
|
return f1->equals(*f2, tol);
|
||||||
std::abs(f1.constant - f2.constant) < tol;
|
});
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s,
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const FactorAndConstant &gf_z) -> std::string {
|
[&](const sharedFactor &gf) -> std::string {
|
||||||
auto gf = gf_z.factor;
|
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (gf && !gf->empty()) {
|
if (gf && !gf->empty()) {
|
||||||
|
|
@ -81,14 +76,9 @@ void GaussianMixtureFactor::print(const std::string &s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
|
GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
return factors_(assignment).factor;
|
return factors_(assignment);
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
|
|
||||||
return factors_(assignment).constant;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -107,10 +97,10 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
auto wrap = [](const sharedFactor &gf) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor_z.factor);
|
result.push_back(gf);
|
||||||
return GraphAndConstant(result, factor_z.constant);
|
return GraphAndConstant(result, 0.0);
|
||||||
};
|
};
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
@ -119,8 +109,8 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
|
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
|
||||||
return factor_z.error(continuousValues);
|
return gf->error(continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||||
return errorTree;
|
return errorTree;
|
||||||
|
|
@ -128,8 +118,8 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
double GaussianMixtureFactor::error(const HybridValues &values) const {
|
||||||
const FactorAndConstant factor_z = factors_(values.discrete());
|
const sharedFactor gf = factors_(values.discrete());
|
||||||
return factor_z.error(values.continuous());
|
return gf->error(values.continuous());
|
||||||
}
|
}
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class VectorValues;
|
||||||
* serves to "select" a mixture component corresponding to a GaussianFactor type
|
* serves to "select" a mixture component corresponding to a GaussianFactor type
|
||||||
* of measurement.
|
* of measurement.
|
||||||
*
|
*
|
||||||
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set
|
* Represents the underlying Gaussian mixture as a Decision Tree, where the set
|
||||||
* of discrete variables indexes to the continuous gaussian distribution.
|
* of discrete variables indexes to the continuous gaussian distribution.
|
||||||
*
|
*
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
|
|
@ -52,38 +52,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
|
|
||||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||||
|
|
||||||
/// Gaussian factor and log of normalizing constant.
|
|
||||||
struct FactorAndConstant {
|
|
||||||
sharedFactor factor;
|
|
||||||
double constant;
|
|
||||||
|
|
||||||
// Return error with constant correction.
|
|
||||||
double error(const VectorValues &values) const {
|
|
||||||
// Note: constant is log of normalization constant for probabilities.
|
|
||||||
// Errors is the negative log-likelihood,
|
|
||||||
// hence we subtract the constant here.
|
|
||||||
if (!factor) return 0.0; // If nullptr, return 0.0 error
|
|
||||||
return factor->error(values) - constant;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check pointer equality.
|
|
||||||
bool operator==(const FactorAndConstant &other) const {
|
|
||||||
return factor == other.factor && constant == other.constant;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/** Serialization function */
|
|
||||||
friend class boost::serialization::access;
|
|
||||||
template <class ARCHIVE>
|
|
||||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
|
||||||
ar &BOOST_SERIALIZATION_NVP(factor);
|
|
||||||
ar &BOOST_SERIALIZATION_NVP(constant);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
using Factors = DecisionTree<Key, sharedFactor>;
|
||||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
|
|
@ -105,7 +75,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
GaussianMixtureFactor() = default;
|
GaussianMixtureFactor() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Gaussian Mixture Factor object.
|
* @brief Construct a new Gaussian mixture factor.
|
||||||
*
|
*
|
||||||
* @param continuousKeys A vector of keys representing continuous variables.
|
* @param continuousKeys A vector of keys representing continuous variables.
|
||||||
* @param discreteKeys A vector of keys representing discrete variables and
|
* @param discreteKeys A vector of keys representing discrete variables and
|
||||||
|
|
@ -115,12 +85,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const Mixture &factors);
|
const Factors &factors);
|
||||||
|
|
||||||
GaussianMixtureFactor(const KeyVector &continuousKeys,
|
|
||||||
const DiscreteKeys &discreteKeys,
|
|
||||||
const Factors &factors_and_z)
|
|
||||||
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
* @brief Construct a new GaussianMixtureFactor object using a vector of
|
||||||
|
|
@ -134,7 +99,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const std::vector<sharedFactor> &factors)
|
const std::vector<sharedFactor> &factors)
|
||||||
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
: GaussianMixtureFactor(continuousKeys, discreteKeys,
|
||||||
Mixture(discreteKeys, factors)) {}
|
Factors(discreteKeys, factors)) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -151,10 +116,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Get factor at a given discrete assignment.
|
/// Get factor at a given discrete assignment.
|
||||||
sharedFactor factor(const DiscreteValues &assignment) const;
|
sharedFactor operator()(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/// Get constant at a given discrete assignment.
|
|
||||||
double constant(const DiscreteValues &assignment) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||||
|
|
|
||||||
|
|
@ -213,20 +213,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved.
|
// decision tree indexed by all discrete keys involved.
|
||||||
GaussianFactorGraphTree sum = factors.assembleGraphTree();
|
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
|
||||||
|
|
||||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||||
// This is done after assembly since it is non-trivial to keep track of which
|
// This is done after assembly since it is non-trivial to keep track of which
|
||||||
// FG has a nullptr as we're looping over the factors.
|
// FG has a nullptr as we're looping over the factors.
|
||||||
sum = removeEmpty(sum);
|
factorGraphTree = removeEmpty(factorGraphTree);
|
||||||
|
|
||||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
GaussianMixtureFactor::FactorAndConstant>;
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
||||||
if (graph_z.graph.empty()) {
|
if (graph_z.graph.empty()) {
|
||||||
return {nullptr, {nullptr, 0.0}};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
|
|
@ -240,27 +240,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// Get the log of the log normalization constant inverse and
|
// Get the log of the log normalization constant inverse and
|
||||||
// add it to the previous constant.
|
// add it to the previous constant.
|
||||||
const double logZ =
|
// const double logZ =
|
||||||
graph_z.constant - conditional->logNormalizationConstant();
|
// graph_z.constant - conditional->logNormalizationConstant();
|
||||||
// Get the log of the log normalization constant inverse.
|
|
||||||
// double logZ = -conditional->logNormalizationConstant();
|
|
||||||
// // IF this is the last continuous variable to eliminated, we need to
|
|
||||||
// // calculate the error here: the value of all factors at the mean, see
|
|
||||||
// // ml_map_rao.pdf.
|
|
||||||
// if (continuousSeparator.empty()) {
|
|
||||||
// const auto posterior_mean = conditional->solve(VectorValues());
|
|
||||||
// logZ += graph_z.graph.error(posterior_mean);
|
|
||||||
// }
|
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
gttoc_(hybrid_eliminate);
|
gttoc_(hybrid_eliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return {conditional, {newFactor, logZ}};
|
return {conditional, newFactor};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
|
DecisionTree<Key, EliminationPair> eliminationResults(factorGraphTree,
|
||||||
|
eliminateFunc);
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
tictoc_print_();
|
tictoc_print_();
|
||||||
|
|
@ -279,26 +271,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// If there are no more continuous parents, then we should create a
|
// If there are no more continuous parents, then we should create a
|
||||||
// DiscreteFactor here, with the error for each discrete choice.
|
// DiscreteFactor here, with the error for each discrete choice.
|
||||||
if (continuousSeparator.empty()) {
|
if (continuousSeparator.empty()) {
|
||||||
auto factorProb =
|
auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
|
||||||
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
|
// This is the probability q(μ) at the MLE point.
|
||||||
// This is the probability q(μ) at the MLE point.
|
// conditionalAndFactor.second is a factor without keys, just containing the residual.
|
||||||
// factor_z.factor is a factor without keys,
|
static const VectorValues kEmpty;
|
||||||
// just containing the residual.
|
// return exp(-conditionalAndFactor.first->logNormalizationConstant());
|
||||||
return exp(-factor_z.error(VectorValues()));
|
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
|
||||||
};
|
return exp( - conditionalAndFactor.second->error(kEmpty));
|
||||||
|
// return 1.0;
|
||||||
|
};
|
||||||
|
|
||||||
const DecisionTree<Key, double> fdt(newFactors, factorProb);
|
const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
|
||||||
// // Normalize the values of decision tree to be valid probabilities
|
|
||||||
// double sum = 0.0;
|
|
||||||
// auto visitor = [&](double y) { sum += y; };
|
|
||||||
// fdt.visit(visitor);
|
|
||||||
// // Check if sum is 0, and update accordingly.
|
|
||||||
// if (sum == 0) {
|
|
||||||
// sum = 1.0;
|
|
||||||
// }
|
|
||||||
// fdt = DecisionTree<Key, double>(fdt,
|
|
||||||
// [sum](const double &x) { return x / sum;
|
|
||||||
// });
|
|
||||||
const auto discreteFactor =
|
const auto discreteFactor =
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||||
|
|
||||||
|
|
@ -375,6 +358,11 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// PREPROCESS: Identify the nature of the current elimination
|
// PREPROCESS: Identify the nature of the current elimination
|
||||||
|
|
||||||
|
// TODO(dellaert): just check the factors:
|
||||||
|
// 1. if all factors are discrete, then we can do discrete elimination:
|
||||||
|
// 2. if all factors are continuous, then we can do continuous elimination:
|
||||||
|
// 3. if not, we do hybrid elimination:
|
||||||
|
|
||||||
// First, identify the separator keys, i.e. all keys that are not frontal.
|
// First, identify the separator keys, i.e. all keys that are not frontal.
|
||||||
KeySet separatorKeys;
|
KeySet separatorKeys;
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
|
|
|
||||||
|
|
@ -197,8 +197,7 @@ TEST(GaussianMixture, Likelihood) {
|
||||||
const GaussianMixtureFactor::Factors factors(
|
const GaussianMixtureFactor::Factors factors(
|
||||||
gm.conditionals(),
|
gm.conditionals(),
|
||||||
[measurements](const GaussianConditional::shared_ptr& conditional) {
|
[measurements](const GaussianConditional::shared_ptr& conditional) {
|
||||||
return GaussianMixtureFactor::FactorAndConstant{
|
return conditional->likelihood(measurements);
|
||||||
conditional->likelihood(measurements), 0.0};
|
|
||||||
});
|
});
|
||||||
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
|
||||||
EXPECT(assert_equal(expected, *factor));
|
EXPECT(assert_equal(expected, *factor));
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ using namespace gtsam;
|
||||||
using noiseModel::Isotropic;
|
using noiseModel::Isotropic;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
static const Key asiaKey = 0;
|
static const Key asiaKey = 0;
|
||||||
static const DiscreteKey Asia(asiaKey, 2);
|
static const DiscreteKey Asia(asiaKey, 2);
|
||||||
|
|
@ -73,8 +74,12 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test creation of a tiny hybrid Bayes net.
|
// Test creation of a tiny hybrid Bayes net.
|
||||||
TEST(HybridBayesNet, Tiny) {
|
TEST(HybridBayesNet, Tiny) {
|
||||||
auto bayesNet = tiny::createHybridBayesNet();
|
auto bn = tiny::createHybridBayesNet();
|
||||||
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
EXPECT_LONGS_EQUAL(3, bn.size());
|
||||||
|
|
||||||
|
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||||
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
|
EXPECT_LONGS_EQUAL(4, fg.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,9 @@ using gtsam::symbol_shorthand::X;
|
||||||
using gtsam::symbol_shorthand::Y;
|
using gtsam::symbol_shorthand::Y;
|
||||||
using gtsam::symbol_shorthand::Z;
|
using gtsam::symbol_shorthand::Z;
|
||||||
|
|
||||||
|
// Set up sampling
|
||||||
|
std::mt19937_64 kRng(42);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, Creation) {
|
TEST(HybridGaussianFactorGraph, Creation) {
|
||||||
HybridConditional conditional;
|
HybridConditional conditional;
|
||||||
|
|
@ -638,24 +641,47 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
||||||
GaussianFactorGraphTree expected{
|
GaussianFactorGraphTree expected{
|
||||||
M(0),
|
M(0),
|
||||||
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
|
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d0), prior}), 0.0},
|
||||||
mixture->constant(d0)},
|
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d1), prior}), 0.0}};
|
||||||
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
|
|
||||||
mixture->constant(d1)}};
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
||||||
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Check that the factor graph unnormalized probability is proportional to the
|
||||||
|
// Bayes net probability for the given measurements.
|
||||||
|
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||||
|
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
|
||||||
|
auto compute_ratio = [&](HybridValues *sample) -> double {
|
||||||
|
sample->update(measurements); // update sample with given measurements:
|
||||||
|
return bn.evaluate(*sample) / fg.probPrime(*sample);
|
||||||
|
// return bn.evaluate(*sample) / posterior->evaluate(*sample);
|
||||||
|
};
|
||||||
|
|
||||||
|
HybridValues sample = bn.sample(&kRng);
|
||||||
|
double expected_ratio = compute_ratio(&sample);
|
||||||
|
|
||||||
|
// Test ratios for a number of independent samples:
|
||||||
|
for (size_t i = 0; i < num_samples; i++) {
|
||||||
|
HybridValues sample = bn.sample(&kRng);
|
||||||
|
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that eliminating tiny net with 1 measurement yields correct result.
|
// Check that eliminating tiny net with 1 measurement yields correct result.
|
||||||
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
using symbol_shorthand::Z;
|
using symbol_shorthand::Z;
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||||
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
EXPECT_LONGS_EQUAL(4, fg.size());
|
EXPECT_LONGS_EQUAL(4, fg.size());
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, fg));
|
||||||
|
|
||||||
// Create expected Bayes Net:
|
// Create expected Bayes Net:
|
||||||
HybridBayesNet expectedBayesNet;
|
HybridBayesNet expectedBayesNet;
|
||||||
|
|
||||||
|
|
@ -675,6 +701,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -683,9 +711,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
// Create factor graph with 2 measurements such that posterior mean = 5.0.
|
||||||
using symbol_shorthand::Z;
|
using symbol_shorthand::Z;
|
||||||
const int num_measurements = 2;
|
const int num_measurements = 2;
|
||||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
|
||||||
num_measurements,
|
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||||
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}});
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
EXPECT_LONGS_EQUAL(6, fg.size());
|
EXPECT_LONGS_EQUAL(6, fg.size());
|
||||||
|
|
||||||
// Create expected Bayes Net:
|
// Create expected Bayes Net:
|
||||||
|
|
@ -707,6 +735,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -723,32 +753,12 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
|
||||||
auto fg = bn.toFactorGraph(measurements);
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
EXPECT_LONGS_EQUAL(7, fg.size());
|
EXPECT_LONGS_EQUAL(7, fg.size());
|
||||||
|
|
||||||
|
EXPECT(ratioTest(bn, measurements, fg));
|
||||||
|
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
|
|
||||||
// Compute the log-ratio between the Bayes net and the factor graph.
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
auto compute_ratio = [&](HybridValues *sample) -> double {
|
|
||||||
// update sample with given measurements:
|
|
||||||
sample->update(measurements);
|
|
||||||
return bn.evaluate(*sample) / posterior->evaluate(*sample);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set up sampling
|
|
||||||
std::mt19937_64 rng(42);
|
|
||||||
|
|
||||||
// The error evaluated by the factor graph and the Bayes net should differ by
|
|
||||||
// the normalizing term computed via the Bayes net determinant.
|
|
||||||
HybridValues sample = bn.sample(&rng);
|
|
||||||
double expected_ratio = compute_ratio(&sample);
|
|
||||||
// regression
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6);
|
|
||||||
|
|
||||||
// Test ratios for a number of independent samples:
|
|
||||||
constexpr int num_samples = 100;
|
|
||||||
for (size_t i = 0; i < num_samples; i++) {
|
|
||||||
HybridValues sample = bn.sample(&rng);
|
|
||||||
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -818,31 +828,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
// Test resulting posterior Bayes net has correct size:
|
// Test resulting posterior Bayes net has correct size:
|
||||||
EXPECT_LONGS_EQUAL(8, posterior->size());
|
EXPECT_LONGS_EQUAL(8, posterior->size());
|
||||||
|
|
||||||
// TODO(dellaert): below is copy/pasta from above, refactor
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
|
|
||||||
// Compute the log-ratio between the Bayes net and the factor graph.
|
|
||||||
auto compute_ratio = [&](HybridValues *sample) -> double {
|
|
||||||
// update sample with given measurements:
|
|
||||||
sample->update(measurements);
|
|
||||||
return bn.evaluate(*sample) / posterior->evaluate(*sample);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set up sampling
|
|
||||||
std::mt19937_64 rng(42);
|
|
||||||
|
|
||||||
// The error evaluated by the factor graph and the Bayes net should differ by
|
|
||||||
// the normalizing term computed via the Bayes net determinant.
|
|
||||||
HybridValues sample = bn.sample(&rng);
|
|
||||||
double expected_ratio = compute_ratio(&sample);
|
|
||||||
// regression
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6);
|
|
||||||
|
|
||||||
// Test ratios for a number of independent samples:
|
|
||||||
constexpr int num_samples = 100;
|
|
||||||
for (size_t i = 0; i < num_samples; i++) {
|
|
||||||
HybridValues sample = bn.sample(&rng);
|
|
||||||
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue