single argument variants
parent
4727783304
commit
4bc7b0ba85
|
|
@ -75,7 +75,34 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional {
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
||||||
Base::print(s, formatter);
|
Base::print(s, formatter);
|
||||||
}
|
}
|
||||||
|
/// @}
|
||||||
|
/// @name Standard interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Evaluate given a single value.
|
||||||
|
double operator()(size_t value) const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Single value operator can only be invoked on single-variable "
|
||||||
|
"priors");
|
||||||
|
DiscreteValues values;
|
||||||
|
values.emplace(keys_[0], value);
|
||||||
|
return Base::operator()(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate given a single value.
|
||||||
|
std::vector<double> pmf() const {
|
||||||
|
if (nrFrontals() != 1)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscretePrior::pmf only defined for single-variable priors");
|
||||||
|
const size_t nrValues = cardinalities_.at(keys_[0]);
|
||||||
|
std::vector<double> array;
|
||||||
|
array.reserve(nrValues);
|
||||||
|
for (size_t v = 0; v < nrValues; v++) {
|
||||||
|
array.push_back(operator()(v));
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
// DiscretePrior
|
// DiscretePrior
|
||||||
|
|
|
||||||
|
|
@ -23,15 +23,30 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
|
static const DiscreteKey X(0, 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DiscretePrior, constructors) {
|
TEST(DiscretePrior, constructors) {
|
||||||
DiscreteKey X(0, 2);
|
|
||||||
DiscretePrior actual(X % "2/3");
|
DiscretePrior actual(X % "2/3");
|
||||||
DecisionTreeFactor f(X, "0.4 0.6");
|
DecisionTreeFactor f(X, "0.4 0.6");
|
||||||
DiscretePrior expected(f);
|
DiscretePrior expected(f);
|
||||||
EXPECT(assert_equal(expected, actual, 1e-9));
|
EXPECT(assert_equal(expected, actual, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscretePrior, operator) {
|
||||||
|
DiscretePrior prior(X % "2/3");
|
||||||
|
EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DiscretePrior, to_vector) {
|
||||||
|
DiscretePrior prior(X % "2/3");
|
||||||
|
vector<double> expected {0.4, 0.6};
|
||||||
|
EXPECT(prior.pmf() == expected);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue