sample variants
parent
a74da73936
commit
10e1bd2f61
|
@ -25,65 +25,71 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
template class FactorGraph<DiscreteConditional>;
|
template class FactorGraph<DiscreteConditional>;
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
bool DiscreteBayesNet::equals(const This& bn, double tol) const
|
|
||||||
{
|
|
||||||
return Base::equals(bn, tol);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
|
|
||||||
// evaluate all conditionals and multiply
|
|
||||||
double result = 1.0;
|
|
||||||
for(const DiscreteConditional::shared_ptr& conditional: *this)
|
|
||||||
result *= (*conditional)(values);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteValues DiscreteBayesNet::optimize() const {
|
|
||||||
// solve each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteValues result;
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->solveInPlace(&result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DiscreteValues DiscreteBayesNet::sample() const {
|
|
||||||
// sample each node in turn in topological sort order (parents first)
|
|
||||||
DiscreteValues result;
|
|
||||||
for (auto conditional: boost::adaptors::reverse(*this))
|
|
||||||
conditional->sampleInPlace(&result);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *********************************************************************** */
|
|
||||||
std::string DiscreteBayesNet::markdown(
|
|
||||||
const KeyFormatter& keyFormatter,
|
|
||||||
const DiscreteFactor::Names& names) const {
|
|
||||||
using std::endl;
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
|
||||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
|
||||||
ss << conditional->markdown(keyFormatter, names) << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *********************************************************************** */
|
|
||||||
std::string DiscreteBayesNet::html(
|
|
||||||
const KeyFormatter& keyFormatter,
|
|
||||||
const DiscreteFactor::Names& names) const {
|
|
||||||
using std::endl;
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
|
||||||
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
|
||||||
ss << conditional->html(keyFormatter, names) << endl;
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
} // namespace
|
bool DiscreteBayesNet::equals(const This& bn, double tol) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
|
||||||
|
// evaluate all conditionals and multiply
|
||||||
|
double result = 1.0;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
result *= (*conditional)(values);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return optimize(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
|
||||||
|
// solve each node in turn in topological sort order (parents first)
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->solveInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteValues DiscreteBayesNet::sample() const {
|
||||||
|
DiscreteValues result;
|
||||||
|
return sample(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
|
// sample each node in turn in topological sort order (parents first)
|
||||||
|
for (auto conditional : boost::adaptors::reverse(*this))
|
||||||
|
conditional->sampleInPlace(&result);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::markdown(
|
||||||
|
const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->markdown(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *********************************************************************** */
|
||||||
|
std::string DiscreteBayesNet::html(const KeyFormatter& keyFormatter,
|
||||||
|
const DiscreteFactor::Names& names) const {
|
||||||
|
using std::endl;
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "<div><p><tt>DiscreteBayesNet</tt> of size " << size() << "</p>";
|
||||||
|
for (const DiscreteConditional::shared_ptr& conditional : *this)
|
||||||
|
ss << conditional->html(keyFormatter, names) << endl;
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
@ -99,13 +99,47 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Solve the DiscreteBayesNet by back-substitution
|
* @brief solve by back-substitution.
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be optimized first. If the Bayes net resulted from
|
||||||
|
* eliminating a factor graph, this is true for the elimination ordering.
|
||||||
|
*
|
||||||
|
* @return a sampled value for all variables.
|
||||||
*/
|
*/
|
||||||
DiscreteValues optimize() const;
|
DiscreteValues optimize() const;
|
||||||
|
|
||||||
/** Do ancestral sampling */
|
/**
|
||||||
|
* @brief solve by back-substitution, given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||||
|
* Bayes net does not contain any conditionals for the given values.
|
||||||
|
*
|
||||||
|
* @return given values extended with optimized value for other variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues optimize(DiscreteValues given) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief do ancestral sampling
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted, i.e. last
|
||||||
|
* conditional will be sampled first. If the Bayes net resulted from
|
||||||
|
* eliminating a factor graph, this is true for the elimination ordering.
|
||||||
|
*
|
||||||
|
* @return a sampled value for all variables.
|
||||||
|
*/
|
||||||
DiscreteValues sample() const;
|
DiscreteValues sample() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief do ancestral sampling, given certain variables.
|
||||||
|
*
|
||||||
|
* Assumes the Bayes net is reverse topologically sorted *and* that the
|
||||||
|
* Bayes net does not contain any conditionals for the given values.
|
||||||
|
*
|
||||||
|
* @return given values extended with sampled value for all other variables.
|
||||||
|
*/
|
||||||
|
DiscreteValues sample(DiscreteValues given) const;
|
||||||
|
|
||||||
///@}
|
///@}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -165,7 +165,9 @@ class DiscreteBayesNet {
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
double operator()(const gtsam::DiscreteValues& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DiscreteValues optimize() const;
|
gtsam::DiscreteValues optimize() const;
|
||||||
|
gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const;
|
||||||
gtsam::DiscreteValues sample() const;
|
gtsam::DiscreteValues sample() const;
|
||||||
|
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
string markdown(const gtsam::KeyFormatter& keyFormatter,
|
||||||
|
|
|
@ -17,6 +17,17 @@ from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
# Some keys:
|
||||||
|
Asia = (0, 2)
|
||||||
|
Smoking = (4, 2)
|
||||||
|
Tuberculosis = (3, 2)
|
||||||
|
LungCancer = (6, 2)
|
||||||
|
|
||||||
|
Bronchitis = (7, 2)
|
||||||
|
Either = (5, 2)
|
||||||
|
XRay = (2, 2)
|
||||||
|
Dyspnea = (1, 2)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
"""Tests for Discrete Bayes Nets."""
|
"""Tests for Discrete Bayes Nets."""
|
||||||
|
@ -43,16 +54,6 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
def test_Asia(self):
|
def test_Asia(self):
|
||||||
"""Test full Asia example."""
|
"""Test full Asia example."""
|
||||||
|
|
||||||
Asia = (0, 2)
|
|
||||||
Smoking = (4, 2)
|
|
||||||
Tuberculosis = (3, 2)
|
|
||||||
LungCancer = (6, 2)
|
|
||||||
|
|
||||||
Bronchitis = (7, 2)
|
|
||||||
Either = (5, 2)
|
|
||||||
XRay = (2, 2)
|
|
||||||
Dyspnea = (1, 2)
|
|
||||||
|
|
||||||
asia = DiscreteBayesNet()
|
asia = DiscreteBayesNet()
|
||||||
asia.add(Asia, "99/1")
|
asia.add(Asia, "99/1")
|
||||||
asia.add(Smoking, "50/50")
|
asia.add(Smoking, "50/50")
|
||||||
|
@ -107,6 +108,28 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
actualSample = chordal2.sample()
|
actualSample = chordal2.sample()
|
||||||
self.assertEqual(len(actualSample), 8)
|
self.assertEqual(len(actualSample), 8)
|
||||||
|
|
||||||
|
def test_fragment(self):
|
||||||
|
"""Test sampling and optimizing for Asia fragment."""
|
||||||
|
|
||||||
|
# Create a reverse-topologically sorted fragment:
|
||||||
|
fragment = DiscreteBayesNet()
|
||||||
|
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||||
|
fragment.add(Tuberculosis, [Asia], "99/1 95/5")
|
||||||
|
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
||||||
|
|
||||||
|
# Create assignment with missing values:
|
||||||
|
given = DiscreteValues()
|
||||||
|
for key in [Asia, Smoking]:
|
||||||
|
given[key[0]] = 0
|
||||||
|
|
||||||
|
# Now optimize fragment:
|
||||||
|
actual = fragment.optimize(given)
|
||||||
|
self.assertEqual(len(actual), 5)
|
||||||
|
|
||||||
|
# Now sample from fragment:
|
||||||
|
actual = fragment.sample(given)
|
||||||
|
self.assertEqual(len(actual), 5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue