break up EliminateHybrid into smaller functions

release/4.3a0
Varun Agrawal 2022-05-27 18:34:13 -04:00
parent 3bde044248
commit 6cd20fba4d
1 changed files with 161 additions and 149 deletions

View File

@ -55,14 +55,6 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridFactorGraph>; template class EliminateableFactorGraph<HybridFactorGraph>;
static std::string BLACK_BOLD = "\033[1;30m";
static std::string RED_BOLD = "\033[1;31m";
static std::string GREEN = "\033[0;32m";
static std::string GREEN_BOLD = "\033[1;32m";
static std::string RESET = "\033[0m";
constexpr bool DEBUG = false;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
@ -84,6 +76,163 @@ static GaussianMixtureFactor::Sum &addGaussian(
return sum; return sum;
} }
/* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
continuousElimination(const HybridFactorGraph &factors,
const Ordering &frontalKeys) {
GaussianFactorGraph gfg;
for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
if (ptr) {
gfg.push_back(ptr->inner());
} else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else {
// It is an orphan wrapped conditional
}
}
}
auto result = EliminatePreferCholesky(gfg, frontalKeys);
return std::make_pair(
boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridGaussianFactor>(result.second));
}
/* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
discreteElimination(const HybridFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg;
for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
if (ptr) {
dfg.push_back(ptr->inner());
} else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) {
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
} else {
// It is an orphan wrapper
}
}
}
auto result = EliminateDiscrete(dfg, frontalKeys);
return std::make_pair(
boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridDiscreteFactor>(result.second));
}
/* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
hybridElimination(const HybridFactorGraph &factors, const Ordering &frontalKeys,
const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
// sum out frontals, this is the factor on the separator
gttic(sum);
GaussianMixtureFactor::Sum sum;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) {
if (f->isHybrid()) {
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
if (cgmf) {
sum = cgmf->add(sum);
}
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
if (gm) {
sum = gm->asMixture()->add(sum);
}
} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
if (!orphan) {
auto &fr = *f;
throw std::invalid_argument(
std::string("factor is discrete in continuous elimination") +
typeid(fr).name());
}
}
}
for (auto &f : deferredFactors) {
sum = addGaussian(sum, f);
}
gttoc(sum);
using EliminationPair = GaussianFactorGraph::EliminationResult;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
auto eliminate = [&](const GaussianFactorGraph &graph)
-> GaussianFactorGraph::EliminationResult {
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
}
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys();
}
return result;
};
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
auto pair = unzip(eliminationResults);
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixtureConditional from the conditionals
auto conditional = boost::make_shared<GaussianMixtureConditional>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first);
// If there are no more continuous parents, then we should create here a
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
VectorValues empty_values;
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) return 0.0; // TODO(fan): does this make sense?
return exp(-factor->error(empty_values));
};
DecisionTree<Key, double> fdt(separatorFactors, factorError);
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else {
// Create a resulting DCGaussianMixture on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}
}
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
@ -177,154 +326,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// Case 1: we are only dealing with continuous // Case 1: we are only dealing with continuous
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) { if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
GaussianFactorGraph gfg; return continuousElimination(factors, frontalKeys);
for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp);
if (ptr) {
gfg.push_back(ptr->inner());
} else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else {
// It is an orphan wrapped conditional
}
}
}
auto result = EliminatePreferCholesky(gfg, frontalKeys);
return std::make_pair(
boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridGaussianFactor>(result.second));
} }
// Case 2: we are only dealing with discrete // Case 2: we are only dealing with discrete
if (allContinuousKeys.empty()) { if (allContinuousKeys.empty()) {
DiscreteFactorGraph dfg; return discreteElimination(factors, frontalKeys);
for (auto &fp : factors) {
auto ptr = boost::dynamic_pointer_cast<HybridDiscreteFactor>(fp);
if (ptr) {
dfg.push_back(ptr->inner());
} else {
auto p = boost::static_pointer_cast<HybridConditional>(fp)->inner();
if (p) {
dfg.push_back(boost::static_pointer_cast<DiscreteConditional>(p));
} else {
// It is an orphan wrapper
}
}
}
auto result = EliminateDiscrete(dfg, frontalKeys);
return std::make_pair(
boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridDiscreteFactor>(result.second));
} }
// Case 3: We are now in the hybrid land! // Case 3: We are now in the hybrid land!
// NOTE: since we use the special JunctionTree, only possiblity is cont. return hybridElimination(factors, frontalKeys, continuousSeparator,
// conditioned on disc. discreteSeparatorSet);
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
// sum out frontals, this is the factor on the separator
gttic(sum);
GaussianMixtureFactor::Sum sum;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) {
if (f->isHybrid()) {
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
if (cgmf) {
sum = cgmf->add(sum);
}
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
if (gm) {
sum = gm->asMixture()->add(sum);
}
} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
} else {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
auto orphan = boost::dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f);
if (!orphan) {
auto &fr = *f;
throw std::invalid_argument(
std::string("factor is discrete in continuous elimination") +
typeid(fr).name());
}
}
}
for (auto &f : deferredFactors) {
sum = addGaussian(sum, f);
}
gttoc(sum);
using EliminationPair = GaussianFactorGraph::EliminationResult;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
auto eliminate = [&](const GaussianFactorGraph &graph)
-> GaussianFactorGraph::EliminationResult {
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
}
// keysOfEliminated of the GaussianConditional
if (keysOfSeparator.empty()) {
keysOfSeparator = result.second->keys();
}
return result;
};
DecisionTree<Key, EliminationPair> eliminationResults(sum, eliminate);
auto pair = unzip(eliminationResults);
const GaussianMixtureConditional::Conditionals &conditionals = pair.first;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixtureConditional from the conditionals
auto conditional = boost::make_shared<GaussianMixtureConditional>(
frontalKeys, keysOfSeparator, discreteSeparator, conditionals);
// If there are no more continuous parents, then we should create here a
// DiscreteFactor, with the error for each discrete choice.
if (keysOfSeparator.empty()) {
VectorValues empty_values;
auto factorError = [&](const GaussianFactor::shared_ptr &factor) {
if (!factor) return 0.0; // TODO(fan): does this make sense?
return exp(-factor->error(empty_values));
};
DecisionTree<Key, double> fdt(separatorFactors, factorError);
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else {
// Create a resulting DCGaussianMixture on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}
} }
/* ************************************************************************ */ /* ************************************************************************ */