diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 892a07d2d..329044aca 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -53,11 +53,11 @@ GaussianMixture::GaussianMixture( /* *******************************************************************************/ GaussianMixture::Sum GaussianMixture::add( const GaussianMixture::Sum &sum) const { - using Y = GaussianFactorGraph; + using Y = GaussianMixtureFactor::GraphAndConstant; auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; + auto result = graph1.graph; + result.push_back(graph2.graph); + return Y(result, graph1.constant + graph2.constant); }; const Sum tree = asGaussianFactorGraphTree(); return sum.empty() ? tree : sum.apply(tree, add); @@ -65,10 +65,15 @@ GaussianMixture::Sum GaussianMixture::add( /* *******************************************************************************/ GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { - auto lambda = [](const GaussianFactor::shared_ptr &factor) { + auto lambda = [](const GaussianConditional::shared_ptr &conditional) { GaussianFactorGraph result; - result.push_back(factor); - return result; + result.push_back(conditional); + if (conditional) { + return GaussianMixtureFactor::GraphAndConstant( + result, conditional->logNormalizationConstant()); + } else { + return GaussianMixtureFactor::GraphAndConstant(result, 0.0); + } }; return {conditionals_, lambda}; } diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a9b05f250..e8ba78d61 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -23,13 +23,13 @@ #include #include #include +#include #include #include #include namespace gtsam { -class GaussianMixtureFactor; class HybridValues; /** @@ -60,7 +60,7 @@ class GTSAM_EXPORT GaussianMixture using BaseConditional = Conditional; /// Alias for DecisionTree of GaussianFactorGraphs - using Sum = DecisionTree; + using Sum = DecisionTree; /// typedef for Decision Tree of Gaussian Conditionals using Conditionals = DecisionTree; diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index e60368717..fb931c5e5 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -90,11 +90,11 @@ const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { /* *******************************************************************************/ GaussianMixtureFactor::Sum GaussianMixtureFactor::add( const GaussianMixtureFactor::Sum &sum) const { - using Y = GaussianFactorGraph; + using Y = GaussianMixtureFactor::GraphAndConstant; auto add = [](const Y &graph1, const Y &graph2) { - auto result = graph1; - result.push_back(graph2); - return result; + auto result = graph1.graph; + result.push_back(graph2.graph); + return Y(result, graph1.constant + graph2.constant); }; const Sum tree = asGaussianFactorGraphTree(); return sum.empty() ? tree : sum.apply(tree, add); @@ -106,7 +106,7 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() auto wrap = [](const FactorAndConstant &factor_z) { GaussianFactorGraph result; result.push_back(factor_z.factor); - return result; + return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant); }; return {factors_, wrap}; } diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index ce011fecc..f0f8d060d 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -25,10 +25,10 @@ #include #include #include +#include namespace gtsam { -class GaussianFactorGraph; class HybridValues; class DiscreteValues; class VectorValues; @@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { using This = GaussianMixtureFactor; using shared_ptr = boost::shared_ptr; - using Sum = DecisionTree; using sharedFactor = boost::shared_ptr; /// Gaussian factor and log of normalizing constant. @@ -60,8 +59,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { // Return error with constant correction. double error(const VectorValues &values) const { - // Note minus sign: constant is log of normalization constant for probabilities. - // Errors is the negative log-likelihood, hence we subtract the constant here. + // Note: constant is log of normalization constant for probabilities. + // Errors is the negative log-likelihood, + // hence we subtract the constant here. return factor->error(values) - constant; } @@ -71,6 +71,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { } }; + /// Gaussian factor graph and log of normalizing constant. + struct GraphAndConstant { + GaussianFactorGraph graph; + double constant; + + GraphAndConstant(const GaussianFactorGraph &graph, double constant) + : graph(graph), constant(constant) {} + + // Check pointer equality. + bool operator==(const GraphAndConstant &other) const { + return graph == other.graph && constant == other.constant; + } + }; + + using Sum = DecisionTree; + /// typedef for Decision Tree of Gaussian factors and log-constant. using Factors = DecisionTree; using Mixture = DecisionTree; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2170e7b68..10e12bf77 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -61,18 +61,19 @@ template class EliminateableFactorGraph; /* ************************************************************************ */ static GaussianMixtureFactor::Sum &addGaussian( GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { - using Y = GaussianFactorGraph; // If the decision tree is not initialized, then initialize it. if (sum.empty()) { GaussianFactorGraph result; result.push_back(factor); - sum = GaussianMixtureFactor::Sum(result); + sum = GaussianMixtureFactor::Sum( + GaussianMixtureFactor::GraphAndConstant(result, 0.0)); } else { - auto add = [&factor](const Y &graph) { - auto result = graph; + auto add = [&factor]( + const GaussianMixtureFactor::GraphAndConstant &graph_z) { + auto result = graph_z.graph; result.push_back(factor); - return result; + return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant); }; sum = sum.apply(add); } @@ -190,31 +191,36 @@ hybridElimination(const HybridGaussianFactorGraph &factors, DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), discreteSeparatorSet.end()); - // sum out frontals, this is the factor 𝜏 on the separator + // Collect all the frontal factors to create Gaussian factor graphs + // indexed on the discrete keys. GaussianMixtureFactor::Sum sum = sumFrontals(factors); // If a tree leaf contains nullptr, // convert that leaf to an empty GaussianFactorGraph. // Needed since the DecisionTree will otherwise create // a GFG with a single (null) factor. - auto emptyGaussian = [](const GaussianFactorGraph &gfg) { - bool hasNull = - std::any_of(gfg.begin(), gfg.end(), - [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); + auto emptyGaussian = + [](const GaussianMixtureFactor::GraphAndConstant &graph_z) { + bool hasNull = std::any_of( + graph_z.graph.begin(), graph_z.graph.end(), + [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); - return hasNull ? GaussianFactorGraph() : gfg; - }; + return hasNull ? GaussianMixtureFactor::GraphAndConstant( + GaussianFactorGraph(), 0.0) + : graph_z; + }; sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); using EliminationPair = std::pair, GaussianMixtureFactor::FactorAndConstant>; KeyVector keysOfEliminated; // Not the ordering - KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? + KeyVector keysOfSeparator; // This is the elimination method on the leaf nodes - auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { - if (graph.empty()) { + auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z) + -> EliminationPair { + if (graph_z.graph.empty()) { return {nullptr, {nullptr, 0.0}}; } @@ -224,7 +230,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors, std::pair, boost::shared_ptr> - conditional_factor = EliminatePreferCholesky(graph, frontalKeys); + conditional_factor = + EliminatePreferCholesky(graph_z.graph, frontalKeys); // Initialize the keysOfEliminated to be the keys of the // eliminated GaussianConditional @@ -235,8 +242,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors, gttoc_(hybrid_eliminate); #endif - // TODO(Varun) The normalizing constant has to be computed correctly - return {conditional_factor.first, {conditional_factor.second, 0.0}}; + GaussianConditional::shared_ptr conditional = conditional_factor.first; + // Get the log of the log normalization constant inverse. + double logZ = -conditional->logNormalizationConstant() + graph_z.constant; + return {conditional, {conditional_factor.second, logZ}}; }; // Perform elimination! @@ -270,6 +279,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors, } }; DecisionTree fdt(separatorFactors, factorProb); + // Normalize the values of decision tree to be valid probabilities + double sum = 0.0; + auto visitor = [&](double y) { sum += y; }; + fdt.visit(visitor); + // fdt = DecisionTree(fdt, + // [sum](const double &x) { return x / sum; + // }); auto discreteFactor = boost::make_shared(discreteSeparator, fdt);