Make HybridBayesNet testable and add serialization
							parent
							
								
									eb5092897b
								
							
						
					
					
						commit
						8692ae63ea
					
				| 
						 | 
					@ -18,6 +18,7 @@
 | 
				
			||||||
#pragma once
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
					#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
				
			||||||
 | 
					#include <gtsam/global_includes.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridConditional.h>
 | 
					#include <gtsam/hybrid/HybridConditional.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridValues.h>
 | 
					#include <gtsam/hybrid/HybridValues.h>
 | 
				
			||||||
#include <gtsam/inference/BayesNet.h>
 | 
					#include <gtsam/inference/BayesNet.h>
 | 
				
			||||||
| 
						 | 
					@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
				
			||||||
  using shared_ptr = boost::shared_ptr<HybridBayesNet>;
 | 
					  using shared_ptr = boost::shared_ptr<HybridBayesNet>;
 | 
				
			||||||
  using sharedConditional = boost::shared_ptr<ConditionalType>;
 | 
					  using sharedConditional = boost::shared_ptr<ConditionalType>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @name Standard Constructors
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /** Construct empty bayes net */
 | 
					  /** Construct empty bayes net */
 | 
				
			||||||
  HybridBayesNet() = default;
 | 
					  HybridBayesNet() = default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Prune the Hybrid Bayes Net given the discrete decision tree.
 | 
					  /// @}
 | 
				
			||||||
  HybridBayesNet prune(
 | 
					  /// @name Testable
 | 
				
			||||||
      const DecisionTreeFactor::shared_ptr &discreteFactor) const;
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Check equality */
 | 
				
			||||||
 | 
					  bool equals(const This &bn, double tol = 1e-9) const {
 | 
				
			||||||
 | 
					    return Base::equals(bn, tol);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// print graph
 | 
				
			||||||
 | 
					  void print(
 | 
				
			||||||
 | 
					      const std::string &s = "",
 | 
				
			||||||
 | 
					      const KeyFormatter &formatter = DefaultKeyFormatter) const override {
 | 
				
			||||||
 | 
					    Base::print(s, formatter);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					  /// @name Standard Interface
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Add HybridConditional to Bayes Net
 | 
					  /// Add HybridConditional to Bayes Net
 | 
				
			||||||
  using Base::add;
 | 
					  using Base::add;
 | 
				
			||||||
| 
						 | 
					@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  GaussianBayesNet choose(const DiscreteValues &assignment) const;
 | 
					  GaussianBayesNet choose(const DiscreteValues &assignment) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Solve the HybridBayesNet by back-substitution.
 | 
					  /**
 | 
				
			||||||
  /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
 | 
					   * @brief Solve the HybridBayesNet by first computing the MPE of all the
 | 
				
			||||||
  /// put this method there?
 | 
					   * discrete variables and then optimizing the continuous variables based on
 | 
				
			||||||
 | 
					   * the MPE assignment.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @return HybridValues
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
  HybridValues optimize() const;
 | 
					  HybridValues optimize() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
| 
						 | 
					@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
 | 
				
			||||||
   * @return Values
 | 
					   * @return Values
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  VectorValues optimize(const DiscreteValues &assignment) const;
 | 
					  VectorValues optimize(const DiscreteValues &assignment) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Prune the Hybrid Bayes Net given the discrete decision tree.
 | 
				
			||||||
 | 
					  HybridBayesNet prune(
 | 
				
			||||||
 | 
					      const DecisionTreeFactor::shared_ptr &discreteFactor) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  /** Serialization function */
 | 
				
			||||||
 | 
					  friend class boost::serialization::access;
 | 
				
			||||||
 | 
					  template <class ARCHIVE>
 | 
				
			||||||
 | 
					  void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
 | 
				
			||||||
 | 
					    ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// traits
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace gtsam
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -18,6 +18,7 @@
 | 
				
			||||||
 * @date    December 2021
 | 
					 * @date    December 2021
 | 
				
			||||||
 */
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <gtsam/base/serializationTestHelpers.h>
 | 
				
			||||||
#include <gtsam/hybrid/HybridBayesNet.h>
 | 
					#include <gtsam/hybrid/HybridBayesNet.h>
 | 
				
			||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
 | 
					#include <gtsam/nonlinear/NonlinearFactorGraph.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -28,6 +29,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace std;
 | 
					using namespace std;
 | 
				
			||||||
using namespace gtsam;
 | 
					using namespace gtsam;
 | 
				
			||||||
 | 
					using namespace gtsam::serializationTestHelpers;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using noiseModel::Isotropic;
 | 
					using noiseModel::Isotropic;
 | 
				
			||||||
using symbol_shorthand::M;
 | 
					using symbol_shorthand::M;
 | 
				
			||||||
using symbol_shorthand::X;
 | 
					using symbol_shorthand::X;
 | 
				
			||||||
| 
						 | 
					@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
 | 
				
			||||||
  EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
 | 
					  EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ****************************************************************************/
 | 
				
			||||||
 | 
					// Test HybridBayesNet serialization.
 | 
				
			||||||
 | 
					TEST(HybridBayesNet, Serialization) {
 | 
				
			||||||
 | 
					  Switching s(4);
 | 
				
			||||||
 | 
					  Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
 | 
				
			||||||
 | 
					  HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT(equalsObj<HybridBayesNet>(hbn));
 | 
				
			||||||
 | 
					  EXPECT(equalsXML<HybridBayesNet>(hbn));
 | 
				
			||||||
 | 
					  EXPECT(equalsBinary<HybridBayesNet>(hbn));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************* */
 | 
					/* ************************************************************************* */
 | 
				
			||||||
int main() {
 | 
					int main() {
 | 
				
			||||||
  TestResult tr;
 | 
					  TestResult tr;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue