Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
commit
782f39a0e2
|
@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
|||
const DecisionTreeFactor& f)
|
||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
const DecisionTreeFactor& f,
|
||||
const Ordering& orderedKeys)
|
||||
: BaseFactor(f), BaseConditional(nrFrontals) {
|
||||
keys_.clear();
|
||||
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
const DiscreteKeys& keys,
|
||||
|
|
|
@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||
|
||||
/**
|
||||
* @brief Construct from DecisionTreeFactor,
|
||||
* taking the first `nrFrontals` from `orderedKeys`.
|
||||
*
|
||||
* @param nrFrontals The number of frontal variables.
|
||||
* @param f The DecisionTreeFactor to construct from.
|
||||
* @param orderedKeys Ordered list of keys involved in the conditional.
|
||||
*/
|
||||
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
|
||||
const Ordering& orderedKeys);
|
||||
|
||||
/**
|
||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||
* `nFrontals` keys as frontals, in the order given.
|
||||
|
|
|
@ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
|||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
|
||||
// If no keys, then return empty DecisionTreeFactor
|
||||
if (dkeys.size() == 0) {
|
||||
AlgebraicDecisionTree<Key> tree;
|
||||
if (sparse_table_.size() != 0) {
|
||||
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
|
||||
}
|
||||
return DecisionTreeFactor(dkeys, tree);
|
||||
}
|
||||
|
||||
std::vector<double> table;
|
||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||
table.push_back(sparse_table_.coeff(i));
|
||||
|
|
|
@ -360,12 +360,16 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
// All the discrete variables should form a single clique,
|
||||
// so we can sum out on all the variables as frontals.
|
||||
// This should give an empty separator.
|
||||
Ordering orderedKeys(product.keys());
|
||||
TableFactor::shared_ptr sum = product.sum(orderedKeys);
|
||||
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttoc_(EliminateDiscreteSum);
|
||||
#endif
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
Ordering orderedKeys;
|
||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||
#endif
|
||||
|
|
|
@ -169,6 +169,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue