diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 19cc3a798..1bca0b09f 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional * @param parentsValues Known values of the parents * @return sample from conditional */ - size_t sample(const DiscreteValues& parentsValues) const; + virtual size_t sample(const DiscreteValues& parentsValues) const; /// Single parent version. size_t sample(size_t parent_value) const; diff --git a/gtsam/discrete/TableDistribution.cpp b/gtsam/discrete/TableDistribution.cpp index 4b9979d3a..e8696c5b1 100644 --- a/gtsam/discrete/TableDistribution.cpp +++ b/gtsam/discrete/TableDistribution.cpp @@ -138,4 +138,37 @@ void TableDistribution::prune(size_t maxNrAssignments) { table_ = table_.prune(maxNrAssignments); } +/* ****************************************************************************/ +size_t TableDistribution::sample(const DiscreteValues& parentsValues) const { + static mt19937 rng(2); // random number generator + + DiscreteKeys parentsKeys; + for (auto&& [key, _] : parentsValues) { + parentsKeys.push_back({key, table_.cardinality(key)}); + } + + // Get the correct conditional distribution: P(F|S=parentsValues) + TableFactor pFS = table_.choose(parentsValues, parentsKeys); + + // TODO(Duy): only works for one key now, seems horribly slow this way + if (nrFrontals() != 1) { + throw std::invalid_argument( + "TableDistribution::sample can only be called on single variable " + "conditionals"); + } + Key key = firstFrontalKey(); + size_t nj = cardinality(key); + vector p(nj); + DiscreteValues frontals; + for (size_t value = 0; value < nj; value++) { + frontals[key] = value; + p[value] = pFS(frontals); // P(F=value|S=parentsValues) + if (p[value] == 1.0) { + return value; // shortcut exit + } + } + std::discrete_distribution distribution(p.begin(), p.end()); + return distribution(rng); +} + } // namespace gtsam diff --git a/gtsam/discrete/TableDistribution.h b/gtsam/discrete/TableDistribution.h index 15ec9959c..72786a515 100644 --- a/gtsam/discrete/TableDistribution.h +++ b/gtsam/discrete/TableDistribution.h @@ -133,6 +133,13 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional { */ DiscreteValues argmax() const; + /** + * sample + * @param parentsValues Known values of the parents + * @return sample from conditional + */ + virtual size_t sample(const DiscreteValues& parentsValues) const override; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 43f84f874..1cb9eda8b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { DecisionTreeFactor toDecisionTreeFactor() const override; /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, + TableFactor choose(const DiscreteValues parentAssignments, DiscreteKeys parent_keys) const; /// Create new factor by summing all values with the same separator values