diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 8fdedab44..c9c6afa79 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -73,6 +73,8 @@ struct HybridAssignmentData { GaussianBayesTree::sharedNode parentClique_; // The gaussian bayes tree that will be recursively created. GaussianBayesTree* gaussianbayesTree_; + // Flag indicating if all the nodes are valid. Used in optimize(). + bool valid_; /** * @brief Construct a new Hybrid Assignment Data object. @@ -83,10 +85,13 @@ struct HybridAssignmentData { */ HybridAssignmentData(const DiscreteValues& assignment, const GaussianBayesTree::sharedNode& parentClique, - GaussianBayesTree* gbt) + GaussianBayesTree* gbt, bool valid = true) : assignment_(assignment), parentClique_(parentClique), - gaussianbayesTree_(gbt) {} + gaussianbayesTree_(gbt), + valid_(valid) {} + + bool isValid() const { return valid_; } /** * @brief A function used during tree traversal that operates on each node @@ -101,6 +106,7 @@ struct HybridAssignmentData { HybridAssignmentData& parentData) { // Extract the gaussian conditional from the Hybrid clique HybridConditional::shared_ptr hybrid_conditional = node->conditional(); + GaussianConditional::shared_ptr conditional; if (hybrid_conditional->isHybrid()) { conditional = (*hybrid_conditional->asMixture())(parentData.assignment_); @@ -111,15 +117,21 @@ struct HybridAssignmentData { conditional = boost::make_shared(); } - // Create the GaussianClique for the current node - auto clique = boost::make_shared(conditional); - // Add the current clique to the GaussianBayesTree. - parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_); + GaussianBayesTree::sharedNode clique; + if (conditional) { + // Create the GaussianClique for the current node + clique = boost::make_shared(conditional); + // Add the current clique to the GaussianBayesTree. + parentData.gaussianbayesTree_->addClique(clique, + parentData.parentClique_); + } else { + parentData.valid_ = false; + } // Create new HybridAssignmentData where the current node is the parent // This will be passed down to the children nodes HybridAssignmentData data(parentData.assignment_, clique, - parentData.gaussianbayesTree_); + parentData.gaussianbayesTree_, parentData.valid_); return data; } }; @@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { visitorPost); } + if (!rootData.isValid()) { + return VectorValues(); + } VectorValues result = gbt.optimize(); // Return the optimized bayes net result. @@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); decisionTree->root_ = prunedDecisionTree.root_; + // this->print(); + // decisionTree->print("", DefaultKeyFormatter); /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData {