Merge branch 'feature/discrete_wrapper' into feature/discrete_wrapper_2
						commit
						7401b6e0c2
					
				| 
						 | 
				
			
			@ -83,6 +83,11 @@ namespace gtsam {
 | 
			
		|||
    //** evaluate for given DiscreteValues */
 | 
			
		||||
    double evaluate(const DiscreteValues & values) const;
 | 
			
		||||
 | 
			
		||||
    //** (Preferred) sugar for the above for given DiscreteValues */
 | 
			
		||||
    double operator()(const DiscreteValues & values) const {
 | 
			
		||||
      return evaluate(values);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
    * Solve the DiscreteBayesNet by back-substitution
 | 
			
		||||
    */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -80,6 +80,12 @@ class GTSAM_EXPORT DiscreteBayesTree
 | 
			
		|||
 | 
			
		||||
  //** evaluate probability for given DiscreteValues */
 | 
			
		||||
  double evaluate(const DiscreteValues& values) const;
 | 
			
		||||
 | 
			
		||||
  //** (Preferred) sugar for the above for given DiscreteValues */
 | 
			
		||||
  double operator()(const DiscreteValues & values) const {
 | 
			
		||||
    return evaluate(values);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,24 +58,28 @@ public:
 | 
			
		|||
  DiscreteConditional(const Signature& signature);
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
  * Construct from key, parents, and a Table specifying the CPT.
 | 
			
		||||
  *
 | 
			
		||||
  * The first string is parsed to add a key and parents.
 | 
			
		||||
  * 
 | 
			
		||||
  * Example: DiscreteConditional P(D, {B,E}, table);
 | 
			
		||||
  */
 | 
			
		||||
   * Construct from key, parents, and a Signature::Table specifying the
 | 
			
		||||
   * conditional probability table (CPT) in 00 01 10 11 order. For
 | 
			
		||||
   * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
 | 
			
		||||
   *
 | 
			
		||||
   * The first string is parsed to add a key and parents.
 | 
			
		||||
   *
 | 
			
		||||
   * Example: DiscreteConditional P(D, {B,E}, table);
 | 
			
		||||
   */
 | 
			
		||||
  DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
 | 
			
		||||
                      const Signature::Table& table)
 | 
			
		||||
      : DiscreteConditional(Signature(key, parents, table)) {}
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
  * Construct from key, parents, and a string specifying the CPT.
 | 
			
		||||
  *
 | 
			
		||||
  * The first string is parsed to add a key and parents. The second string
 | 
			
		||||
  * parses into a table.
 | 
			
		||||
  * 
 | 
			
		||||
  * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
  */
 | 
			
		||||
   * Construct from key, parents, and a string specifying the conditional
 | 
			
		||||
   * probability table (CPT) in 00 01 10 11 order. For three-valued, it would
 | 
			
		||||
   * be 00 01 02 10 11 12 20 21 22, etc....
 | 
			
		||||
   *
 | 
			
		||||
   * The first string is parsed to add a key and parents. The second string
 | 
			
		||||
   * parses into a table.
 | 
			
		||||
   *
 | 
			
		||||
   * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
   */
 | 
			
		||||
  DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
 | 
			
		||||
                      const std::string& spec)
 | 
			
		||||
      : DiscreteConditional(Signature(key, parents, spec)) {}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -83,9 +83,6 @@ public:
 | 
			
		|||
  /// Find value for given assignment of values to variables
 | 
			
		||||
  virtual double operator()(const DiscreteValues&) const = 0;
 | 
			
		||||
 | 
			
		||||
  /// Synonym for operator(), mostly for wrapper
 | 
			
		||||
  double evaluate(const DiscreteValues& values) const { return operator()(values); }
 | 
			
		||||
 | 
			
		||||
  /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
 | 
			
		||||
  virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -136,9 +136,6 @@ public:
 | 
			
		|||
   */
 | 
			
		||||
  double operator()(const DiscreteValues& values) const;
 | 
			
		||||
 | 
			
		||||
  /// Synonym for operator(), mostly for wrapper
 | 
			
		||||
  double evaluate(const DiscreteValues& values) const { return operator()(values); }
 | 
			
		||||
 | 
			
		||||
  /// print
 | 
			
		||||
  void print(
 | 
			
		||||
      const std::string& s = "DiscreteFactorGraph",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,7 +30,7 @@ namespace gtsam {
 | 
			
		|||
   * The format is (Key % string) for nodes with no parents,
 | 
			
		||||
   * and (Key | Key, Key = string) for nodes with parents.
 | 
			
		||||
   *
 | 
			
		||||
   * The string specifies a conditional probability spec in the 00 01 10 11 order.
 | 
			
		||||
   * The string specifies a conditional probability table in 00 01 10 11 order.
 | 
			
		||||
   * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc...
 | 
			
		||||
   *
 | 
			
		||||
   * For example, given the following keys
 | 
			
		||||
| 
						 | 
				
			
			@ -73,22 +73,29 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
  public:
 | 
			
		||||
   /**
 | 
			
		||||
    * Construct from key, parents, and a Table specifying the CPT.
 | 
			
		||||
    * Construct from key, parents, and a Signature::Table specifying the
 | 
			
		||||
    * conditional probability table (CPT) in 00 01 10 11 order. For
 | 
			
		||||
    * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
 | 
			
		||||
    *
 | 
			
		||||
    * The first string is parsed to add a key and parents.
 | 
			
		||||
    * 
 | 
			
		||||
    * Example: Signature sig(D, {B,E}, table);
 | 
			
		||||
    *
 | 
			
		||||
    * Example:
 | 
			
		||||
    *   Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
 | 
			
		||||
    *   Signature sig(D, {E, B}, table);
 | 
			
		||||
    */
 | 
			
		||||
   Signature(const DiscreteKey& key, const DiscreteKeys& parents,
 | 
			
		||||
             const Table& table);
 | 
			
		||||
 | 
			
		||||
   /**
 | 
			
		||||
    * Construct from key, parents, and a string specifying the CPT.
 | 
			
		||||
    * Construct from key, parents, and a string specifying the conditional
 | 
			
		||||
    * probability table (CPT) in 00 01 10 11 order. For three-valued, it would
 | 
			
		||||
    * be 00 01 02 10 11 12 20 21 22, etc....
 | 
			
		||||
    *
 | 
			
		||||
    * The first string is parsed to add a key and parents. The second string
 | 
			
		||||
    * parses into a table.
 | 
			
		||||
    * 
 | 
			
		||||
    * Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
    *
 | 
			
		||||
    * Example (same CPT as above):
 | 
			
		||||
    *   Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
    */
 | 
			
		||||
   Signature(const DiscreteKey& key, const DiscreteKeys& parents,
 | 
			
		||||
             const std::string& spec);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,5 @@
 | 
			
		|||
//*************************************************************************
 | 
			
		||||
// basis
 | 
			
		||||
// discrete
 | 
			
		||||
//*************************************************************************
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ class DiscreteFactor {
 | 
			
		|||
  bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
 | 
			
		||||
  bool empty() const;
 | 
			
		||||
  size_t size() const;
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -38,7 +38,7 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
 | 
			
		|||
             const gtsam::KeyFormatter& keyFormatter =
 | 
			
		||||
                 gtsam::DefaultKeyFormatter) const;
 | 
			
		||||
  bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteConditional.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -53,7 +53,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
 | 
			
		|||
                      const gtsam::DecisionTreeFactor& marginal,
 | 
			
		||||
                      const gtsam::Ordering& orderedKeys);
 | 
			
		||||
  size_t size() const; // TODO(dellaert): why do I have to repeat???
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const; // TODO(dellaert): why do I have to repeat???
 | 
			
		||||
  void print(string s = "Discrete Conditional\n",
 | 
			
		||||
             const gtsam::KeyFormatter& keyFormatter =
 | 
			
		||||
                 gtsam::DefaultKeyFormatter) const;
 | 
			
		||||
| 
						 | 
				
			
			@ -86,7 +86,7 @@ class DiscreteBayesNet {
 | 
			
		|||
                const gtsam::KeyFormatter& keyFormatter =
 | 
			
		||||
                 gtsam::DefaultKeyFormatter) const;
 | 
			
		||||
  void add(const gtsam::DiscreteConditional& s);
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  gtsam::DiscreteValues optimize() const;
 | 
			
		||||
  gtsam::DiscreteValues sample() const;
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +98,7 @@ class DiscreteBayesTree {
 | 
			
		|||
             const gtsam::KeyFormatter& keyFormatter =
 | 
			
		||||
                 gtsam::DefaultKeyFormatter) const;
 | 
			
		||||
  bool equals(const gtsam::DiscreteBayesTree& other, double tol = 1e-9) const;
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -119,7 +119,7 @@ class DiscreteFactorGraph {
 | 
			
		|||
  bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
 | 
			
		||||
  
 | 
			
		||||
  gtsam::DecisionTreeFactor product() const;
 | 
			
		||||
  double evaluate(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  double operator()(const gtsam::DiscreteValues& values) const;
 | 
			
		||||
  gtsam::DiscreteValues optimize() const;
 | 
			
		||||
 | 
			
		||||
  gtsam::DiscreteBayesNet eliminateSequential();
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -92,7 +92,6 @@ TEST(testSignature, all_examples) {
 | 
			
		|||
  Signature b(B, {S}, "70/30  40/60");
 | 
			
		||||
  Signature e(E, {T, L}, "F F F 1");
 | 
			
		||||
  Signature x(X, {E}, "95/5 2/98");
 | 
			
		||||
  Signature d(D, {E, B}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Make sure we can create all signatures for Asia network with operator magic.
 | 
			
		||||
| 
						 | 
				
			
			@ -105,7 +104,17 @@ TEST(testSignature, all_examples_magic) {
 | 
			
		|||
  Signature b(B | S = "70/30  40/60");
 | 
			
		||||
  Signature e((E | T, L) = "F F F 1");
 | 
			
		||||
  Signature x(X | E = "95/5 2/98");
 | 
			
		||||
  Signature d((D | E, B) = "9/1 2/8 3/7 1/9");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check example from docs.
 | 
			
		||||
TEST(testSignature, doxygen_example) {
 | 
			
		||||
  Signature::Table table{{0.9, 0.1}, {0.2, 0.8}, {0.3, 0.7}, {0.1, 0.9}};
 | 
			
		||||
  Signature d1(D, {E, B}, table);
 | 
			
		||||
  Signature d2((D | E, B) = "9/1 2/8 3/7 1/9");
 | 
			
		||||
  Signature d3(D, {E, B}, "9/1 2/8 3/7 1/9");
 | 
			
		||||
  EXPECT(*(d1.table()) == table);
 | 
			
		||||
  EXPECT(*(d2.table()) == table);
 | 
			
		||||
  EXPECT(*(d3.table()) == table);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -91,6 +91,9 @@ class TestDiscreteBayesNet(GtsamTestCase):
 | 
			
		|||
        self.assertEqual(list(actualMPE.items()),
 | 
			
		||||
                         list(expectedMPE.items()))
 | 
			
		||||
 | 
			
		||||
        # Check value for MPE is the same
 | 
			
		||||
        self.assertAlmostEqual(asia(actualMPE), fg(actualMPE))
 | 
			
		||||
 | 
			
		||||
        # add evidence, we were in Asia and we have dyspnea
 | 
			
		||||
        fg.add(Asia, "0 1")
 | 
			
		||||
        fg.add(Dyspnea, "0 1")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,9 +44,9 @@ class TestDiscreteFactorGraph(GtsamTestCase):
 | 
			
		|||
        assignment[1] = 1
 | 
			
		||||
 | 
			
		||||
        # Check if graph evaluation works ( 0.3*0.6*4 )
 | 
			
		||||
        self.assertAlmostEqual(.72, graph.evaluate(assignment))
 | 
			
		||||
        self.assertAlmostEqual(.72, graph(assignment))
 | 
			
		||||
 | 
			
		||||
        # Creating a new test with third node and adding unary and ternary factors on it
 | 
			
		||||
        # Create a new test with third node and adding unary and ternary factor
 | 
			
		||||
        graph.add(P3, "0.9 0.2 0.5")
 | 
			
		||||
        keys = DiscreteKeys()
 | 
			
		||||
        keys.push_back(P1)
 | 
			
		||||
| 
						 | 
				
			
			@ -54,25 +54,25 @@ class TestDiscreteFactorGraph(GtsamTestCase):
 | 
			
		|||
        keys.push_back(P3)
 | 
			
		||||
        graph.add(keys, "1 2 3 4 5 6 7 8 9 10 11 12")
 | 
			
		||||
 | 
			
		||||
        # Below assignment lead to selecting the 8th index in the ternary factor table
 | 
			
		||||
        # Below assignment selects the 8th index in the ternary factor table
 | 
			
		||||
        assignment[0] = 1
 | 
			
		||||
        assignment[1] = 0
 | 
			
		||||
        assignment[2] = 1
 | 
			
		||||
 | 
			
		||||
        # Check if graph evaluation works (0.3*0.9*1*0.2*8)
 | 
			
		||||
        self.assertAlmostEqual(4.32, graph.evaluate(assignment))
 | 
			
		||||
        self.assertAlmostEqual(4.32, graph(assignment))
 | 
			
		||||
 | 
			
		||||
        # Below assignment lead to selecting the 3rd index in the ternary factor table
 | 
			
		||||
        # Below assignment selects the 3rd index in the ternary factor table
 | 
			
		||||
        assignment[0] = 0
 | 
			
		||||
        assignment[1] = 1
 | 
			
		||||
        assignment[2] = 0
 | 
			
		||||
 | 
			
		||||
        # Check if graph evaluation works (0.9*0.6*1*0.9*4)
 | 
			
		||||
        self.assertAlmostEqual(1.944, graph.evaluate(assignment))
 | 
			
		||||
        self.assertAlmostEqual(1.944, graph(assignment))
 | 
			
		||||
 | 
			
		||||
        # Check if graph product works
 | 
			
		||||
        product = graph.product()
 | 
			
		||||
        self.assertAlmostEqual(1.944, product.evaluate(assignment))
 | 
			
		||||
        self.assertAlmostEqual(1.944, product(assignment))
 | 
			
		||||
 | 
			
		||||
    def test_optimize(self):
 | 
			
		||||
        """Test constructing and optizing a discrete factor graph."""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue