Merge pull request #1043 from borglab/feature/more_discrete_methods

release/4.3a0
Frank Dellaert 2022-01-18 11:11:41 -05:00 committed by GitHub
commit 296b469df9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 72 deletions

View File

@ -29,8 +29,7 @@ namespace gtsam {
template class FactorGraph<DiscreteConditional>; template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */ /* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const bool DiscreteBayesNet::equals(const This& bn, double tol) const {
{
return Base::equals(bn, tol); return Base::equals(bn, tol);
} }
@ -45,8 +44,12 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
DiscreteValues result; DiscreteValues result;
return optimize(result);
}
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this)) for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result); conditional->solveInPlace(&result);
return result; return result;
@ -54,8 +57,12 @@ namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const { DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
DiscreteValues result; DiscreteValues result;
return sample(result);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this)) for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result); conditional->sampleInPlace(&result);
return result; return result;
@ -74,8 +81,7 @@ namespace gtsam {
} }
/* *********************************************************************** */ /* *********************************************************************** */
std::string DiscreteBayesNet::html( std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const { const DiscreteFactor::Names& names) const {
using std::endl; using std::endl;
std::stringstream ss; std::stringstream ss;
@ -86,4 +92,4 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
} // namespace } // namespace gtsam

View File

@ -99,13 +99,47 @@ namespace gtsam {
} }
/** /**
* Solve the DiscreteBayesNet by back-substitution * @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/ */
DiscreteValues optimize() const; DiscreteValues optimize() const;
/** Do ancestral sampling */ /**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;
/**
* @brief do ancestral sampling
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const; DiscreteValues sample() const;
/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
///@} ///@}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -165,7 +165,9 @@ class DiscreteBayesNet {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,

View File

@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
# Some keys:
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets.""" """Tests for Discrete Bayes Nets."""
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
def test_Asia(self): def test_Asia(self):
"""Test full Asia example.""" """Test full Asia example."""
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
asia = DiscreteBayesNet() asia = DiscreteBayesNet()
asia.add(Asia, "99/1") asia.add(Asia, "99/1")
asia.add(Smoking, "50/50") asia.add(Smoking, "50/50")
@ -107,6 +108,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
actualSample = chordal2.sample() actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8) self.assertEqual(len(actualSample), 8)
def test_fragment(self):
"""Test sampling and optimizing for Asia fragment."""
# Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")
# Create assignment with missing values:
given = DiscreteValues()
for key in [Asia, Smoking]:
given[key[0]] = 0
# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)
# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()