sample variants

release/4.3a0
Frank Dellaert 2022-01-17 22:59:17 -05:00
parent a74da73936
commit 10e1bd2f61
4 changed files with 137 additions and 72 deletions

View File

@ -25,44 +25,51 @@
namespace gtsam {
// Instantiate base class
template class FactorGraph<DiscreteConditional>;
// Instantiate base class
template class FactorGraph<DiscreteConditional>;
/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const
{
/* ************************************************************************* */
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
}
}
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(const DiscreteConditional::shared_ptr& conditional: *this)
for (const DiscreteConditional::shared_ptr& conditional : *this)
result *= (*conditional)(values);
return result;
}
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
// solve each node in turn in topological sort order (parents first)
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
return optimize(result);
}
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
}
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
// sample each node in turn in topological sort order (parents first)
/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
DiscreteValues result;
for (auto conditional: boost::adaptors::reverse(*this))
return sample(result);
}
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto conditional : boost::adaptors::reverse(*this))
conditional->sampleInPlace(&result);
return result;
}
}
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
@ -71,11 +78,10 @@ namespace gtsam {
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->markdown(keyFormatter, names) << endl;
return ss.str();
}
}
/* *********************************************************************** */
std::string DiscreteBayesNet::html(
const KeyFormatter& keyFormatter,
/* *********************************************************************** */
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
const DiscreteFactor::Names& names) const {
using std::endl;
std::stringstream ss;
@ -83,7 +89,7 @@ namespace gtsam {
for (const DiscreteConditional::shared_ptr& conditional : *this)
ss << conditional->html(keyFormatter, names) << endl;
return ss.str();
}
}
/* ************************************************************************* */
} // namespace
} // namespace gtsam

View File

@ -99,13 +99,47 @@ namespace gtsam {
}
/**
* Solve the DiscreteBayesNet by back-substitution
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;
/** Do ancestral sampling */
/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;
/**
* @brief do ancestral sampling
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be sampled first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues sample() const;
/**
* @brief do ancestral sampling, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with sampled value for all other variables.
*/
DiscreteValues sample(DiscreteValues given) const;
///@}
/// @name Wrapper support
/// @{

View File

@ -165,7 +165,9 @@ class DiscreteBayesNet {
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,

View File

@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
from gtsam.utils.test_case import GtsamTestCase
# Some keys:
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
class TestDiscreteBayesNet(GtsamTestCase):
"""Tests for Discrete Bayes Nets."""
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
def test_Asia(self):
"""Test full Asia example."""
Asia = (0, 2)
Smoking = (4, 2)
Tuberculosis = (3, 2)
LungCancer = (6, 2)
Bronchitis = (7, 2)
Either = (5, 2)
XRay = (2, 2)
Dyspnea = (1, 2)
asia = DiscreteBayesNet()
asia.add(Asia, "99/1")
asia.add(Smoking, "50/50")
@ -107,6 +108,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
actualSample = chordal2.sample()
self.assertEqual(len(actualSample), 8)
def test_fragment(self):
"""Test sampling and optimizing for Asia fragment."""
# Create a reverse-topologically sorted fragment:
fragment = DiscreteBayesNet()
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
fragment.add(LungCancer, [Smoking], "99/1 90/10")
# Create assignment with missing values:
given = DiscreteValues()
for key in [Asia, Smoking]:
given[key[0]] = 0
# Now optimize fragment:
actual = fragment.optimize(given)
self.assertEqual(len(actual), 5)
# Now sample from fragment:
actual = fragment.sample(given)
self.assertEqual(len(actual), 5)
if __name__ == "__main__":
unittest.main()