override sample in TableDistribution
parent
b81ab86b69
commit
3629c33ecd
|
@ -199,7 +199,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
* @param parentsValues Known values of the parents
|
* @param parentsValues Known values of the parents
|
||||||
* @return sample from conditional
|
* @return sample from conditional
|
||||||
*/
|
*/
|
||||||
size_t sample(const DiscreteValues& parentsValues) const;
|
virtual size_t sample(const DiscreteValues& parentsValues) const;
|
||||||
|
|
||||||
/// Single parent version.
|
/// Single parent version.
|
||||||
size_t sample(size_t parent_value) const;
|
size_t sample(size_t parent_value) const;
|
||||||
|
|
|
@ -138,4 +138,37 @@ void TableDistribution::prune(size_t maxNrAssignments) {
|
||||||
table_ = table_.prune(maxNrAssignments);
|
table_ = table_.prune(maxNrAssignments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
size_t TableDistribution::sample(const DiscreteValues& parentsValues) const {
|
||||||
|
static mt19937 rng(2); // random number generator
|
||||||
|
|
||||||
|
DiscreteKeys parentsKeys;
|
||||||
|
for (auto&& [key, _] : parentsValues) {
|
||||||
|
parentsKeys.push_back({key, table_.cardinality(key)});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the correct conditional distribution: P(F|S=parentsValues)
|
||||||
|
TableFactor pFS = table_.choose(parentsValues, parentsKeys);
|
||||||
|
|
||||||
|
// TODO(Duy): only works for one key now, seems horribly slow this way
|
||||||
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"TableDistribution::sample can only be called on single variable "
|
||||||
|
"conditionals");
|
||||||
|
}
|
||||||
|
Key key = firstFrontalKey();
|
||||||
|
size_t nj = cardinality(key);
|
||||||
|
vector<double> p(nj);
|
||||||
|
DiscreteValues frontals;
|
||||||
|
for (size_t value = 0; value < nj; value++) {
|
||||||
|
frontals[key] = value;
|
||||||
|
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
|
||||||
|
if (p[value] == 1.0) {
|
||||||
|
return value; // shortcut exit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
|
||||||
|
return distribution(rng);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -133,6 +133,13 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
*/
|
*/
|
||||||
DiscreteValues argmax() const;
|
DiscreteValues argmax() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* sample
|
||||||
|
* @param parentsValues Known values of the parents
|
||||||
|
* @return sample from conditional
|
||||||
|
*/
|
||||||
|
virtual size_t sample(const DiscreteValues& parentsValues) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -211,7 +211,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
||||||
/// Create a TableFactor that is a subset of this TableFactor
|
/// Create a TableFactor that is a subset of this TableFactor
|
||||||
TableFactor choose(const DiscreteValues assignments,
|
TableFactor choose(const DiscreteValues parentAssignments,
|
||||||
DiscreteKeys parent_keys) const;
|
DiscreteKeys parent_keys) const;
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
|
|
Loading…
Reference in New Issue