Fixed issues with sample
parent
88c79a2a56
commit
53a6523943
|
|
@ -282,6 +282,15 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
|
||||||
return sample(values);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ******************************************************************************** */
|
||||||
|
size_t DiscreteConditional::sample() const {
|
||||||
|
if (nrParents() != 0)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"sample() can only be invoked on no-parent prior");
|
||||||
|
DiscreteValues values;
|
||||||
|
return sample(values);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
|
|
|
||||||
|
|
@ -162,9 +162,12 @@ public:
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
|
|
||||||
/// Single value version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
||||||
|
/// Zero parent version.
|
||||||
|
size_t sample() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
* sample
|
* sample
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample() const { return Base::sample(DiscreteValues()); }
|
size_t sample() const { return Base::sample(); }
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -86,6 +86,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
|
||||||
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
size_t solve(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
|
||||||
size_t sample(size_t value) const;
|
size_t sample(size_t value) const;
|
||||||
|
size_t sample() const;
|
||||||
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void solveInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
|
|
@ -105,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
|
||||||
double operator()(size_t value) const;
|
double operator()(size_t value) const;
|
||||||
std::vector<double> pmf() const;
|
std::vector<double> pmf() const;
|
||||||
size_t solve() const;
|
size_t solve() const;
|
||||||
size_t sample() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2);
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, constructors) {
|
TEST(DiscretePrior, constructors) {
|
||||||
DiscretePrior actual(X % "2/3");
|
DiscretePrior actual(X % "2/3");
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual.nrFrontals());
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual.nrParents());
|
||||||
DecisionTreeFactor f(X, "0.4 0.6");
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
DiscretePrior expected(f);
|
DiscretePrior expected(f);
|
||||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
|
|
@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, to_vector) {
|
TEST(DiscretePrior, pmf) {
|
||||||
DiscretePrior prior(X % "2/3");
|
DiscretePrior prior(X % "2/3");
|
||||||
vector<double> expected {0.4, 0.6};
|
vector<double> expected {0.4, 0.6};
|
||||||
EXPECT(prior.pmf() == expected);
|
EXPECT(prior.pmf() == expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscretePrior, sample) {
|
||||||
|
DiscretePrior prior(X % "2/3");
|
||||||
|
prior.sample();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue