Removed FactorAndConstant, no longer needed

release/4.3a0
Frank Dellaert 2023-01-12 22:34:34 -08:00
parent 1dcc6ddde9
commit 03ad393e12
7 changed files with 99 additions and 169 deletions

View File

@ -203,8 +203,8 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(given), 0.0};
return GaussianMixtureFactor::sharedFactor{
conditional->likelihood(given)};
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);

View File

@ -31,11 +31,8 @@ namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
}
/* *******************************************************************************/
@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
[&](const sharedFactor &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
@ -81,14 +76,9 @@ void GaussianMixtureFactor::print(const std::string &s,
}
/* *******************************************************************************/
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
const DiscreteValues &assignment) const {
return factors_(assignment).factor;
}
/* *******************************************************************************/
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
return factors_(assignment).constant;
return factors_(assignment);
}
/* *******************************************************************************/
@ -107,10 +97,10 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const FactorAndConstant &factor_z) {
auto wrap = [](const sharedFactor &gf) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return GraphAndConstant(result, factor_z.constant);
result.push_back(gf);
return GraphAndConstant(result, 0.0);
};
return {factors_, wrap};
}
@ -119,8 +109,8 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
@ -128,8 +118,8 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
/* *******************************************************************************/
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.error(values.continuous());
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/

View File

@ -39,7 +39,7 @@ class VectorValues;
* serves to "select" a mixture component corresponding to a GaussianFactor type
* of measurement.
*
* Represents the underlying Gaussian Mixture as a Decision Tree, where the set
* Represents the underlying Gaussian mixture as a Decision Tree, where the set
* of discrete variables indexes to the continuous gaussian distribution.
*
* @ingroup hybrid
@ -52,38 +52,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// Gaussian factor and log of normalizing constant.
struct FactorAndConstant {
sharedFactor factor;
double constant;
// Return error with constant correction.
double error(const VectorValues &values) const {
// Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood,
// hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant;
}
// Check pointer equality.
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
using Factors = DecisionTree<Key, sharedFactor>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
@ -105,7 +75,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
GaussianMixtureFactor() = default;
/**
* @brief Construct a new Gaussian Mixture Factor object.
* @brief Construct a new Gaussian mixture factor.
*
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys A vector of keys representing discrete variables and
@ -115,12 +85,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Mixture &factors);
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors_and_z)
: Base(continuousKeys, discreteKeys), factors_(factors_and_z) {}
const Factors &factors);
/**
* @brief Construct a new GaussianMixtureFactor object using a vector of
@ -134,7 +99,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
const DiscreteKeys &discreteKeys,
const std::vector<sharedFactor> &factors)
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Mixture(discreteKeys, factors)) {}
Factors(discreteKeys, factors)) {}
/// @}
/// @name Testable
@ -151,10 +116,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
/// @{
/// Get factor at a given discrete assignment.
sharedFactor factor(const DiscreteValues &assignment) const;
/// Get constant at a given discrete assignment.
double constant(const DiscreteValues &assignment) const;
sharedFactor operator()(const DiscreteValues &assignment) const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while

View File

@ -213,20 +213,20 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree sum = factors.assembleGraphTree();
GaussianFactorGraphTree factorGraphTree = factors.assembleGraphTree();
// Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which
// FG has a nullptr as we're looping over the factors.
sum = removeEmpty(sum);
factorGraphTree = removeEmpty(factorGraphTree);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
GaussianMixtureFactor::sharedFactor>;
// This is the elimination method on the leaf nodes
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}};
return {nullptr, nullptr};
}
#ifdef HYBRID_TIMING
@ -240,27 +240,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Get the log of the log normalization constant inverse and
// add it to the previous constant.
const double logZ =
graph_z.constant - conditional->logNormalizationConstant();
// Get the log of the log normalization constant inverse.
// double logZ = -conditional->logNormalizationConstant();
// // IF this is the last continuous variable to eliminated, we need to
// // calculate the error here: the value of all factors at the mean, see
// // ml_map_rao.pdf.
// if (continuousSeparator.empty()) {
// const auto posterior_mean = conditional->solve(VectorValues());
// logZ += graph_z.graph.error(posterior_mean);
// }
// const double logZ =
// graph_z.constant - conditional->logNormalizationConstant();
#ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate);
#endif
return {conditional, {newFactor, logZ}};
return {conditional, newFactor};
};
// Perform elimination!
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminateFunc);
DecisionTree<Key, EliminationPair> eliminationResults(factorGraphTree,
eliminateFunc);
#ifdef HYBRID_TIMING
tictoc_print_();
@ -279,26 +271,17 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// If there are no more continuous parents, then we should create a
// DiscreteFactor here, with the error for each discrete choice.
if (continuousSeparator.empty()) {
auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
// This is the probability q(μ) at the MLE point.
// factor_z.factor is a factor without keys,
// just containing the residual.
return exp(-factor_z.error(VectorValues()));
};
auto factorProb = [&](const EliminationPair &conditionalAndFactor) {
// This is the probability q(μ) at the MLE point.
// conditionalAndFactor.second is a factor without keys, just containing the residual.
static const VectorValues kEmpty;
// return exp(-conditionalAndFactor.first->logNormalizationConstant());
// return exp(-conditionalAndFactor.first->logNormalizationConstant() - conditionalAndFactor.second->error(kEmpty));
return exp( - conditionalAndFactor.second->error(kEmpty));
// return 1.0;
};
const DecisionTree<Key, double> fdt(newFactors, factorProb);
// // Normalize the values of decision tree to be valid probabilities
// double sum = 0.0;
// auto visitor = [&](double y) { sum += y; };
// fdt.visit(visitor);
// // Check if sum is 0, and update accordingly.
// if (sum == 0) {
// sum = 1.0;
// }
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
const DecisionTree<Key, double> fdt(eliminationResults, factorProb);
const auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
@ -375,6 +358,11 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// PREPROCESS: Identify the nature of the current elimination
// TODO(dellaert): just check the factors:
// 1. if all factors are discrete, then we can do discrete elimination:
// 2. if all factors are continuous, then we can do continuous elimination:
// 3. if not, we do hybrid elimination:
// First, identify the separator keys, i.e. all keys that are not frontal.
KeySet separatorKeys;
for (auto &&factor : factors) {

View File

@ -197,8 +197,7 @@ TEST(GaussianMixture, Likelihood) {
const GaussianMixtureFactor::Factors factors(
gm.conditionals(),
[measurements](const GaussianConditional::shared_ptr& conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(measurements), 0.0};
return conditional->likelihood(measurements);
});
const GaussianMixtureFactor expected({X(0)}, {mode}, factors);
EXPECT(assert_equal(expected, *factor));

View File

@ -34,6 +34,7 @@ using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
static const Key asiaKey = 0;
static const DiscreteKey Asia(asiaKey, 2);
@ -73,8 +74,12 @@ TEST(HybridBayesNet, EvaluatePureDiscrete) {
/* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) {
auto bayesNet = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bayesNet.size());
auto bn = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bn.size());
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(4, fg.size());
}
/* ****************************************************************************/

View File

@ -57,6 +57,9 @@ using gtsam::symbol_shorthand::X;
using gtsam::symbol_shorthand::Y;
using gtsam::symbol_shorthand::Z;
// Set up sampling
std::mt19937_64 kRng(42);
/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, Creation) {
HybridConditional conditional;
@ -638,24 +641,47 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expected{
M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
mixture->constant(d0)},
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
mixture->constant(d1)}};
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d0), prior}), 0.0},
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d1), prior}), 0.0}};
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
}
/* ****************************************************************************/
// Check that the factor graph unnormalized probability is proportional to the
// Bayes net probability for the given measurements.
bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
const HybridGaussianFactorGraph &fg, size_t num_samples = 10) {
auto compute_ratio = [&](HybridValues *sample) -> double {
sample->update(measurements); // update sample with given measurements:
return bn.evaluate(*sample) / fg.probPrime(*sample);
// return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
HybridValues sample = bn.sample(&kRng);
double expected_ratio = compute_ratio(&sample);
// Test ratios for a number of independent samples:
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&kRng);
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
}
return true;
}
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
const VectorValues measurements{{Z(0), Vector1(5.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(4, fg.size());
EXPECT(ratioTest(bn, measurements, fg));
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
@ -675,6 +701,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
// Test elimination
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
}
/* ****************************************************************************/
@ -683,9 +711,9 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements,
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}});
const VectorValues measurements{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}};
auto bn = tiny::createHybridBayesNet(num_measurements);
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(6, fg.size());
// Create expected Bayes Net:
@ -707,6 +735,8 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Test elimination
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
EXPECT(ratioTest(bn, measurements, *posterior));
}
/* ****************************************************************************/
@ -723,32 +753,12 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
auto fg = bn.toFactorGraph(measurements);
EXPECT_LONGS_EQUAL(7, fg.size());
EXPECT(ratioTest(bn, measurements, fg));
// Test elimination
const auto posterior = fg.eliminateSequential();
// Compute the log-ratio between the Bayes net and the factor graph.
auto compute_ratio = [&](HybridValues *sample) -> double {
// update sample with given measurements:
sample->update(measurements);
return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
// Set up sampling
std::mt19937_64 rng(42);
// The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant.
HybridValues sample = bn.sample(&rng);
double expected_ratio = compute_ratio(&sample);
// regression
EXPECT_DOUBLES_EQUAL(0.018253037966018862, expected_ratio, 1e-6);
// Test ratios for a number of independent samples:
constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&rng);
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
}
EXPECT(ratioTest(bn, measurements, *posterior));
}
/* ****************************************************************************/
@ -818,31 +828,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
// Test resulting posterior Bayes net has correct size:
EXPECT_LONGS_EQUAL(8, posterior->size());
// TODO(dellaert): below is copy/pasta from above, refactor
// Compute the log-ratio between the Bayes net and the factor graph.
auto compute_ratio = [&](HybridValues *sample) -> double {
// update sample with given measurements:
sample->update(measurements);
return bn.evaluate(*sample) / posterior->evaluate(*sample);
};
// Set up sampling
std::mt19937_64 rng(42);
// The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant.
HybridValues sample = bn.sample(&rng);
double expected_ratio = compute_ratio(&sample);
// regression
EXPECT_DOUBLES_EQUAL(0.0094526745785019472, expected_ratio, 1e-6);
// Test ratios for a number of independent samples:
constexpr int num_samples = 100;
for (size_t i = 0; i < num_samples; i++) {
HybridValues sample = bn.sample(&rng);
EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(&sample), 1e-6);
}
EXPECT(ratioTest(bn, measurements, *posterior));
}
/* ************************************************************************* */