Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
commit
782f39a0e2
|
@ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f)
|
const DecisionTreeFactor& f)
|
||||||
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
|
const DecisionTreeFactor& f,
|
||||||
|
const Ordering& orderedKeys)
|
||||||
|
: BaseFactor(f), BaseConditional(nrFrontals) {
|
||||||
|
keys_.clear();
|
||||||
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
const DiscreteKeys& keys,
|
const DiscreteKeys& keys,
|
||||||
|
|
|
@ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from DecisionTreeFactor,
|
||||||
|
* taking the first `nrFrontals` from `orderedKeys`.
|
||||||
|
*
|
||||||
|
* @param nrFrontals The number of frontal variables.
|
||||||
|
* @param f The DecisionTreeFactor to construct from.
|
||||||
|
* @param orderedKeys Ordered list of keys involved in the conditional.
|
||||||
|
*/
|
||||||
|
DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f,
|
||||||
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
* Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first
|
||||||
* `nFrontals` keys as frontals, in the order given.
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
|
|
|
@ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
|
||||||
|
// If no keys, then return empty DecisionTreeFactor
|
||||||
|
if (dkeys.size() == 0) {
|
||||||
|
AlgebraicDecisionTree<Key> tree;
|
||||||
|
if (sparse_table_.size() != 0) {
|
||||||
|
tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0));
|
||||||
|
}
|
||||||
|
return DecisionTreeFactor(dkeys, tree);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<double> table;
|
std::vector<double> table;
|
||||||
for (auto i = 0; i < sparse_table_.size(); i++) {
|
for (auto i = 0; i < sparse_table_.size(); i++) {
|
||||||
table.push_back(sparse_table_.coeff(i));
|
table.push_back(sparse_table_.coeff(i));
|
||||||
|
|
|
@ -360,12 +360,16 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// All the discrete variables should form a single clique,
|
// All the discrete variables should form a single clique,
|
||||||
// so we can sum out on all the variables as frontals.
|
// so we can sum out on all the variables as frontals.
|
||||||
// This should give an empty separator.
|
// This should give an empty separator.
|
||||||
Ordering orderedKeys(product.keys());
|
TableFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||||
TableFactor::shared_ptr sum = product.sum(orderedKeys);
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(EliminateDiscreteSum);
|
gttoc_(EliminateDiscreteSum);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
|
Ordering orderedKeys;
|
||||||
|
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end());
|
||||||
|
orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end());
|
||||||
|
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttic_(EliminateDiscreteFormDiscreteConditional);
|
gttic_(EliminateDiscreteFormDiscreteConditional);
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -169,6 +169,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue