From 268a49ec1cbdb441a53b2da6acd175f9c68c8730 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 26 Dec 2021 13:59:48 -0500 Subject: [PATCH] DiscretePrior class --- gtsam/discrete/DiscretePrior.h | 87 ++++++++++++++++++++++ gtsam/discrete/tests/testDiscretePrior.cpp | 40 ++++++++++ 2 files changed, 127 insertions(+) create mode 100644 gtsam/discrete/DiscretePrior.h create mode 100644 gtsam/discrete/tests/testDiscretePrior.cpp diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h new file mode 100644 index 000000000..f38c78ca1 --- /dev/null +++ b/gtsam/discrete/DiscretePrior.h @@ -0,0 +1,87 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscretePrior() {} + + /// Constructor from factor. + DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscretePrior P(D % "3/2"); + */ + DiscretePrior(const Signature& s) : Base(s) {} + + /** + * Construct from key and a Signature::Table specifying the + * conditional probability table (CPT). + * + * Example: DiscretePrior P(D, table); + */ + DiscretePrior(const DiscreteKey& key, const Signature::Table& table) + : Base(Signature(key, {}, table)) {} + + /** + * Construct from key and a string specifying the conditional + * probability table (CPT). + * + * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + */ + DiscretePrior(const DiscreteKey& key, const std::string& spec) + : DiscretePrior(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + + /// @} +}; +// DiscretePrior + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp new file mode 100644 index 000000000..f63b8af0b --- /dev/null +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -0,0 +1,40 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testDiscretePrior.cpp + * @brief unit tests for DiscretePrior + * @author Frank dellaert + * @date December 2021 + */ + +#include +#include +#include + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +TEST(DiscretePrior, constructors) { + DiscreteKey X(0, 2); + DiscretePrior actual(X % "2/3"); + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + EXPECT(assert_equal(expected, actual, 1e-9)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */