sample variants
parent
a74da73936
commit
10e1bd2f61
|
@ -25,44 +25,51 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
// Instantiate base class
|
||||
template class FactorGraph<DiscreteConditional>;
|
||||
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
||||
{
|
||||
/* ************************************************************************* */
|
||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
||||
/* ************************************************************************* */
|
||||
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||
// evaluate all conditionals and multiply
|
||||
double result = 1.0;
|
||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
result *= (*conditional)(values);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||
DiscreteValues result;
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
return optimize(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||
// solve each node in turn in topological sort order (parents first)
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->solveInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
/* ************************************************************************* */
|
||||
DiscreteValues DiscreteBayesNet::sample() const {
|
||||
DiscreteValues result;
|
||||
for (auto conditional: boost::adaptors::reverse(*this))
|
||||
return sample(result);
|
||||
}
|
||||
|
||||
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||
// sample each node in turn in topological sort order (parents first)
|
||||
for (auto conditional : boost::adaptors::reverse(*this))
|
||||
conditional->sampleInPlace(&result);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::markdown(
|
||||
const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
|
@ -71,11 +78,10 @@ namespace gtsam {
|
|||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::html(
|
||||
const KeyFormatter& keyFormatter,
|
||||
/* *********************************************************************** */
|
||||
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||
const DiscreteFactor::Names& names) const {
|
||||
using std::endl;
|
||||
std::stringstream ss;
|
||||
|
@ -83,7 +89,7 @@ namespace gtsam {
|
|||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||
ss << conditional->html(keyFormatter, names) << endl;
|
||||
return ss.str();
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
} // namespace
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -99,13 +99,47 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/**
|
||||
* Solve the DiscreteBayesNet by back-substitution
|
||||
* @brief solve by back-substitution.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||
* conditional will be optimized first. If the Bayes net resulted from
|
||||
* eliminating a factor graph, this is true for the elimination ordering.
|
||||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues optimize() const;
|
||||
|
||||
/** Do ancestral sampling */
|
||||
/**
|
||||
* @brief solve by back-substitution, given certain variables.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||
* Bayes net does not contain any conditionals for the given values.
|
||||
*
|
||||
* @return given values extended with optimized value for other variables.
|
||||
*/
|
||||
DiscreteValues optimize(DiscreteValues given) const;
|
||||
|
||||
/**
|
||||
* @brief do ancestral sampling
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||
* conditional will be sampled first. If the Bayes net resulted from
|
||||
* eliminating a factor graph, this is true for the elimination ordering.
|
||||
*
|
||||
* @return a sampled value for all variables.
|
||||
*/
|
||||
DiscreteValues sample() const;
|
||||
|
||||
/**
|
||||
* @brief do ancestral sampling, given certain variables.
|
||||
*
|
||||
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||
* Bayes net does not contain any conditionals for the given values.
|
||||
*
|
||||
* @return given values extended with sampled value for all other variables.
|
||||
*/
|
||||
DiscreteValues sample(DiscreteValues given) const;
|
||||
|
||||
///@}
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
|
|
@ -165,7 +165,9 @@ class DiscreteBayesNet {
|
|||
gtsam::DefaultKeyFormatter) const;
|
||||
double operator()(const gtsam::DiscreteValues& values) const;
|
||||
gtsam::DiscreteValues optimize() const;
|
||||
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
|
||||
gtsam::DiscreteValues sample() const;
|
||||
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||
gtsam::DefaultKeyFormatter) const;
|
||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||
|
|
|
@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
|||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||
from gtsam.utils.test_case import GtsamTestCase
|
||||
|
||||
# Some keys:
|
||||
Asia = (0, 2)
|
||||
Smoking = (4, 2)
|
||||
Tuberculosis = (3, 2)
|
||||
LungCancer = (6, 2)
|
||||
|
||||
Bronchitis = (7, 2)
|
||||
Either = (5, 2)
|
||||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
|
||||
class TestDiscreteBayesNet(GtsamTestCase):
|
||||
"""Tests for Discrete Bayes Nets."""
|
||||
|
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
def test_Asia(self):
|
||||
"""Test full Asia example."""
|
||||
|
||||
Asia = (0, 2)
|
||||
Smoking = (4, 2)
|
||||
Tuberculosis = (3, 2)
|
||||
LungCancer = (6, 2)
|
||||
|
||||
Bronchitis = (7, 2)
|
||||
Either = (5, 2)
|
||||
XRay = (2, 2)
|
||||
Dyspnea = (1, 2)
|
||||
|
||||
asia = DiscreteBayesNet()
|
||||
asia.add(Asia, "99/1")
|
||||
asia.add(Smoking, "50/50")
|
||||
|
@ -107,6 +108,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
|||
actualSample = chordal2.sample()
|
||||
self.assertEqual(len(actualSample), 8)
|
||||
|
||||
def test_fragment(self):
|
||||
"""Test sampling and optimizing for Asia fragment."""
|
||||
|
||||
# Create a reverse-topologically sorted fragment:
|
||||
fragment = DiscreteBayesNet()
|
||||
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
|
||||
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
||||
|
||||
# Create assignment with missing values:
|
||||
given = DiscreteValues()
|
||||
for key in [Asia, Smoking]:
|
||||
given[key[0]] = 0
|
||||
|
||||
# Now optimize fragment:
|
||||
actual = fragment.optimize(given)
|
||||
self.assertEqual(len(actual), 5)
|
||||
|
||||
# Now sample from fragment:
|
||||
actual = fragment.sample(given)
|
||||
self.assertEqual(len(actual), 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue