override sample in TableDistribution
parent
b81ab86b69
commit
3629c33ecd
|
@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
size_t sample(const DiscreteValues& parentsValues) const;
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
||||
|
||||
/// Single parent version.
|
||||
size_t sample(size_t parent_value) const;
|
||||
|
|
|
@ -138,4 +138,37 @@ void TableDistribution::prune(size_t maxNrAssignments) {
|
|||
table_ = table_.prune(maxNrAssignments);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||
static mt19937 rng(2); // random number generator
|
||||
|
||||
DiscreteKeys parentsKeys;
|
||||
for (auto&& [key, _] : parentsValues) {
|
||||
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||
}
|
||||
|
||||
// Get the correct conditional distribution: P(F|S=parentsValues)
|
||||
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
|
||||
|
||||
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"TableDistribution::sample can only be called on single variable "
|
||||
"conditionals");
|
||||
}
|
||||
Key key = firstFrontalKey();
|
||||
size_t nj = cardinality(key);
|
||||
vector<double> p(nj);
|
||||
DiscreteValues frontals;
|
||||
for (size_t value = 0; value < nj; value++) {
|
||||
frontals[key] = value;
|
||||
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||
if (p[value] == 1.0) {
|
||||
return value; // shortcut exit
|
||||
}
|
||||
}
|
||||
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||
return distribution(rng);
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -133,6 +133,13 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
|||
*/
|
||||
DiscreteValues argmax() const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
* @param parentsValues Known values of the parents
|
||||
* @return sample from conditional
|
||||
*/
|
||||
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
||||
|
||||
/// @}
|
||||
/// @name Advanced Interface
|
||||
/// @{
|
||||
|
|
|
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||
|
||||
/// Create a TableFactor that is a subset of this TableFactor
|
||||
TableFactor choose(const DiscreteValues assignments,
|
||||
TableFactor choose(const DiscreteValues parentAssignments,
|
||||
DiscreteKeys parent_keys) const;
|
||||
|
||||
/// Create new factor by summing all values with the same separator values
|
||||
|
|
Loading…
Reference in New Issue