/* ---------------------------------------------------------------------------- * GTSAM Copyright 2010, Georgia Tech Research Corporation, * Atlanta, Georgia 30332-0415 * All Rights Reserved * Authors: Frank Dellaert, et al. (see THANKS for the full author list) * See LICENSE for the license information * -------------------------------------------------------------------------- */ /** * @file HybridGaussianFactorGraph.cpp * @brief Hybrid factor graph that uses type erasure * @author Fan Jiang * @author Varun Agrawal * @author Frank Dellaert * @date Mar 11, 2022 */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: template class EliminateableFactorGraph; using std::dynamic_pointer_cast; using OrphanWrapper = BayesTreeOrphanWrapper; /// Result from elimination. struct Result { // Gaussian conditional resulting from elimination. GaussianConditional::shared_ptr conditional; double negLogK; // Negative log of the normalization constant K. GaussianFactor::shared_ptr factor; // Leftover factor 𝜏. double scalar; // Scalar value associated with factor 𝜏. bool operator==(const Result &other) const { return conditional == other.conditional && negLogK == other.negLogK && factor == other.factor && scalar == other.scalar; } }; using ResultTree = DecisionTree; static const VectorValues kEmpty; /* ************************************************************************ */ // Throw a runtime exception for method specified in string s, and factor f: static void throwRuntimeError(const std::string &s, const std::shared_ptr &f) { auto &fr = *f; throw std::runtime_error(s + " not implemented for factor type " + demangle(typeid(fr).name()) + "."); } /* ************************************************************************ */ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) { KeySet discrete_keys = graph.discreteKeySet(); const VariableIndex index(graph); return Ordering::ColamdConstrainedLast( index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); } /* ************************************************************************ */ static void printFactor(const std::shared_ptr &factor, const DiscreteValues &assignment, const KeyFormatter &keyFormatter) { if (auto hgf = dynamic_pointer_cast(factor)) { if (assignment.empty()) { hgf->print("HybridGaussianFactor:", keyFormatter); } else { hgf->operator()(assignment) .first->print("HybridGaussianFactor, component:", keyFormatter); } } else if (auto gf = dynamic_pointer_cast(factor)) { factor->print("GaussianFactor:\n", keyFormatter); } else if (auto df = dynamic_pointer_cast(factor)) { factor->print("DiscreteFactor:\n", keyFormatter); } else if (auto hc = dynamic_pointer_cast(factor)) { if (hc->isContinuous()) { factor->print("GaussianConditional:\n", keyFormatter); } else if (hc->isDiscrete()) { factor->print("DiscreteConditional:\n", keyFormatter); } else { if (assignment.empty()) { hc->print("HybridConditional:", keyFormatter); } else { hc->asHybrid() ->choose(assignment) ->print("HybridConditional, component:\n", keyFormatter); } } } else { factor->print("Unknown factor type\n", keyFormatter); } } /* ************************************************************************ */ void HybridGaussianFactorGraph::print(const std::string &s, const KeyFormatter &keyFormatter) const { std::cout << (s.empty() ? "" : s + " ") << std::endl; std::cout << "size: " << size() << std::endl; for (size_t i = 0; i < factors_.size(); i++) { auto &&factor = factors_[i]; if (factor == nullptr) { std::cout << "Factor " << i << ": nullptr\n"; continue; } // Print the factor std::cout << "Factor " << i << "\n"; printFactor(factor, {}, keyFormatter); std::cout << "\n"; } std::cout.flush(); } /* ************************************************************************ */ void HybridGaussianFactorGraph::printErrors( const HybridValues &values, const std::string &str, const KeyFormatter &keyFormatter, const std::function &printCondition) const { std::cout << str << "size: " << size() << std::endl << std::endl; for (size_t i = 0; i < factors_.size(); i++) { auto &&factor = factors_[i]; if (factor == nullptr) { std::cout << "Factor " << i << ": nullptr\n"; continue; } const double errorValue = factor->error(values); if (!printCondition(factor.get(), errorValue, i)) continue; // User-provided filter did not pass // Print the factor std::cout << "Factor " << i << ", error = " << errorValue << "\n"; printFactor(factor, values.discrete(), keyFormatter); std::cout << "\n"; } std::cout.flush(); } /* ************************************************************************ */ HybridGaussianProductFactor HybridGaussianFactorGraph::collectProductFactor() const { HybridGaussianProductFactor result; for (auto &f : factors_) { // TODO(dellaert): can we make this cleaner and less error-prone? if (auto orphan = dynamic_pointer_cast(f)) { continue; // Ignore OrphanWrapper } else if (auto gf = dynamic_pointer_cast(f)) { result += gf; } else if (auto gc = dynamic_pointer_cast(f)) { result += gc; } else if (auto gmf = dynamic_pointer_cast(f)) { result += *gmf; } else if (auto gm = dynamic_pointer_cast(f)) { result += *gm; // handled above already? } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gm = hc->asHybrid()) { result += *gm; } else if (auto g = hc->asGaussian()) { result += g; } else { // Has to be discrete. // TODO(dellaert): in C++20, we can use std::visit. continue; } } else if (dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; } else { // TODO(dellaert): there was an unattributed comment here: We need to // handle the case where the object is actually an BayesTreeOrphanWrapper! throwRuntimeError("gtsam::collectProductFactor", f); } } return result; } /* ************************************************************************ */ static std::pair> continuousElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { GaussianFactorGraph gfg; for (auto &f : factors) { if (auto gf = dynamic_pointer_cast(f)) { gfg.push_back(gf); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. } else if (auto hc = dynamic_pointer_cast(f)) { auto gc = hc->asGaussian(); if (!gc) throwRuntimeError("continuousElimination", gc); gfg.push_back(gc); } else { throwRuntimeError("continuousElimination", f); } } auto result = EliminatePreferCholesky(gfg, frontalKeys); return {std::make_shared(result.first), result.second}; } /* ************************************************************************ */ /** * @brief Take negative log-values, shift them so that the minimum value is 0, * and then exponentiate to create a TableFactor (not normalized yet!). * * @param errors DecisionTree of (unnormalized) errors. * @return TableFactor::shared_ptr */ static TableFactor::shared_ptr DiscreteFactorFromErrors( const DiscreteKeys &discreteKeys, const AlgebraicDecisionTree &errors) { double min_log = errors.min(); AlgebraicDecisionTree potentials( errors, [&min_log](const double x) { return exp(-(x - min_log)); }); return std::make_shared(discreteKeys, potentials); } /* ************************************************************************ */ TableFactor TableProductAndNormalize(const DiscreteFactorGraph &factors) { // PRODUCT: multiply all factors #if GTSAM_HYBRID_TIMING gttic_(DiscreteProduct); #endif TableFactor product; for (auto &&factor : factors) { if (factor) { if (auto dtc = std::dynamic_pointer_cast(factor)) { product = product * dtc->table(); } else if (auto f = std::dynamic_pointer_cast(factor)) { product = product * (*f); } else if (auto dtf = std::dynamic_pointer_cast(factor)) { product = product * TableFactor(*dtf); } } } #if GTSAM_HYBRID_TIMING gttoc_(DiscreteProduct); #endif // Max over all the potentials by pretending all keys are frontal: auto normalizer = product.max(product.size()); #if GTSAM_HYBRID_TIMING gttic_(DiscreteNormalize); #endif // Normalize the product factor to prevent underflow. product = product / (*normalizer); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteNormalize); #endif return product; } /* ************************************************************************ */ static std::pair> discreteElimination(const HybridGaussianFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; for (auto &f : factors) { if (auto df = dynamic_pointer_cast(f)) { dfg.push_back(df); } else if (auto gmf = dynamic_pointer_cast(f)) { // Case where we have a HybridGaussianFactor with no continuous keys. // In this case, compute a discrete factor from the remaining error. auto calculateError = [&](const auto &pair) -> double { auto [factor, scalar] = pair; // If factor is null, it has been pruned, hence return infinite error if (!factor) return std::numeric_limits::infinity(); return scalar + factor->error(kEmpty); }; AlgebraicDecisionTree errors(gmf->factors(), calculateError); dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors)); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. } else if (auto hc = dynamic_pointer_cast(f)) { auto dc = hc->asDiscrete(); if (!dc) throwRuntimeError("discreteElimination", dc); #if GTSAM_HYBRID_TIMING gttic_(ConvertConditionalToTableFactor); #endif if (auto dtc = std::dynamic_pointer_cast(dc)) { /// Get the underlying TableFactor dfg.push_back(dtc->table()); } else { // Convert DiscreteConditional to TableFactor auto tdc = std::make_shared(*dc); dfg.push_back(tdc); } #if GTSAM_HYBRID_TIMING gttoc_(ConvertConditionalToTableFactor); #endif } else { throwRuntimeError("discreteElimination", f); } } #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscrete); #endif /**** NOTE: This does sum-product. ****/ // Get product factor TableFactor product = TableProductAndNormalize(dfg); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteSum); #endif // All the discrete variables should form a single clique, // so we can sum out on all the variables as frontals. // This should give an empty separator. TableFactor::shared_ptr sum = product.sum(frontalKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteSum); #endif // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); #if GTSAM_HYBRID_TIMING gttic_(EliminateDiscreteFormDiscreteConditional); #endif // Finally, get the conditional auto conditional = std::make_shared(product, *sum, orderedKeys); #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscreteFormDiscreteConditional); #endif #if GTSAM_HYBRID_TIMING gttoc_(EliminateDiscrete); #endif return {std::make_shared(conditional), sum}; } /* ************************************************************************ */ /** * Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m) * from the residual error ||b||^2 at the mean μ. * The residual error contains no keys, and only * depends on the discrete separator if present. */ static std::shared_ptr createDiscreteFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { auto calculateError = [&](const Result &result) -> double { if (result.conditional && result.factor) { // `error` has the following contributions: // - the scalar is the sum of all mode-dependent constants // - factor->error(kempty) is the error remaining after elimination // - negLogK is what is given to the conditional to normalize return result.scalar + result.factor->error(kEmpty) - result.negLogK; } else if (!result.conditional && !result.factor) { // If the factor has been pruned, return infinite error return std::numeric_limits::infinity(); } else { throw std::runtime_error("createDiscreteFactor has mixed NULLs"); } }; #if GTSAM_HYBRID_TIMING gttic_(DiscreteBoundaryErrors); #endif AlgebraicDecisionTree errors(eliminationResults, calculateError); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteBoundaryErrors); gttic_(DiscreteBoundaryResult); #endif auto result = DiscreteFactorFromErrors(discreteSeparator, errors); #if GTSAM_HYBRID_TIMING gttoc_(DiscreteBoundaryResult); #endif return result; } /* *******************************************************************************/ // Create HybridGaussianFactor on the separator, taking care to correct // for conditional constants. static std::shared_ptr createHybridGaussianFactor( const ResultTree &eliminationResults, const DiscreteKeys &discreteSeparator) { // Correct for the normalization constant used up by the conditional auto correct = [&](const Result &result) -> GaussianFactorValuePair { if (result.conditional && result.factor) { return {result.factor, result.scalar - result.negLogK}; } else if (!result.conditional && !result.factor) { return {nullptr, std::numeric_limits::infinity()}; } else { throw std::runtime_error("createHybridGaussianFactors has mixed NULLs"); } }; DecisionTree newFactors(eliminationResults, correct); return std::make_shared(discreteSeparator, newFactors); } /* *******************************************************************************/ /// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys. static auto GetDiscreteKeys = [](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys { const std::set discreteKeySet = hfg.discreteKeys(); return {discreteKeySet.begin(), discreteKeySet.end()}; }; /* *******************************************************************************/ std::pair> HybridGaussianFactorGraph::eliminate(const Ordering &keys) const { // Since we eliminate all continuous variables first, // the discrete separator will be *all* the discrete keys. DiscreteKeys discreteSeparator = GetDiscreteKeys(*this); #if GTSAM_HYBRID_TIMING gttic_(HybridCollectProductFactor); #endif // Collect all the factors to create a set of Gaussian factor graphs in a // decision tree indexed by all discrete keys involved. Just like any hybrid // factor, every assignment also has a scalar error, in this case the sum of // all errors in the graph. This error is assignment-specific and accounts for // any difference in noise models used. HybridGaussianProductFactor productFactor = collectProductFactor(); #if GTSAM_HYBRID_TIMING gttoc_(HybridCollectProductFactor); #endif // Check if a factor is null auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; }; // This is the elimination method on the leaf nodes bool someContinuousLeft = false; auto eliminate = [&](const std::pair &pair) -> Result { const auto &[graph, scalar] = pair; // If any product contains a pruned factor, prune it here. Done here as it's // non non-trivial to do within collectProductFactor. if (graph.empty() || std::any_of(graph.begin(), graph.end(), isNull)) { return {nullptr, 0.0, nullptr, 0.0}; } // Expensive elimination of product factor. auto [conditional, factor] = EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE // Record whether there any continuous variables left someContinuousLeft |= !factor->empty(); // We pass on the scalar unmodified. return {conditional, conditional->negLogConstant(), factor, scalar}; }; #if GTSAM_HYBRID_TIMING gttic_(HybridEliminate); #endif // Perform elimination! const ResultTree eliminationResults(productFactor, eliminate); #if GTSAM_HYBRID_TIMING gttoc_(HybridEliminate); #endif // If there are no more continuous parents we create a DiscreteFactor with the // error for each discrete choice. Otherwise, create a HybridGaussianFactor // on the separator, taking care to correct for conditional constants. auto newFactor = someContinuousLeft ? createHybridGaussianFactor(eliminationResults, discreteSeparator) : createDiscreteFactor(eliminationResults, discreteSeparator); // Create the HybridGaussianConditional without re-calculating constants: HybridGaussianConditional::FactorValuePairs pairs( eliminationResults, [](const Result &result) -> GaussianFactorValuePair { return {result.conditional, result.negLogK}; }); auto hybridGaussian = std::make_shared(discreteSeparator, pairs); return {std::make_shared(hybridGaussian), newFactor}; } /* ************************************************************************ * Function to eliminate variables **under the following assumptions**: * 1. When the ordering is fully continuous, and the graph only contains * continuous and hybrid factors * 2. When the ordering is fully discrete, and the graph only contains discrete * factors * * Any usage outside of this is considered incorrect. * * \warning This function is not meant to be used with arbitrary hybrid factor * graphs. For example, if there exists continuous parents, and one tries to * eliminate a discrete variable (as specified in the ordering), the result will * be INCORRECT and there will be NO error raised. */ std::pair> // EliminateHybrid(const HybridGaussianFactorGraph &factors, const Ordering &keys) { // NOTE: Because we are in the Conditional Gaussian regime there are only // a few cases: // 1. continuous variable, make a hybrid Gaussian conditional if there are // hybrid factors; // 2. continuous variable, we make a Gaussian Factor if there are no hybrid // factors; // 3. discrete variable, no continuous factor is allowed // (escapes Conditional Gaussian regime), if discrete only we do the discrete // elimination. // However it is not that simple. During elimination it is possible that the // multifrontal needs to eliminate an ordering that contains both Gaussian and // hybrid variables, for example x1, c1. // In this scenario, we will have a density P(x1, c1) that is a Conditional // Linear Gaussian P(x1|c1)P(c1) (see Murphy02). // The issue here is that, how can we know which variable is discrete if we // unify Values? Obviously we can tell using the factors, but is that fast? // In the case of multifrontal, we will need to use a constrained ordering // so that the discrete parts will be guaranteed to be eliminated last! // Because of all these reasons, we carefully consider how to // implement the hybrid factors so that we do not get poor performance. // The first thing is how to represent the HybridGaussianConditional. // A very possible scenario is that the incoming factors will have different // levels of discrete keys. For example, imagine we are going to eliminate the // fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. // Now we will need to know how to retrieve the corresponding continuous // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO // defined order!). We also need to consider when there is pruning. Two // hybrid factors could have different pruning patterns - one could have // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this // creates a big problem in how to identify the intersection of non-pruned // branches. // Our approach is first building the collection of all discrete keys. After // that we enumerate the space of all key combinations *lazily* so that the // exploration branch terminates whenever an assignment yields NULL in any of // the hybrid factors. // When the number of assignments is large we may encounter stack overflows. // However this is also the case with iSAM2, so no pressure :) // Check the factors: // 1. if all factors are discrete, then we can do discrete elimination: // 2. if all factors are continuous, then we can do continuous elimination: // 3. if not, we do hybrid elimination: bool only_discrete = true, only_continuous = true; for (auto &&factor : factors) { if (auto hybrid_factor = std::dynamic_pointer_cast(factor)) { if (hybrid_factor->isDiscrete()) { only_continuous = false; } else if (hybrid_factor->isContinuous()) { only_discrete = false; } else if (hybrid_factor->isHybrid()) { only_continuous = false; only_discrete = false; break; } } else if (auto cont_factor = std::dynamic_pointer_cast(factor)) { only_discrete = false; } else if (auto discrete_factor = std::dynamic_pointer_cast(factor)) { only_continuous = false; } } // NOTE: We should really defer the product here because of pruning if (only_discrete) { // Case 1: we are only dealing with discrete return discreteElimination(factors, keys); } else if (only_continuous) { // Case 2: we are only dealing with continuous return continuousElimination(factors, keys); } else { // Case 3: We are now in the hybrid land! return factors.eliminate(keys); } } /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::errorTree( const VectorValues &continuousValues) const { AlgebraicDecisionTree result(0.0); // Iterate over each factor. for (auto &factor : factors_) { if (auto hf = std::dynamic_pointer_cast(factor)) { // Add errorTree for hybrid factors, includes HybridGaussianConditionals! result = result + hf->errorTree(continuousValues); } else if (auto df = std::dynamic_pointer_cast(factor)) { // If discrete, just add its errorTree as well result = result + df->errorTree(); } else if (auto gf = dynamic_pointer_cast(factor)) { // For a continuous only factor, just add its error result = result + gf->error(continuousValues); } else { throwRuntimeError("HybridGaussianFactorGraph::errorTree", factor); } } return result; } /* ************************************************************************ */ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const { double error = this->error(values); // NOTE: The 0.5 term is handled by each factor return std::exp(-error); } /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::discretePosterior( const VectorValues &continuousValues) const { AlgebraicDecisionTree errors = this->errorTree(continuousValues); AlgebraicDecisionTree p = errors.apply([](double error) { // NOTE: The 0.5 term is handled by each factor return exp(-error); }); return p / p.sum(); } /* ************************************************************************ */ GaussianFactorGraph HybridGaussianFactorGraph::choose( const DiscreteValues &assignment) const { GaussianFactorGraph gfg; for (auto &&f : *this) { if (auto gf = std::dynamic_pointer_cast(f)) { gfg.push_back(gf); } else if (auto gc = std::dynamic_pointer_cast(f)) { gfg.push_back(gf); } else if (auto hgf = std::dynamic_pointer_cast(f)) { gfg.push_back((*hgf)(assignment).first); } else if (auto hgc = dynamic_pointer_cast(f)) { gfg.push_back((*hgc)(assignment)); } else if (auto hc = dynamic_pointer_cast(f)) { if (auto gc = hc->asGaussian()) gfg.push_back(gc); else if (auto hgc = hc->asHybrid()) gfg.push_back((*hgc)(assignment)); } else { continue; } } return gfg; } /* ************************************************************************ */ DiscreteFactorGraph HybridGaussianFactorGraph::discreteFactors() const { DiscreteFactorGraph dfg; for (auto &&f : factors_) { if (auto discreteFactor = std::dynamic_pointer_cast(f)) { dfg.push_back(discreteFactor); } } return dfg; } } // namespace gtsam