break up EliminateHybrid into smaller functions
parent
3bde044248
commit
6cd20fba4d
|
|
@ -55,14 +55,6 @@ namespace gtsam {
|
|||
|
||||
template class EliminateableFactorGraph<HybridFactorGraph>;
|
||||
|
||||
static std::string BLACK_BOLD = "\033[1;30m";
|
||||
static std::string RED_BOLD = "\033[1;31m";
|
||||
static std::string GREEN = "\033[0;32m";
|
||||
static std::string GREEN_BOLD = "\033[1;32m";
|
||||
static std::string RESET = "\033[0m";
|
||||
|
||||
constexpr bool DEBUG = false;
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianMixtureFactor::Sum &addGaussian(
|
||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||
|
|
@ -84,6 +76,163 @@ static GaussianMixtureFactor::Sum &addGaussian(
|
|||
return sum;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
continuousElimination(const HybridFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
GaussianFactorGraph gfg;
|
||||
for (auto &fp : factors) {
|
||||
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
|
||||
if (ptr) {
|
||||
gfg.push_back(ptr->inner());
|
||||
} else {
|
||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||
if (p) {
|
||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
||||
return std::make_pair(
|
||||
boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridGaussianFactor>(result.second));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
discreteElimination(const HybridFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto &fp : factors) {
|
||||
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
|
||||
if (ptr) {
|
||||
dfg.push_back(ptr->inner());
|
||||
} else {
|
||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||
if (p) {
|
||||
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
|
||||
} else {
|
||||
// It is an orphan wrapper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
|
||||
return std::make_pair(
|
||||
boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridDiscreteFactor>(result.second));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||
hybridElimination(const HybridFactorGraph &factors, const Ordering &frontalKeys,
|
||||
const KeySet &continuousSeparator,
|
||||
const std::set<DiscreteKey> &discreteSeparatorSet) {
|
||||
// NOTE: since we use the special JunctionTree,
|
||||
// only possiblity is continuous conditioned on discrete.
|
||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
|
||||
GaussianMixtureFactor::Sum sum;
|
||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (f->isHybrid()) {
|
||||
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
||||
if (cgmf) {
|
||||
sum = cgmf->add(sum);
|
||||
}
|
||||
|
||||
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
||||
if (gm) {
|
||||
sum = gm->asMixture()->add(sum);
|
||||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
deferredFactors.push_back(
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
||||
} else {
|
||||
// We need to handle the case where the object is actually an
|
||||
// BayesTreeOrphanWrapper!
|
||||
auto orphan = boost::dynamic_pointer_cast<
|
||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
||||
if (!orphan) {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("factor is discrete in continuous elimination") +
|
||||
typeid(fr).name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &f : deferredFactors) {
|
||||
sum = addGaussian(sum, f);
|
||||
}
|
||||
|
||||
gttoc(sum);
|
||||
|
||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
||||
-> GaussianFactorGraph::EliminationResult {
|
||||
if (graph.empty()) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
if (keysOfEliminated.empty()) {
|
||||
keysOfEliminated =
|
||||
result.first->keys(); // Initialize the keysOfEliminated to be the
|
||||
}
|
||||
// keysOfEliminated of the GaussianConditional
|
||||
if (keysOfSeparator.empty()) {
|
||||
keysOfSeparator = result.second->keys();
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
||||
|
||||
auto pair = unzip(eliminationResults);
|
||||
|
||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
||||
|
||||
// Create the GaussianMixtureConditional from the conditionals
|
||||
auto conditional = boost::make_shared<GaussianMixtureConditional>(
|
||||
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
|
||||
|
||||
// If there are no more continuous parents, then we should create here a
|
||||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
||||
return exp(-factor->error(empty_values));
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
||||
auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(conditional),
|
||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||
|
||||
} else {
|
||||
// Create a resulting DCGaussianMixture on the separator.
|
||||
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||
discreteSeparator, separatorFactors);
|
||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
||||
}
|
||||
}
|
||||
/* ************************************************************************ */
|
||||
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
||||
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
||||
|
|
@ -177,154 +326,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
|
||||
// Case 1: we are only dealing with continuous
|
||||
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
||||
GaussianFactorGraph gfg;
|
||||
for (auto &fp : factors) {
|
||||
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
|
||||
if (ptr) {
|
||||
gfg.push_back(ptr->inner());
|
||||
} else {
|
||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||
if (p) {
|
||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto result = EliminatePreferCholesky(gfg, frontalKeys);
|
||||
return std::make_pair(
|
||||
boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridGaussianFactor>(result.second));
|
||||
return continuousElimination(factors, frontalKeys);
|
||||
}
|
||||
|
||||
// Case 2: we are only dealing with discrete
|
||||
if (allContinuousKeys.empty()) {
|
||||
DiscreteFactorGraph dfg;
|
||||
for (auto &fp : factors) {
|
||||
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
|
||||
if (ptr) {
|
||||
dfg.push_back(ptr->inner());
|
||||
} else {
|
||||
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
|
||||
if (p) {
|
||||
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
|
||||
} else {
|
||||
// It is an orphan wrapper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto result = EliminateDiscrete(dfg, frontalKeys);
|
||||
|
||||
return std::make_pair(
|
||||
boost::make_shared<HybridConditional>(result.first),
|
||||
boost::make_shared<HybridDiscreteFactor>(result.second));
|
||||
return discreteElimination(factors, frontalKeys);
|
||||
}
|
||||
|
||||
// Case 3: We are now in the hybrid land!
|
||||
// NOTE: since we use the special JunctionTree, only possiblity is cont.
|
||||
// conditioned on disc.
|
||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
|
||||
GaussianMixtureFactor::Sum sum;
|
||||
|
||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
||||
|
||||
for (auto &f : factors) {
|
||||
if (f->isHybrid()) {
|
||||
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
||||
if (cgmf) {
|
||||
sum = cgmf->add(sum);
|
||||
}
|
||||
|
||||
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
||||
if (gm) {
|
||||
sum = gm->asMixture()->add(sum);
|
||||
}
|
||||
|
||||
} else if (f->isContinuous()) {
|
||||
deferredFactors.push_back(
|
||||
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
|
||||
} else {
|
||||
// We need to handle the case where the object is actually an
|
||||
// BayesTreeOrphanWrapper!
|
||||
auto orphan = boost::dynamic_pointer_cast<
|
||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
|
||||
if (!orphan) {
|
||||
auto &fr = *f;
|
||||
throw std::invalid_argument(
|
||||
std::string("factor is discrete in continuous elimination") +
|
||||
typeid(fr).name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &f : deferredFactors) {
|
||||
sum = addGaussian(sum, f);
|
||||
}
|
||||
|
||||
gttoc(sum);
|
||||
|
||||
using EliminationPair = GaussianFactorGraph::EliminationResult;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph)
|
||||
-> GaussianFactorGraph::EliminationResult {
|
||||
if (graph.empty()) {
|
||||
return {nullptr, nullptr};
|
||||
}
|
||||
auto result = EliminatePreferCholesky(graph, frontalKeys);
|
||||
if (keysOfEliminated.empty()) {
|
||||
keysOfEliminated =
|
||||
result.first->keys(); // Initialize the keysOfEliminated to be the
|
||||
}
|
||||
// keysOfEliminated of the GaussianConditional
|
||||
if (keysOfSeparator.empty()) {
|
||||
keysOfSeparator = result.second->keys();
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
|
||||
|
||||
auto pair = unzip(eliminationResults);
|
||||
|
||||
const GaussianMixtureConditional::Conditionals &conditionals = pair.first;
|
||||
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
|
||||
|
||||
// Create the GaussianMixtureConditional from the conditionals
|
||||
auto conditional = boost::make_shared<GaussianMixtureConditional>(
|
||||
frontalKeys, keysOfSeparator, discreteSeparator, conditionals);
|
||||
|
||||
// If there are no more continuous parents, then we should create here a
|
||||
// DiscreteFactor, with the error for each discrete choice.
|
||||
if (keysOfSeparator.empty()) {
|
||||
VectorValues empty_values;
|
||||
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
|
||||
if (!factor) return 0.0; // TODO(fan): does this make sense?
|
||||
return exp(-factor->error(empty_values));
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorError);
|
||||
auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
||||
return {boost::make_shared<HybridConditional>(conditional),
|
||||
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
|
||||
|
||||
} else {
|
||||
// Create a resulting DCGaussianMixture on the separator.
|
||||
auto factor = boost::make_shared<GaussianMixtureFactor>(
|
||||
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
|
||||
discreteSeparator, separatorFactors);
|
||||
return {boost::make_shared<HybridConditional>(conditional), factor};
|
||||
}
|
||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||
discreteSeparatorSet);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
Loading…
Reference in New Issue