single argument variants

release/4.3a0
Frank Dellaert 2021-12-26 15:21:02 -05:00
parent 4727783304
commit 4bc7b0ba85
2 changed files with 43 additions and 1 deletions

View File

@ -75,7 +75,34 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
const KeyFormatter& formatter = DefaultKeyFormatter) const override { const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter); Base::print(s, formatter);
} }
/// @}
/// @name Standard interface
/// @{
/// Evaluate given a single value.
double operator()(size_t value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value operator can only be invoked on single-variable "
"priors");
DiscreteValues values;
values.emplace(keys_[0], value);
return Base::operator()(values);
}
/// Evaluate given a single value.
std::vector<double> pmf() const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"DiscretePrior::pmf only defined for single-variable priors");
const size_t nrValues = cardinalities_.at(keys_[0]);
std::vector<double> array;
array.reserve(nrValues);
for (size_t v = 0; v < nrValues; v++) {
array.push_back(operator()(v));
}
return array;
}
/// @} /// @}
}; };
// DiscretePrior // DiscretePrior

View File

@ -23,15 +23,30 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
static const DiscreteKey X(0, 2);
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DiscretePrior, constructors) { TEST(DiscretePrior, constructors) {
DiscreteKey X(0, 2);
DiscretePrior actual(X % "2/3"); DiscretePrior actual(X % "2/3");
DecisionTreeFactor f(X, "0.4 0.6"); DecisionTreeFactor f(X, "0.4 0.6");
DiscretePrior expected(f); DiscretePrior expected(f);
EXPECT(assert_equal(expected, actual, 1e-9)); EXPECT(assert_equal(expected, actual, 1e-9));
} }
/* ************************************************************************* */
TEST(DiscretePrior, operator) {
DiscretePrior prior(X % "2/3");
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
}
/* ************************************************************************* */
TEST(DiscretePrior, to_vector) {
DiscretePrior prior(X % "2/3");
vector<double> expected {0.4, 0.6};
EXPECT(prior.pmf() == expected);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;