override sample in TableDistribution

release/4.3a0
Varun Agrawal 2025-01-07 14:53:28 -05:00
parent b81ab86b69
commit 3629c33ecd
4 changed files with 42 additions and 2 deletions

View File

@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
* @param parentsValues Known values of the parents * @param parentsValues Known values of the parents
* @return sample from conditional * @return sample from conditional
*/ */
size_t sample(const DiscreteValues& parentsValues) const; virtual size_t sample(const DiscreteValues& parentsValues) const;
/// Single parent version. /// Single parent version.
size_t sample(size_t parent_value) const; size_t sample(size_t parent_value) const;

View File

@ -138,4 +138,37 @@ void TableDistribution::prune(size_t maxNrAssignments) {
table_ = table_.prune(maxNrAssignments); table_ = table_.prune(maxNrAssignments);
} }
/* ****************************************************************************/
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator
DiscreteKeys parentsKeys;
for (auto&& [key, _] : parentsValues) {
parentsKeys.push_back({key, table_.cardinality(key)});
}
// Get the correct conditional distribution: P(F|S=parentsValues)
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"TableDistribution::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}
} // namespace gtsam } // namespace gtsam

View File

@ -133,6 +133,13 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
*/ */
DiscreteValues argmax() const; DiscreteValues argmax() const;
/**
* sample
* @param parentsValues Known values of the parents
* @return sample from conditional
*/
virtual size_t sample(const DiscreteValues& parentsValues) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;
/// Create a TableFactor that is a subset of this TableFactor /// Create a TableFactor that is a subset of this TableFactor
TableFactor choose(const DiscreteValues assignments, TableFactor choose(const DiscreteValues parentAssignments,
DiscreteKeys parent_keys) const; DiscreteKeys parent_keys) const;
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values