Merge branch 'hybrid-custom-discrete' into discrete-table-conditional

release/4.3a0
Varun Agrawal 2024-12-31 21:28:18 -05:00
commit 782f39a0e2
5 changed files with 36 additions and 2 deletions

View File

@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DecisionTreeFactor& f,
const Ordering& orderedKeys)
: BaseFactor(f), BaseConditional(nrFrontals) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,

View File

@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional
/// Construct from factor, taking the first `nFrontals` keys as frontals.
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
/**
* @brief Construct from DecisionTreeFactor,
* taking the first `nrFrontals` from `orderedKeys`.
*
* @param nrFrontals The number of frontal variables.
* @param f The DecisionTreeFactor to construct from.
* @param orderedKeys Ordered list of keys involved in the conditional.
*/
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
const Ordering& orderedKeys);
/**
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
* `nFrontals` keys as frontals, in the order given.

View File

@ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys();
// If no keys, then return empty DecisionTreeFactor
if (dkeys.size() == 0) {
AlgebraicDecisionTree<Key> tree;
if (sparse_table_.size() != 0) {
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
}
return DecisionTreeFactor(dkeys, tree);
}
std::vector<double> table;
for (auto i = 0; i < sparse_table_.size(); i++) {
table.push_back(sparse_table_.coeff(i));

View File

@ -360,12 +360,16 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// All the discrete variables should form a single clique,
// so we can sum out on all the variables as frontals.
// This should give an empty separator.
Ordering orderedKeys(product.keys());
TableFactor::shared_ptr sum = product.sum(orderedKeys);
TableFactor::shared_ptr sum = product.sum(frontalKeys);
#if GTSAM_HYBRID_TIMING
gttoc_(EliminateDiscreteSum);
#endif
// Ordering keys for the conditional so that frontalKeys are really in front
Ordering orderedKeys;
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
#if GTSAM_HYBRID_TIMING
gttic_(EliminateDiscreteFormDiscreteConditional);
#endif

View File

@ -169,6 +169,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
}
}
/* ************************************************************************* */
int main() {
TestResult tr;