Fixed issues with sample

release/4.3a0
Frank Dellaert 2022-01-02 23:23:51 -05:00
parent 88c79a2a56
commit 53a6523943
5 changed files with 24 additions and 4 deletions

View File

@ -282,6 +282,15 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values);
}
/* ******************************************************************************** */
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
DiscreteValues values;
return sample(values);
}
/* ************************************************************************* */
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {

View File

@ -162,9 +162,12 @@ public:
size_t sample(const DiscreteValues& parentsValues) const;
/// Single value version.
/// Single parent version.
size_t sample(size_t parent_value) const;
/// Zero parent version.
size_t sample() const;
/// @}
/// @name Advanced Interface
/// @{

View File

@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
* sample
* @return sample from conditional
*/
size_t sample() const { return Base::sample(DiscreteValues()); }
size_t sample() const { return Base::sample(); }
/// @}
};

View File

@ -86,6 +86,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
size_t sample(size_t value) const;
size_t sample() const;
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
@ -105,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
double operator()(size_t value) const;
std::vector<double> pmf() const;
size_t solve() const;
size_t sample() const;
};
#include <gtsam/discrete/DiscreteBayesNet.h>

View File

@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2);
/* ************************************************************************* */
TEST(DiscretePrior, constructors) {
DiscretePrior actual(X % "2/3");
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
EXPECT_LONGS_EQUAL(0, actual.nrParents());
DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f);
EXPECT(assert_equal(expected, actual, 1e-9));
@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) {
}
/* ************************************************************************* */
TEST(DiscretePrior, to_vector) {
TEST(DiscretePrior, pmf) {
DiscretePrior prior(X % "2/3");
vector<double> expected {0.4, 0.6};
EXPECT(prior.pmf() == expected);
}
/* ************************************************************************* */
TEST(DiscretePrior, sample) {
DiscretePrior prior(X % "2/3");
prior.sample();
}
/* ************************************************************************* */
int main() {
TestResult tr;