break up EliminateHybrid into smaller functions
parent
3bde044248
commit
6cd20fba4d
|
|
@ -55,14 +55,6 @@ namespace gtsam {
|
||||||
|
|
||||||
template class EliminateableFactorGraph<HybridFactorGraph>;
|
template class EliminateableFactorGraph<HybridFactorGraph>;
|
||||||
|
|
||||||
static std::string BLACK_BOLD = "\033[1;30m";
|
|
||||||
static std::string RED_BOLD = "\033[1;31m";
|
|
||||||
static std::string GREEN = "\033[0;32m";
|
|
||||||
static std::string GREEN_BOLD = "\033[1;32m";
|
|
||||||
static std::string RESET = "\033[0m";
|
|
||||||
|
|
||||||
constexpr bool DEBUG = false;
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianMixtureFactor::Sum &addGaussian(
|
static GaussianMixtureFactor::Sum &addGaussian(
|
||||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||||
|
|
@ -84,6 +76,163 @@ static GaussianMixtureFactor::Sum &addGaussian(
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||||
|
continuousElimination(const HybridFactorGraph &factors,
|
||||||
|
const Ordering &frontalKeys) {
|
||||||
|
GaussianFactorGraph gfg;
|
||||||
|
for (auto &fp : factors) {
|
||||||
|
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
|
||||||
|
if (ptr) {
|
||||||
|
gfg.push_back(ptr->inner());
|
||||||
|
} else {
|
||||||
|
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||||
|
if (p) {
|
||||||
|
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
||||||
|
} else {
|
||||||
|
// It is an orphan wrapped conditional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
||||||
|
return std::make_pair(
|
||||||
|
boost::make_shared<HybridConditional>(result.first),
|
||||||
|
boost::make_shared<HybridGaussianFactor>(result.second));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||||
|
discreteElimination(const HybridFactorGraph &factors,
|
||||||
|
const Ordering &frontalKeys) {
|
||||||
|
DiscreteFactorGraph dfg;
|
||||||
|
for (auto &fp : factors) {
|
||||||
|
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
|
||||||
|
if (ptr) {
|
||||||
|
dfg.push_back(ptr->inner());
|
||||||
|
} else {
|
||||||
|
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||||
|
if (p) {
|
||||||
|
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
|
||||||
|
} else {
|
||||||
|
// It is an orphan wrapper
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||||
|
|
||||||
|
return std::make_pair(
|
||||||
|
boost::make_shared<HybridConditional>(result.first),
|
||||||
|
boost::make_shared<HybridDiscreteFactor>(result.second));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||||
|
hybridElimination(const HybridFactorGraph &factors, const Ordering &frontalKeys,
|
||||||
|
const KeySet &continuousSeparator,
|
||||||
|
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||||
|
// NOTE: since we use the special JunctionTree,
|
||||||
|
// only possiblity is continuous conditioned on discrete.
|
||||||
|
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||||
|
discreteSeparatorSet.end());
|
||||||
|
|
||||||
|
// sum out frontals, this is the factor on the separator
|
||||||
|
gttic(sum);
|
||||||
|
|
||||||
|
GaussianMixtureFactor::Sum sum;
|
||||||
|
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
||||||
|
|
||||||
|
for (auto &f : factors) {
|
||||||
|
if (f->isHybrid()) {
|
||||||
|
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
||||||
|
if (cgmf) {
|
||||||
|
sum = cgmf->add(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
||||||
|
if (gm) {
|
||||||
|
sum = gm->asMixture()->add(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (f->isContinuous()) {
|
||||||
|
deferredFactors.push_back(
|
||||||
|
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
||||||
|
} else {
|
||||||
|
// We need to handle the case where the object is actually an
|
||||||
|
// BayesTreeOrphanWrapper!
|
||||||
|
auto orphan = boost::dynamic_pointer_cast<
|
||||||
|
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
||||||
|
if (!orphan) {
|
||||||
|
auto &fr = *f;
|
||||||
|
throw std::invalid_argument(
|
||||||
|
std::string("factor is discrete in continuous elimination") +
|
||||||
|
typeid(fr).name());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &f : deferredFactors) {
|
||||||
|
sum = addGaussian(sum, f);
|
||||||
|
}
|
||||||
|
|
||||||
|
gttoc(sum);
|
||||||
|
|
||||||
|
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
||||||
|
|
||||||
|
KeyVector keysOfEliminated; // Not the ordering
|
||||||
|
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||||
|
|
||||||
|
auto eliminate = [&](const GaussianFactorGraph &graph)
|
||||||
|
-> GaussianFactorGraph::EliminationResult {
|
||||||
|
if (graph.empty()) {
|
||||||
|
return {nullptr, nullptr};
|
||||||
|
}
|
||||||
|
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||||
|
if (keysOfEliminated.empty()) {
|
||||||
|
keysOfEliminated =
|
||||||
|
result.first->keys(); // Initialize the keysOfEliminated to be the
|
||||||
|
}
|
||||||
|
// keysOfEliminated of the GaussianConditional
|
||||||
|
if (keysOfSeparator.empty()) {
|
||||||
|
keysOfSeparator = result.second->keys();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
||||||
|
|
||||||
|
auto pair = unzip(eliminationResults);
|
||||||
|
|
||||||
|
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
||||||
|
|
||||||
|
// Create the GaussianMixtureConditional from the conditionals
|
||||||
|
auto conditional = boost::make_shared<GaussianMixtureConditional>(
|
||||||
|
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
|
||||||
|
|
||||||
|
// If there are no more continuous parents, then we should create here a
|
||||||
|
// DiscreteFactor, with the error for each discrete choice.
|
||||||
|
if (keysOfSeparator.empty()) {
|
||||||
|
VectorValues empty_values;
|
||||||
|
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
||||||
|
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
||||||
|
return exp(-factor->error(empty_values));
|
||||||
|
};
|
||||||
|
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
||||||
|
auto discreteFactor =
|
||||||
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||||
|
|
||||||
|
return {boost::make_shared<HybridConditional>(conditional),
|
||||||
|
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Create a resulting DCGaussianMixture on the separator.
|
||||||
|
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
||||||
|
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||||
|
discreteSeparator, separatorFactors);
|
||||||
|
return {boost::make_shared<HybridConditional>(conditional), factor};
|
||||||
|
}
|
||||||
|
}
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
||||||
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
||||||
|
|
@ -177,154 +326,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
||||||
|
|
||||||
// Case 1: we are only dealing with continuous
|
// Case 1: we are only dealing with continuous
|
||||||
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
||||||
GaussianFactorGraph gfg;
|
return continuousElimination(factors, frontalKeys);
|
||||||
for (auto &fp : factors) {
|
|
||||||
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
|
|
||||||
if (ptr) {
|
|
||||||
gfg.push_back(ptr->inner());
|
|
||||||
} else {
|
|
||||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
|
||||||
if (p) {
|
|
||||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
|
||||||
} else {
|
|
||||||
// It is an orphan wrapped conditional
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
|
||||||
return std::make_pair(
|
|
||||||
boost::make_shared<HybridConditional>(result.first),
|
|
||||||
boost::make_shared<HybridGaussianFactor>(result.second));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2: we are only dealing with discrete
|
// Case 2: we are only dealing with discrete
|
||||||
if (allContinuousKeys.empty()) {
|
if (allContinuousKeys.empty()) {
|
||||||
DiscreteFactorGraph dfg;
|
return discreteElimination(factors, frontalKeys);
|
||||||
for (auto &fp : factors) {
|
|
||||||
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
|
|
||||||
if (ptr) {
|
|
||||||
dfg.push_back(ptr->inner());
|
|
||||||
} else {
|
|
||||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
|
||||||
if (p) {
|
|
||||||
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
|
|
||||||
} else {
|
|
||||||
// It is an orphan wrapper
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
|
||||||
|
|
||||||
return std::make_pair(
|
|
||||||
boost::make_shared<HybridConditional>(result.first),
|
|
||||||
boost::make_shared<HybridDiscreteFactor>(result.second));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 3: We are now in the hybrid land!
|
// Case 3: We are now in the hybrid land!
|
||||||
// NOTE: since we use the special JunctionTree, only possiblity is cont.
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
// conditioned on disc.
|
discreteSeparatorSet);
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
|
||||||
discreteSeparatorSet.end());
|
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
|
||||||
gttic(sum);
|
|
||||||
|
|
||||||
GaussianMixtureFactor::Sum sum;
|
|
||||||
|
|
||||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
|
||||||
|
|
||||||
for (auto &f : factors) {
|
|
||||||
if (f->isHybrid()) {
|
|
||||||
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
|
||||||
if (cgmf) {
|
|
||||||
sum = cgmf->add(sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
|
||||||
if (gm) {
|
|
||||||
sum = gm->asMixture()->add(sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if (f->isContinuous()) {
|
|
||||||
deferredFactors.push_back(
|
|
||||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
|
||||||
} else {
|
|
||||||
// We need to handle the case where the object is actually an
|
|
||||||
// BayesTreeOrphanWrapper!
|
|
||||||
auto orphan = boost::dynamic_pointer_cast<
|
|
||||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
|
||||||
if (!orphan) {
|
|
||||||
auto &fr = *f;
|
|
||||||
throw std::invalid_argument(
|
|
||||||
std::string("factor is discrete in continuous elimination") +
|
|
||||||
typeid(fr).name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto &f : deferredFactors) {
|
|
||||||
sum = addGaussian(sum, f);
|
|
||||||
}
|
|
||||||
|
|
||||||
gttoc(sum);
|
|
||||||
|
|
||||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
|
||||||
|
|
||||||
KeyVector keysOfEliminated; // Not the ordering
|
|
||||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
|
||||||
|
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
|
||||||
-> GaussianFactorGraph::EliminationResult {
|
|
||||||
if (graph.empty()) {
|
|
||||||
return {nullptr, nullptr};
|
|
||||||
}
|
|
||||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
|
||||||
if (keysOfEliminated.empty()) {
|
|
||||||
keysOfEliminated =
|
|
||||||
result.first->keys(); // Initialize the keysOfEliminated to be the
|
|
||||||
}
|
|
||||||
// keysOfEliminated of the GaussianConditional
|
|
||||||
if (keysOfSeparator.empty()) {
|
|
||||||
keysOfSeparator = result.second->keys();
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
|
|
||||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
|
||||||
|
|
||||||
auto pair = unzip(eliminationResults);
|
|
||||||
|
|
||||||
const GaussianMixtureConditional::Conditionals &conditionals = pair.first;
|
|
||||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
|
||||||
|
|
||||||
// Create the GaussianMixtureConditional from the conditionals
|
|
||||||
auto conditional = boost::make_shared<GaussianMixtureConditional>(
|
|
||||||
frontalKeys, keysOfSeparator, discreteSeparator, conditionals);
|
|
||||||
|
|
||||||
// If there are no more continuous parents, then we should create here a
|
|
||||||
// DiscreteFactor, with the error for each discrete choice.
|
|
||||||
if (keysOfSeparator.empty()) {
|
|
||||||
VectorValues empty_values;
|
|
||||||
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
|
||||||
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
|
||||||
return exp(-factor->error(empty_values));
|
|
||||||
};
|
|
||||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
|
||||||
auto discreteFactor =
|
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(conditional),
|
|
||||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
|
||||||
|
|
||||||
} else {
|
|
||||||
// Create a resulting DCGaussianMixture on the separator.
|
|
||||||
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
|
||||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
|
||||||
discreteSeparator, separatorFactors);
|
|
||||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue