Merge pull request #1575 from borglab/hybrid-tablefactor-2
						commit
						ba7c077a25
					
				| 
						 | 
				
			
			@ -93,7 +93,8 @@ namespace gtsam {
 | 
			
		|||
    /// print
 | 
			
		||||
    void print(const std::string& s, const LabelFormatter& labelFormatter,
 | 
			
		||||
               const ValueFormatter& valueFormatter) const override {
 | 
			
		||||
      std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
 | 
			
		||||
      std::cout << s << " Leaf [" << nrAssignments() << "] "
 | 
			
		||||
                << valueFormatter(constant_) << std::endl;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** Write graphviz format to stream `os`. */
 | 
			
		||||
| 
						 | 
				
			
			@ -827,6 +828,16 @@ namespace gtsam {
 | 
			
		|||
    return total;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
  template <typename L, typename Y>
 | 
			
		||||
  size_t DecisionTree<L, Y>::nrAssignments() const {
 | 
			
		||||
    size_t n = 0;
 | 
			
		||||
    this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
 | 
			
		||||
      n += leaf.nrAssignments();
 | 
			
		||||
    });
 | 
			
		||||
    return n;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
  // fold is just done with a visit
 | 
			
		||||
  template <typename L, typename Y>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -320,6 +320,42 @@ namespace gtsam {
 | 
			
		|||
    /// Return the number of leaves in the tree.
 | 
			
		||||
    size_t nrLeaves() const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief This is a convenience function which returns the total number of
 | 
			
		||||
     * leaf assignments in the decision tree.
 | 
			
		||||
     * This function is not used for anymajor operations within the discrete
 | 
			
		||||
     * factor graph framework.
 | 
			
		||||
     *
 | 
			
		||||
     * Leaf assignments represent the cardinality of each leaf node, e.g. in a
 | 
			
		||||
     * binary tree each leaf has 2 assignments. This includes counts removed
 | 
			
		||||
     * from implicit pruning hence, it will always be >= nrLeaves().
 | 
			
		||||
     *
 | 
			
		||||
     * E.g. we have a decision tree as below, where each node has 2 branches:
 | 
			
		||||
     *
 | 
			
		||||
     * Choice(m1)
 | 
			
		||||
     * 0 Choice(m0)
 | 
			
		||||
     * 0 0 Leaf 0.0
 | 
			
		||||
     * 0 1 Leaf 0.0
 | 
			
		||||
     * 1 Choice(m0)
 | 
			
		||||
     * 1 0 Leaf 1.0
 | 
			
		||||
     * 1 1 Leaf 2.0
 | 
			
		||||
     *
 | 
			
		||||
     * In the unpruned form, the tree will have 4 assignments, 2 for each key,
 | 
			
		||||
     * and 4 leaves.
 | 
			
		||||
     *
 | 
			
		||||
     * In the pruned form, the number of assignments is still 4 but the number
 | 
			
		||||
     * of leaves is now 3, as below:
 | 
			
		||||
     *
 | 
			
		||||
     * Choice(m1)
 | 
			
		||||
     * 0 Leaf 0.0
 | 
			
		||||
     * 1 Choice(m0)
 | 
			
		||||
     * 1 0 Leaf 1.0
 | 
			
		||||
     * 1 1 Leaf 2.0
 | 
			
		||||
     *
 | 
			
		||||
     * @return size_t
 | 
			
		||||
     */
 | 
			
		||||
    size_t nrAssignments() const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Fold a binary function over the tree, returning accumulator.
 | 
			
		||||
     *
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -101,6 +101,14 @@ namespace gtsam {
 | 
			
		|||
    return DecisionTreeFactor(keys, result);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
 | 
			
		||||
    // apply operand
 | 
			
		||||
    ADT result = ADT::apply(op);
 | 
			
		||||
    // Make a new factor
 | 
			
		||||
    return DecisionTreeFactor(discreteKeys(), result);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
 | 
			
		||||
      size_t nrFrontals, ADT::Binary op) const {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -182,6 +182,12 @@ namespace gtsam {
 | 
			
		|||
    /// @name Advanced Interface
 | 
			
		||||
    /// @{
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Apply unary operator (*this) "op" f
 | 
			
		||||
     * @param op a unary operator that operates on AlgebraicDecisionTree
 | 
			
		||||
     */
 | 
			
		||||
    DecisionTreeFactor apply(ADT::UnaryAssignment op) const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Apply binary operator (*this) "op" f
 | 
			
		||||
     * @param f the second argument for op
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,7 @@
 | 
			
		|||
#include <gtsam/base/serializationTestHelpers.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTree-inl.h>
 | 
			
		||||
#include <gtsam/discrete/Signature.h>
 | 
			
		||||
#include <gtsam/inference/Symbol.h>
 | 
			
		||||
 | 
			
		||||
#include <iomanip>
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
 | 
			
		|||
  return Base::equals(bn, tol);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
 | 
			
		||||
  AlgebraicDecisionTree<Key> discreteProbs;
 | 
			
		||||
 | 
			
		||||
  // The canonical decision tree factor which will get
 | 
			
		||||
  // the discrete conditionals added to it.
 | 
			
		||||
  DecisionTreeFactor discreteProbsFactor;
 | 
			
		||||
 | 
			
		||||
  for (auto &&conditional : *this) {
 | 
			
		||||
    if (conditional->isDiscrete()) {
 | 
			
		||||
      // Convert to a DecisionTreeFactor and add it to the main factor.
 | 
			
		||||
      DecisionTreeFactor f(*conditional->asDiscrete());
 | 
			
		||||
      discreteProbsFactor = discreteProbsFactor * f;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Helper function to get the pruner functional.
 | 
			
		||||
| 
						 | 
				
			
			@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
void HybridBayesNet::updateDiscreteConditionals(
 | 
			
		||||
    const DecisionTreeFactor &prunedDiscreteProbs) {
 | 
			
		||||
  KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
 | 
			
		||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
 | 
			
		||||
    size_t maxNrLeaves) {
 | 
			
		||||
  // Get the joint distribution of only the discrete keys
 | 
			
		||||
  gttic_(HybridBayesNet_PruneDiscreteConditionals);
 | 
			
		||||
  // The joint discrete probability.
 | 
			
		||||
  DiscreteConditional discreteProbs;
 | 
			
		||||
 | 
			
		||||
  std::vector<size_t> discrete_factor_idxs;
 | 
			
		||||
  // Record frontal keys so we can maintain ordering
 | 
			
		||||
  Ordering discrete_frontals;
 | 
			
		||||
 | 
			
		||||
  // Loop with index since we need it later.
 | 
			
		||||
  for (size_t i = 0; i < this->size(); i++) {
 | 
			
		||||
    HybridConditional::shared_ptr conditional = this->at(i);
 | 
			
		||||
    auto conditional = this->at(i);
 | 
			
		||||
    if (conditional->isDiscrete()) {
 | 
			
		||||
      auto discrete = conditional->asDiscrete();
 | 
			
		||||
      discreteProbs = discreteProbs * (*conditional->asDiscrete());
 | 
			
		||||
 | 
			
		||||
      // Convert pointer from conditional to factor
 | 
			
		||||
      auto discreteTree =
 | 
			
		||||
          std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
 | 
			
		||||
      // Apply prunerFunc to the underlying AlgebraicDecisionTree
 | 
			
		||||
      DecisionTreeFactor::ADT prunedDiscreteTree =
 | 
			
		||||
          discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
 | 
			
		||||
 | 
			
		||||
      gttic_(HybridBayesNet_MakeConditional);
 | 
			
		||||
      // Create the new (hybrid) conditional
 | 
			
		||||
      KeyVector frontals(discrete->frontals().begin(),
 | 
			
		||||
                         discrete->frontals().end());
 | 
			
		||||
      auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
 | 
			
		||||
          frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
 | 
			
		||||
      conditional = std::make_shared<HybridConditional>(prunedDiscrete);
 | 
			
		||||
      gttoc_(HybridBayesNet_MakeConditional);
 | 
			
		||||
 | 
			
		||||
      // Add it back to the BayesNet
 | 
			
		||||
      this->at(i) = conditional;
 | 
			
		||||
      Ordering conditional_keys(conditional->frontals());
 | 
			
		||||
      discrete_frontals += conditional_keys;
 | 
			
		||||
      discrete_factor_idxs.push_back(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const DecisionTreeFactor prunedDiscreteProbs =
 | 
			
		||||
      discreteProbs.prune(maxNrLeaves);
 | 
			
		||||
  gttoc_(HybridBayesNet_PruneDiscreteConditionals);
 | 
			
		||||
 | 
			
		||||
  // Eliminate joint probability back into conditionals
 | 
			
		||||
  gttic_(HybridBayesNet_UpdateDiscreteConditionals);
 | 
			
		||||
  DiscreteFactorGraph dfg{prunedDiscreteProbs};
 | 
			
		||||
  DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
 | 
			
		||||
 | 
			
		||||
  // Assign pruned discrete conditionals back at the correct indices.
 | 
			
		||||
  for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
 | 
			
		||||
    size_t idx = discrete_factor_idxs.at(i);
 | 
			
		||||
    this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
 | 
			
		||||
  }
 | 
			
		||||
  gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
 | 
			
		||||
 | 
			
		||||
  return prunedDiscreteProbs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
 | 
			
		||||
  // Get the decision tree of only the discrete keys
 | 
			
		||||
  gttic_(HybridBayesNet_PruneDiscreteConditionals);
 | 
			
		||||
  DecisionTreeFactor::shared_ptr discreteConditionals =
 | 
			
		||||
      this->discreteConditionals();
 | 
			
		||||
  const DecisionTreeFactor prunedDiscreteProbs =
 | 
			
		||||
      discreteConditionals->prune(maxNrLeaves);
 | 
			
		||||
  gttoc_(HybridBayesNet_PruneDiscreteConditionals);
 | 
			
		||||
  DecisionTreeFactor prunedDiscreteProbs =
 | 
			
		||||
      this->pruneDiscreteConditionals(maxNrLeaves);
 | 
			
		||||
 | 
			
		||||
  gttic_(HybridBayesNet_UpdateDiscreteConditionals);
 | 
			
		||||
  this->updateDiscreteConditionals(prunedDiscreteProbs);
 | 
			
		||||
  gttoc_(HybridBayesNet_UpdateDiscreteConditionals);
 | 
			
		||||
 | 
			
		||||
  /* To Prune, we visitWith every leaf in the GaussianMixture.
 | 
			
		||||
  /* To prune, we visitWith every leaf in the GaussianMixture.
 | 
			
		||||
   * For each leaf, using the assignment we can check the discrete decision tree
 | 
			
		||||
   * for 0.0 probability, then just set the leaf to a nullptr.
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
			
		|||
   */
 | 
			
		||||
  VectorValues optimize(const DiscreteValues &assignment) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Get all the discrete conditionals as a decision tree factor.
 | 
			
		||||
   *
 | 
			
		||||
   * @return DecisionTreeFactor::shared_ptr
 | 
			
		||||
   */
 | 
			
		||||
  DecisionTreeFactor::shared_ptr discreteConditionals() const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Sample from an incomplete BayesNet, given missing variables.
 | 
			
		||||
   *
 | 
			
		||||
| 
						 | 
				
			
			@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
			
		|||
 | 
			
		||||
 private:
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Update the discrete conditionals with the pruned versions.
 | 
			
		||||
   * @brief Prune all the discrete conditionals.
 | 
			
		||||
   *
 | 
			
		||||
   * @param prunedDiscreteProbs
 | 
			
		||||
   * @param maxNrLeaves
 | 
			
		||||
   */
 | 
			
		||||
  void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
 | 
			
		||||
  DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
 | 
			
		||||
 | 
			
		||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
 | 
			
		||||
  /** Serialization function */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,6 +20,7 @@
 | 
			
		|||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/DecisionTree.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteKey.h>
 | 
			
		||||
#include <gtsam/discrete/TableFactor.h>
 | 
			
		||||
#include <gtsam/inference/Factor.h>
 | 
			
		||||
#include <gtsam/linear/GaussianFactorGraph.h>
 | 
			
		||||
#include <gtsam/nonlinear/Values.h>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,7 +17,6 @@
 | 
			
		|||
 * @date   January, 2023
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridFactorGraph.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
| 
						 | 
				
			
			@ -26,7 +25,7 @@ namespace gtsam {
 | 
			
		|||
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
 | 
			
		||||
  std::set<DiscreteKey> keys;
 | 
			
		||||
  for (auto& factor : factors_) {
 | 
			
		||||
    if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
 | 
			
		||||
    if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
 | 
			
		||||
      for (const DiscreteKey& key : p->discreteKeys()) {
 | 
			
		||||
        keys.insert(key);
 | 
			
		||||
      }
 | 
			
		||||
| 
						 | 
				
			
			@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
 | 
			
		|||
      for (const Key& key : p->continuousKeys()) {
 | 
			
		||||
        keys.insert(key);
 | 
			
		||||
      }
 | 
			
		||||
    } else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
 | 
			
		||||
      keys.insert(p->keys().begin(), p->keys().end());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return keys;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,8 +48,6 @@
 | 
			
		|||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
// #define HYBRID_TIMING
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
 | 
			
		|||
        // TODO(dellaert): in C++20, we can use std::visit.
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
    } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
 | 
			
		||||
    } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
 | 
			
		||||
      // Don't do anything for discrete-only factors
 | 
			
		||||
      // since we want to eliminate continuous values only.
 | 
			
		||||
      continue;
 | 
			
		||||
| 
						 | 
				
			
			@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
  DiscreteFactorGraph dfg;
 | 
			
		||||
 | 
			
		||||
  for (auto &f : factors) {
 | 
			
		||||
    if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
 | 
			
		||||
      dfg.push_back(dtf);
 | 
			
		||||
    if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
 | 
			
		||||
      dfg.push_back(df);
 | 
			
		||||
    } else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
 | 
			
		||||
      // Ignore orphaned clique.
 | 
			
		||||
      // TODO(dellaert): is this correct? If so explain here.
 | 
			
		||||
| 
						 | 
				
			
			@ -262,6 +260,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
    };
 | 
			
		||||
 | 
			
		||||
    DecisionTree<Key, double> probabilities(eliminationResults, probability);
 | 
			
		||||
 | 
			
		||||
    return {
 | 
			
		||||
        std::make_shared<HybridConditional>(gaussianMixture),
 | 
			
		||||
        std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities)};
 | 
			
		||||
| 
						 | 
				
			
			@ -348,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
  // When the number of assignments is large we may encounter stack overflows.
 | 
			
		||||
  // However this is also the case with iSAM2, so no pressure :)
 | 
			
		||||
 | 
			
		||||
  // PREPROCESS: Identify the nature of the current elimination
 | 
			
		||||
 | 
			
		||||
  // TODO(dellaert): just check the factors:
 | 
			
		||||
  // Check the factors:
 | 
			
		||||
  // 1. if all factors are discrete, then we can do discrete elimination:
 | 
			
		||||
  // 2. if all factors are continuous, then we can do continuous elimination:
 | 
			
		||||
  // 3. if not, we do hybrid elimination:
 | 
			
		||||
 | 
			
		||||
  // First, identify the separator keys, i.e. all keys that are not frontal.
 | 
			
		||||
  KeySet separatorKeys;
 | 
			
		||||
  bool only_discrete = true, only_continuous = true;
 | 
			
		||||
  for (auto &&factor : factors) {
 | 
			
		||||
    separatorKeys.insert(factor->begin(), factor->end());
 | 
			
		||||
  }
 | 
			
		||||
  // remove frontals from separator
 | 
			
		||||
  for (auto &k : frontalKeys) {
 | 
			
		||||
    separatorKeys.erase(k);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Build a map from keys to DiscreteKeys
 | 
			
		||||
  auto mapFromKeyToDiscreteKey = factors.discreteKeyMap();
 | 
			
		||||
 | 
			
		||||
  // Fill in discrete frontals and continuous frontals.
 | 
			
		||||
  std::set<DiscreteKey> discreteFrontals;
 | 
			
		||||
  KeySet continuousFrontals;
 | 
			
		||||
  for (auto &k : frontalKeys) {
 | 
			
		||||
    if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
 | 
			
		||||
      discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
 | 
			
		||||
    } else {
 | 
			
		||||
      continuousFrontals.insert(k);
 | 
			
		||||
    if (auto hybrid_factor = std::dynamic_pointer_cast<HybridFactor>(factor)) {
 | 
			
		||||
      if (hybrid_factor->isDiscrete()) {
 | 
			
		||||
        only_continuous = false;
 | 
			
		||||
      } else if (hybrid_factor->isContinuous()) {
 | 
			
		||||
        only_discrete = false;
 | 
			
		||||
      } else if (hybrid_factor->isHybrid()) {
 | 
			
		||||
        only_continuous = false;
 | 
			
		||||
        only_discrete = false;
 | 
			
		||||
      }
 | 
			
		||||
    } else if (auto cont_factor =
 | 
			
		||||
                   std::dynamic_pointer_cast<GaussianFactor>(factor)) {
 | 
			
		||||
      only_discrete = false;
 | 
			
		||||
    } else if (auto discrete_factor =
 | 
			
		||||
                   std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
 | 
			
		||||
      only_continuous = false;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Fill in discrete discrete separator keys and continuous separator keys.
 | 
			
		||||
  std::set<DiscreteKey> discreteSeparatorSet;
 | 
			
		||||
  KeyVector continuousSeparator;
 | 
			
		||||
  for (auto &k : separatorKeys) {
 | 
			
		||||
    if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
 | 
			
		||||
      discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
 | 
			
		||||
    } else {
 | 
			
		||||
      continuousSeparator.push_back(k);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Check if we have any continuous keys:
 | 
			
		||||
  const bool discrete_only =
 | 
			
		||||
      continuousFrontals.empty() && continuousSeparator.empty();
 | 
			
		||||
 | 
			
		||||
  // NOTE: We should really defer the product here because of pruning
 | 
			
		||||
 | 
			
		||||
  if (discrete_only) {
 | 
			
		||||
  if (only_discrete) {
 | 
			
		||||
    // Case 1: we are only dealing with discrete
 | 
			
		||||
    return discreteElimination(factors, frontalKeys);
 | 
			
		||||
  } else if (mapFromKeyToDiscreteKey.empty()) {
 | 
			
		||||
  } else if (only_continuous) {
 | 
			
		||||
    // Case 2: we are only dealing with continuous
 | 
			
		||||
    return continuousElimination(factors, frontalKeys);
 | 
			
		||||
  } else {
 | 
			
		||||
    // Case 3: We are now in the hybrid land!
 | 
			
		||||
    KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end());
 | 
			
		||||
 | 
			
		||||
    // Find all the keys in the set of continuous keys
 | 
			
		||||
    // which are not in the frontal keys. This is our continuous separator.
 | 
			
		||||
    KeyVector continuousSeparator;
 | 
			
		||||
    auto continuousKeySet = factors.continuousKeySet();
 | 
			
		||||
    std::set_difference(
 | 
			
		||||
        continuousKeySet.begin(), continuousKeySet.end(),
 | 
			
		||||
        frontalKeysSet.begin(), frontalKeysSet.end(),
 | 
			
		||||
        std::inserter(continuousSeparator, continuousSeparator.begin()));
 | 
			
		||||
 | 
			
		||||
    // Similarly for the discrete separator.
 | 
			
		||||
    KeySet discreteSeparatorSet;
 | 
			
		||||
    std::set<DiscreteKey> discreteSeparator;
 | 
			
		||||
    auto discreteKeySet = factors.discreteKeySet();
 | 
			
		||||
    std::set_difference(
 | 
			
		||||
        discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(),
 | 
			
		||||
        frontalKeysSet.end(),
 | 
			
		||||
        std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin()));
 | 
			
		||||
    // Convert from set of keys to set of DiscreteKeys
 | 
			
		||||
    auto discreteKeyMap = factors.discreteKeyMap();
 | 
			
		||||
    for (auto key : discreteSeparatorSet) {
 | 
			
		||||
      discreteSeparator.insert(discreteKeyMap.at(key));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return hybridElimination(factors, frontalKeys, continuousSeparator,
 | 
			
		||||
                             discreteSeparatorSet);
 | 
			
		||||
                             discreteSeparator);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -429,7 +432,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
 | 
			
		|||
      // Add the gaussian factor error to every leaf of the error tree.
 | 
			
		||||
      error_tree = error_tree.apply(
 | 
			
		||||
          [error](double leaf_value) { return leaf_value + error; });
 | 
			
		||||
    } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
 | 
			
		||||
    } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
 | 
			
		||||
      // If factor at `idx` is discrete-only, we skip.
 | 
			
		||||
      continue;
 | 
			
		||||
    } else {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -40,6 +40,7 @@ class HybridEliminationTree;
 | 
			
		|||
class HybridBayesTree;
 | 
			
		||||
class HybridJunctionTree;
 | 
			
		||||
class DecisionTreeFactor;
 | 
			
		||||
class TableFactor;
 | 
			
		||||
class JacobianFactor;
 | 
			
		||||
class HybridValues;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,7 +66,7 @@ struct HybridConstructorTraversalData {
 | 
			
		|||
        for (auto& k : hf->discreteKeys()) {
 | 
			
		||||
          data.discreteKeys.insert(k.first);
 | 
			
		||||
        }
 | 
			
		||||
      } else if (auto hf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
 | 
			
		||||
      } else if (auto hf = std::dynamic_pointer_cast<DiscreteFactor>(f)) {
 | 
			
		||||
        for (auto& k : hf->discreteKeys()) {
 | 
			
		||||
          data.discreteKeys.insert(k.first);
 | 
			
		||||
        }
 | 
			
		||||
| 
						 | 
				
			
			@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree(
 | 
			
		|||
  Data rootData(0);
 | 
			
		||||
  rootData.junctionTreeNode =
 | 
			
		||||
      std::make_shared<typename Base::Node>();  // Make a dummy node to gather
 | 
			
		||||
                                                  // the junction tree roots
 | 
			
		||||
                                                // the junction tree roots
 | 
			
		||||
  treeTraversal::DepthFirstForest(eliminationTree, rootData,
 | 
			
		||||
                                  Data::ConstructorTraversalVisitorPre,
 | 
			
		||||
                                  Data::ConstructorTraversalVisitorPost);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,7 @@
 | 
			
		|||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
#include <gtsam/discrete/TableFactor.h>
 | 
			
		||||
#include <gtsam/hybrid/GaussianMixture.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
 | 
			
		|||
    } else if (auto nlf = dynamic_pointer_cast<NonlinearFactor>(f)) {
 | 
			
		||||
      const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues);
 | 
			
		||||
      linearFG->push_back(gf);
 | 
			
		||||
    } else if (dynamic_pointer_cast<DecisionTreeFactor>(f)) {
 | 
			
		||||
    } else if (dynamic_pointer_cast<DiscreteFactor>(f)) {
 | 
			
		||||
      // If discrete-only: doesn't need linearization.
 | 
			
		||||
      linearFG->push_back(f);
 | 
			
		||||
    } else if (auto gmf = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
 | 
			
		|||
      addConditionals(graph, hybridBayesNet_, ordering);
 | 
			
		||||
 | 
			
		||||
  // Eliminate.
 | 
			
		||||
  auto bayesNetFragment = graph.eliminateSequential(ordering);
 | 
			
		||||
  HybridBayesNet::shared_ptr bayesNetFragment =
 | 
			
		||||
      graph.eliminateSequential(ordering);
 | 
			
		||||
 | 
			
		||||
  /// Prune
 | 
			
		||||
  if (maxNrLeaves) {
 | 
			
		||||
| 
						 | 
				
			
			@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
 | 
			
		|||
  HybridGaussianFactorGraph graph(originalGraph);
 | 
			
		||||
  HybridBayesNet hybridBayesNet(originalHybridBayesNet);
 | 
			
		||||
 | 
			
		||||
  // If we are not at the first iteration, means we have conditionals to add.
 | 
			
		||||
  // If hybridBayesNet is not empty,
 | 
			
		||||
  // it means we have conditionals to add to the factor graph.
 | 
			
		||||
  if (!hybridBayesNet.empty()) {
 | 
			
		||||
    // We add all relevant conditional mixtures on the last continuous variable
 | 
			
		||||
    // in the previous `hybridBayesNet` to the graph
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -202,31 +202,16 @@ struct Switching {
 | 
			
		|||
   * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
 | 
			
		||||
   * E.g. if K=4, we want M0, M1 and M2.
 | 
			
		||||
   *
 | 
			
		||||
   * @param fg The nonlinear factor graph to which the mode chain is added.
 | 
			
		||||
   * @param fg The factor graph to which the mode chain is added.
 | 
			
		||||
   */
 | 
			
		||||
  void addModeChain(HybridNonlinearFactorGraph *fg,
 | 
			
		||||
  template <typename FACTORGRAPH>
 | 
			
		||||
  void addModeChain(FACTORGRAPH *fg,
 | 
			
		||||
                    std::string discrete_transition_prob = "1/2 3/2") {
 | 
			
		||||
    fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
 | 
			
		||||
    fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
 | 
			
		||||
    for (size_t k = 0; k < K - 2; k++) {
 | 
			
		||||
      auto parents = {modes[k]};
 | 
			
		||||
      fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
 | 
			
		||||
                                              discrete_transition_prob);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2).
 | 
			
		||||
   * E.g. if K=4, we want M0, M1 and M2.
 | 
			
		||||
   *
 | 
			
		||||
   * @param fg The gaussian factor graph to which the mode chain is added.
 | 
			
		||||
   */
 | 
			
		||||
  void addModeChain(HybridGaussianFactorGraph *fg,
 | 
			
		||||
                    std::string discrete_transition_prob = "1/2 3/2") {
 | 
			
		||||
    fg->emplace_shared<DiscreteDistribution>(modes[0], "1/1");
 | 
			
		||||
    for (size_t k = 0; k < K - 2; k++) {
 | 
			
		||||
      auto parents = {modes[k]};
 | 
			
		||||
      fg->emplace_shared<DiscreteConditional>(modes[k + 1], parents,
 | 
			
		||||
                                              discrete_transition_prob);
 | 
			
		||||
      fg->template emplace_shared<DiscreteConditional>(
 | 
			
		||||
          modes[k + 1], parents, discrete_transition_prob);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) {
 | 
			
		|||
  std::string expected =
 | 
			
		||||
      R"(Hybrid [x1 x2; 1]{
 | 
			
		||||
 Choice(1) 
 | 
			
		||||
 0 Leaf :
 | 
			
		||||
 0 Leaf [1] :
 | 
			
		||||
  A[x1] = [
 | 
			
		||||
	0;
 | 
			
		||||
	0
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) {
 | 
			
		|||
  b = [ 0 0 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Leaf :
 | 
			
		||||
 1 Leaf [1] :
 | 
			
		||||
  A[x1] = [
 | 
			
		||||
	0;
 | 
			
		||||
	0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) {
 | 
			
		|||
  auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
 | 
			
		||||
 | 
			
		||||
  // Regression test on pruned logProbability tree
 | 
			
		||||
  std::vector<double> pruned_leaves = {0.0, 20.346113, 0.0, 19.738098};
 | 
			
		||||
  std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
 | 
			
		||||
  AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
 | 
			
		||||
  EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) {
 | 
			
		|||
  logProbability +=
 | 
			
		||||
      posterior->at(4)->asDiscrete()->logProbability(hybridValues);
 | 
			
		||||
 | 
			
		||||
  // Regression
 | 
			
		||||
  double density = exp(logProbability);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(density,
 | 
			
		||||
                       1.6078460548731697 * actualTree(discrete_values), 1e-6);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
 | 
			
		||||
                       1e-9);
 | 
			
		||||
| 
						 | 
				
			
			@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
 | 
			
		|||
  EXPECT_LONGS_EQUAL(7, posterior->size());
 | 
			
		||||
 | 
			
		||||
  size_t maxNrLeaves = 3;
 | 
			
		||||
  auto discreteConditionals = posterior->discreteConditionals();
 | 
			
		||||
  DiscreteConditional discreteConditionals;
 | 
			
		||||
  for (auto&& conditional : *posterior) {
 | 
			
		||||
    if (conditional->isDiscrete()) {
 | 
			
		||||
      discreteConditionals =
 | 
			
		||||
          discreteConditionals * (*conditional->asDiscrete());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const DecisionTreeFactor::shared_ptr prunedDecisionTree =
 | 
			
		||||
      std::make_shared<DecisionTreeFactor>(
 | 
			
		||||
          discreteConditionals->prune(maxNrLeaves));
 | 
			
		||||
          discreteConditionals.prune(maxNrLeaves));
 | 
			
		||||
 | 
			
		||||
  EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
 | 
			
		||||
                     prunedDecisionTree->nrLeaves());
 | 
			
		||||
 | 
			
		||||
  auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete());
 | 
			
		||||
  // regression
 | 
			
		||||
  DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
 | 
			
		||||
  DecisionTreeFactor::ADT potentials(
 | 
			
		||||
      dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
 | 
			
		||||
  DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
 | 
			
		||||
 | 
			
		||||
  // Prune!
 | 
			
		||||
  posterior->prune(maxNrLeaves);
 | 
			
		||||
 | 
			
		||||
  // Functor to verify values against the original_discrete_conditionals
 | 
			
		||||
  // Functor to verify values against the expected_discrete_conditionals
 | 
			
		||||
  auto checker = [&](const Assignment<Key>& assignment,
 | 
			
		||||
                     double probability) -> double {
 | 
			
		||||
    // typecast so we can use this to get probability value
 | 
			
		||||
| 
						 | 
				
			
			@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
 | 
			
		|||
    if (prunedDecisionTree->operator()(choices) == 0) {
 | 
			
		||||
      EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
 | 
			
		||||
    } else {
 | 
			
		||||
      EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability,
 | 
			
		||||
      EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
 | 
			
		||||
                           1e-9);
 | 
			
		||||
    }
 | 
			
		||||
    return 0.0;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) {
 | 
			
		|||
 | 
			
		||||
  DiscreteFactorGraph dfg;
 | 
			
		||||
  for (auto&& f : *remainingFactorGraph) {
 | 
			
		||||
    auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(f);
 | 
			
		||||
    auto discreteFactor = dynamic_pointer_cast<DiscreteFactor>(f);
 | 
			
		||||
    assert(discreteFactor);
 | 
			
		||||
    dfg.push_back(discreteFactor);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) {
 | 
			
		|||
  EXPECT(assert_equal(expected_continuous, result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/****************************************************************************/
 | 
			
		||||
// Test approximate inference with an additional pruning step.
 | 
			
		||||
TEST(HybridEstimation, ISAM) {
 | 
			
		||||
  size_t K = 15;
 | 
			
		||||
  std::vector<double> measurements = {0, 1, 2, 2, 2, 2,  3,  4,  5,  6, 6,
 | 
			
		||||
                                      7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
 | 
			
		||||
  // Ground truth discrete seq
 | 
			
		||||
  std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
 | 
			
		||||
                                      1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
 | 
			
		||||
  // Switching example of robot moving in 1D
 | 
			
		||||
  // with given measurements and equal mode priors.
 | 
			
		||||
  Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
 | 
			
		||||
  HybridNonlinearISAM isam;
 | 
			
		||||
  HybridNonlinearFactorGraph graph;
 | 
			
		||||
  Values initial;
 | 
			
		||||
 | 
			
		||||
  // gttic_(Estimation);
 | 
			
		||||
 | 
			
		||||
  // Add the X(0) prior
 | 
			
		||||
  graph.push_back(switching.nonlinearFactorGraph.at(0));
 | 
			
		||||
  initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
 | 
			
		||||
 | 
			
		||||
  HybridGaussianFactorGraph linearized;
 | 
			
		||||
 | 
			
		||||
  for (size_t k = 1; k < K; k++) {
 | 
			
		||||
    // Motion Model
 | 
			
		||||
    graph.push_back(switching.nonlinearFactorGraph.at(k));
 | 
			
		||||
    // Measurement
 | 
			
		||||
    graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
 | 
			
		||||
 | 
			
		||||
    initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
 | 
			
		||||
 | 
			
		||||
    isam.update(graph, initial, 3);
 | 
			
		||||
    // isam.bayesTree().print("\n\n");
 | 
			
		||||
 | 
			
		||||
    graph.resize(0);
 | 
			
		||||
    initial.clear();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Values result = isam.estimate();
 | 
			
		||||
  DiscreteValues assignment = isam.assignment();
 | 
			
		||||
 | 
			
		||||
  DiscreteValues expected_discrete;
 | 
			
		||||
  for (size_t k = 0; k < K - 1; k++) {
 | 
			
		||||
    expected_discrete[M(k)] = discrete_seq[k];
 | 
			
		||||
  }
 | 
			
		||||
  EXPECT(assert_equal(expected_discrete, assignment));
 | 
			
		||||
 | 
			
		||||
  Values expected_continuous;
 | 
			
		||||
  for (size_t k = 0; k < K; k++) {
 | 
			
		||||
    expected_continuous.insert(X(k), measurements[k]);
 | 
			
		||||
  }
 | 
			
		||||
  EXPECT(assert_equal(expected_continuous, result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A function to get a specific 1D robot motion problem as a linearized
 | 
			
		||||
 * factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -18,7 +18,9 @@
 | 
			
		|||
#include <gtsam/base/TestableAssertions.h>
 | 
			
		||||
#include <gtsam/base/utilities.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridFactorGraph.h>
 | 
			
		||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
 | 
			
		||||
#include <gtsam/inference/Symbol.h>
 | 
			
		||||
#include <gtsam/linear/JacobianFactor.h>
 | 
			
		||||
#include <gtsam/nonlinear/PriorFactor.h>
 | 
			
		||||
 | 
			
		||||
using namespace std;
 | 
			
		||||
| 
						 | 
				
			
			@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) {
 | 
			
		|||
  HybridFactorGraph fg;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
// Test if methods to get keys work as expected.
 | 
			
		||||
TEST(HybridFactorGraph, Keys) {
 | 
			
		||||
  HybridGaussianFactorGraph hfg;
 | 
			
		||||
 | 
			
		||||
  // Add prior on x0
 | 
			
		||||
  hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
 | 
			
		||||
 | 
			
		||||
  // Add factor between x0 and x1
 | 
			
		||||
  hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
 | 
			
		||||
 | 
			
		||||
  // Add a gaussian mixture factor ϕ(x1, c1)
 | 
			
		||||
  DiscreteKey m1(M(1), 2);
 | 
			
		||||
  DecisionTree<Key, GaussianFactor::shared_ptr> dt(
 | 
			
		||||
      M(1), std::make_shared<JacobianFactor>(X(1), I_3x3, Z_3x1),
 | 
			
		||||
      std::make_shared<JacobianFactor>(X(1), I_3x3, Vector3::Ones()));
 | 
			
		||||
  hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));
 | 
			
		||||
 | 
			
		||||
  KeySet expected_continuous{X(0), X(1)};
 | 
			
		||||
  EXPECT(
 | 
			
		||||
      assert_container_equality(expected_continuous, hfg.continuousKeySet()));
 | 
			
		||||
 | 
			
		||||
  KeySet expected_discrete{M(1)};
 | 
			
		||||
  EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
int main() {
 | 
			
		||||
  TestResult tr;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
 | 
			
		|||
  // Test resulting posterior Bayes net has correct size:
 | 
			
		||||
  EXPECT_LONGS_EQUAL(8, posterior->size());
 | 
			
		||||
 | 
			
		||||
  // TODO(dellaert): this test fails - no idea why.
 | 
			
		||||
  // Ratio test
 | 
			
		||||
  EXPECT(ratioTest(bn, measurements, *posterior));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -492,7 +492,7 @@ factor 0:
 | 
			
		|||
factor 1: 
 | 
			
		||||
Hybrid [x0 x1; m0]{
 | 
			
		||||
 Choice(m0) 
 | 
			
		||||
 0 Leaf :
 | 
			
		||||
 0 Leaf [1] :
 | 
			
		||||
  A[x0] = [
 | 
			
		||||
	-1
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{
 | 
			
		|||
  b = [ -1 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Leaf :
 | 
			
		||||
 1 Leaf [1] :
 | 
			
		||||
  A[x0] = [
 | 
			
		||||
	-1
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{
 | 
			
		|||
factor 2: 
 | 
			
		||||
Hybrid [x1 x2; m1]{
 | 
			
		||||
 Choice(m1) 
 | 
			
		||||
 0 Leaf :
 | 
			
		||||
 0 Leaf [1] :
 | 
			
		||||
  A[x1] = [
 | 
			
		||||
	-1
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{
 | 
			
		|||
  b = [ -1 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Leaf :
 | 
			
		||||
 1 Leaf [1] :
 | 
			
		||||
  A[x1] = [
 | 
			
		||||
	-1
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			@ -550,16 +550,16 @@ factor 4:
 | 
			
		|||
  b = [ -10 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
factor 5:  P( m0 ):
 | 
			
		||||
 Leaf  0.5
 | 
			
		||||
 Leaf [2]  0.5
 | 
			
		||||
 | 
			
		||||
factor 6:  P( m1 | m0 ):
 | 
			
		||||
 Choice(m1) 
 | 
			
		||||
 0 Choice(m0) 
 | 
			
		||||
 0 0 Leaf 0.33333333
 | 
			
		||||
 0 1 Leaf  0.6
 | 
			
		||||
 0 0 Leaf [1] 0.33333333
 | 
			
		||||
 0 1 Leaf [1]  0.6
 | 
			
		||||
 1 Choice(m0) 
 | 
			
		||||
 1 0 Leaf 0.66666667
 | 
			
		||||
 1 1 Leaf  0.4
 | 
			
		||||
 1 0 Leaf [1] 0.66666667
 | 
			
		||||
 1 1 Leaf [1]  0.4
 | 
			
		||||
 | 
			
		||||
)";
 | 
			
		||||
  EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph));
 | 
			
		||||
| 
						 | 
				
			
			@ -570,13 +570,13 @@ size: 3
 | 
			
		|||
conditional 0: Hybrid  P( x0 | x1 m0)
 | 
			
		||||
 Discrete Keys = (m0, 2), 
 | 
			
		||||
 Choice(m0) 
 | 
			
		||||
 0 Leaf  p(x0 | x1)
 | 
			
		||||
 0 Leaf [1] p(x0 | x1)
 | 
			
		||||
  R = [ 10.0499 ]
 | 
			
		||||
  S[x1] = [ -0.0995037 ]
 | 
			
		||||
  d = [ -9.85087 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Leaf  p(x0 | x1)
 | 
			
		||||
 1 Leaf [1] p(x0 | x1)
 | 
			
		||||
  R = [ 10.0499 ]
 | 
			
		||||
  S[x1] = [ -0.0995037 ]
 | 
			
		||||
  d = [ -9.95037 ]
 | 
			
		||||
| 
						 | 
				
			
			@ -586,26 +586,26 @@ conditional 1: Hybrid  P( x1 | x2 m0 m1)
 | 
			
		|||
 Discrete Keys = (m0, 2), (m1, 2), 
 | 
			
		||||
 Choice(m1) 
 | 
			
		||||
 0 Choice(m0) 
 | 
			
		||||
 0 0 Leaf  p(x1 | x2)
 | 
			
		||||
 0 0 Leaf [1] p(x1 | x2)
 | 
			
		||||
  R = [ 10.099 ]
 | 
			
		||||
  S[x2] = [ -0.0990196 ]
 | 
			
		||||
  d = [ -9.99901 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 0 1 Leaf  p(x1 | x2)
 | 
			
		||||
 0 1 Leaf [1] p(x1 | x2)
 | 
			
		||||
  R = [ 10.099 ]
 | 
			
		||||
  S[x2] = [ -0.0990196 ]
 | 
			
		||||
  d = [ -9.90098 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Choice(m0) 
 | 
			
		||||
 1 0 Leaf  p(x1 | x2)
 | 
			
		||||
 1 0 Leaf [1] p(x1 | x2)
 | 
			
		||||
  R = [ 10.099 ]
 | 
			
		||||
  S[x2] = [ -0.0990196 ]
 | 
			
		||||
  d = [ -10.098 ]
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 1 Leaf  p(x1 | x2)
 | 
			
		||||
 1 1 Leaf [1] p(x1 | x2)
 | 
			
		||||
  R = [ 10.099 ]
 | 
			
		||||
  S[x2] = [ -0.0990196 ]
 | 
			
		||||
  d = [ -10 ]
 | 
			
		||||
| 
						 | 
				
			
			@ -615,14 +615,14 @@ conditional 2: Hybrid  P( x2 | m0 m1)
 | 
			
		|||
 Discrete Keys = (m0, 2), (m1, 2), 
 | 
			
		||||
 Choice(m1) 
 | 
			
		||||
 0 Choice(m0) 
 | 
			
		||||
 0 0 Leaf  p(x2)
 | 
			
		||||
 0 0 Leaf [1] p(x2)
 | 
			
		||||
  R = [ 10.0494 ]
 | 
			
		||||
  d = [ -10.1489 ]
 | 
			
		||||
  mean: 1 elements
 | 
			
		||||
  x2: -1.0099
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 0 1 Leaf  p(x2)
 | 
			
		||||
 0 1 Leaf [1] p(x2)
 | 
			
		||||
  R = [ 10.0494 ]
 | 
			
		||||
  d = [ -10.1479 ]
 | 
			
		||||
  mean: 1 elements
 | 
			
		||||
| 
						 | 
				
			
			@ -630,14 +630,14 @@ conditional 2: Hybrid  P( x2 | m0 m1)
 | 
			
		|||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 Choice(m0) 
 | 
			
		||||
 1 0 Leaf  p(x2)
 | 
			
		||||
 1 0 Leaf [1] p(x2)
 | 
			
		||||
  R = [ 10.0494 ]
 | 
			
		||||
  d = [ -10.0504 ]
 | 
			
		||||
  mean: 1 elements
 | 
			
		||||
  x2: -1.0001
 | 
			
		||||
  No noise model
 | 
			
		||||
 | 
			
		||||
 1 1 Leaf  p(x2)
 | 
			
		||||
 1 1 Leaf [1] p(x2)
 | 
			
		||||
  R = [ 10.0494 ]
 | 
			
		||||
  d = [ -10.0494 ]
 | 
			
		||||
  mean: 1 elements
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) {
 | 
			
		|||
      R"(Hybrid [x1 x2; 1]
 | 
			
		||||
MixtureFactor
 | 
			
		||||
 Choice(1) 
 | 
			
		||||
 0 Leaf Nonlinear factor on 2 keys
 | 
			
		||||
 1 Leaf Nonlinear factor on 2 keys
 | 
			
		||||
 0 Leaf [1] Nonlinear factor on 2 keys
 | 
			
		||||
 1 Leaf [1] Nonlinear factor on 2 keys
 | 
			
		||||
)";
 | 
			
		||||
  EXPECT(assert_print_equal(expected, mixtureFactor));
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -99,7 +99,7 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
  /* ************************************************************************ */
 | 
			
		||||
  void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const {
 | 
			
		||||
    cout << s << " p(";
 | 
			
		||||
    cout << (s.empty() ? "" : s + " ") << "p(";
 | 
			
		||||
    for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
 | 
			
		||||
      cout << formatter(*it) << (nrFrontals() > 1 ? " " : "");
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue