Wrap single-argument methods
parent
10628a0ddc
commit
a1b8f52da8
|
@ -79,6 +79,8 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||||
void print(string s = "Discrete Prior\n",
|
void print(string s = "Discrete Prior\n",
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
double operator()(size_t value) const;
|
||||||
|
std::vector<double> pmf() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
|
|
@ -13,16 +13,18 @@ Author: Varun Agrawal
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DiscretePrior, DecisionTreeFactor, DiscreteKeys
|
import numpy as np
|
||||||
|
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
X = 0, 2
|
||||||
|
|
||||||
|
|
||||||
class TestDiscretePrior(GtsamTestCase):
|
class TestDiscretePrior(GtsamTestCase):
|
||||||
"""Tests for Discrete Priors."""
|
"""Tests for Discrete Priors."""
|
||||||
|
|
||||||
def test_constructor(self):
|
def test_constructor(self):
|
||||||
"""Test various constructors."""
|
"""Test various constructors."""
|
||||||
X = 0, 2
|
|
||||||
actual = DiscretePrior(X, "2/3")
|
actual = DiscretePrior(X, "2/3")
|
||||||
keys = DiscreteKeys()
|
keys = DiscreteKeys()
|
||||||
keys.push_back(X)
|
keys.push_back(X)
|
||||||
|
@ -30,10 +32,19 @@ class TestDiscretePrior(GtsamTestCase):
|
||||||
expected = DiscretePrior(f)
|
expected = DiscretePrior(f)
|
||||||
self.gtsamAssertEquals(actual, expected)
|
self.gtsamAssertEquals(actual, expected)
|
||||||
|
|
||||||
|
def test_operator(self):
|
||||||
|
prior = DiscretePrior(X, "2/3")
|
||||||
|
self.assertAlmostEqual(prior(0), 0.4)
|
||||||
|
self.assertAlmostEqual(prior(1), 0.6)
|
||||||
|
|
||||||
|
def test_pmf(self):
|
||||||
|
prior = DiscretePrior(X, "2/3")
|
||||||
|
expected = np.array([0.4, 0.6])
|
||||||
|
np.testing.assert_allclose(expected, prior.pmf())
|
||||||
|
|
||||||
def test_markdown(self):
|
def test_markdown(self):
|
||||||
"""Test the _repr_markdown_ method."""
|
"""Test the _repr_markdown_ method."""
|
||||||
|
|
||||||
X = 0, 2
|
|
||||||
prior = DiscretePrior(X, "2/3")
|
prior = DiscretePrior(X, "2/3")
|
||||||
expected = " $P(0)$:\n" \
|
expected = " $P(0)$:\n" \
|
||||||
"|0|value|\n" \
|
"|0|value|\n" \
|
||||||
|
|
Loading…
Reference in New Issue