Merge pull request #1832 from borglab/hybrid-enum
commit
2c140df196
|
|
@ -28,13 +28,8 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
||||||
const DiscreteKeys &discreteFrontals,
|
const DiscreteKeys &discreteFrontals,
|
||||||
const KeyVector &continuousParents,
|
const KeyVector &continuousParents,
|
||||||
const DiscreteKeys &discreteParents)
|
const DiscreteKeys &discreteParents)
|
||||||
: HybridConditional(
|
: HybridConditional(CollectKeys(continuousFrontals, continuousParents),
|
||||||
CollectKeys(
|
CollectDiscreteKeys(discreteFrontals, discreteParents),
|
||||||
{continuousFrontals.begin(), continuousFrontals.end()},
|
|
||||||
KeyVector{continuousParents.begin(), continuousParents.end()}),
|
|
||||||
CollectDiscreteKeys(
|
|
||||||
{discreteFrontals.begin(), discreteFrontals.end()},
|
|
||||||
{discreteParents.begin(), discreteParents.end()}),
|
|
||||||
continuousFrontals.size() + discreteFrontals.size()) {}
|
continuousFrontals.size() + discreteFrontals.size()) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -56,9 +51,7 @@ HybridConditional::HybridConditional(
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
|
const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
|
||||||
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
: BaseFactor(gaussianMixture->continuousKeys(),
|
||||||
gaussianMixture->keys().begin() +
|
|
||||||
gaussianMixture->nrContinuous()),
|
|
||||||
gaussianMixture->discreteKeys()),
|
gaussianMixture->discreteKeys()),
|
||||||
BaseConditional(gaussianMixture->nrFrontals()) {
|
BaseConditional(gaussianMixture->nrFrontals()) {
|
||||||
inner_ = gaussianMixture;
|
inner_ = gaussianMixture;
|
||||||
|
|
|
||||||
|
|
@ -50,31 +50,43 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridFactor::HybridFactor(const KeyVector &keys)
|
HybridFactor::HybridFactor(const KeyVector &keys)
|
||||||
: Base(keys), isContinuous_(true), continuousKeys_(keys) {}
|
: Base(keys), category_(Category::Continuous), continuousKeys_(keys) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
HybridFactor::Category GetCategory(const KeyVector &continuousKeys,
|
||||||
|
const DiscreteKeys &discreteKeys) {
|
||||||
|
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
|
||||||
|
return HybridFactor::Category::Discrete;
|
||||||
|
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
|
||||||
|
return HybridFactor::Category::Continuous;
|
||||||
|
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) {
|
||||||
|
return HybridFactor::Category::Hybrid;
|
||||||
|
} else {
|
||||||
|
// Case where we have no keys. Should never happen.
|
||||||
|
return HybridFactor::Category::None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
|
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys)
|
const DiscreteKeys &discreteKeys)
|
||||||
: Base(CollectKeys(continuousKeys, discreteKeys)),
|
: Base(CollectKeys(continuousKeys, discreteKeys)),
|
||||||
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
|
category_(GetCategory(continuousKeys, discreteKeys)),
|
||||||
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
|
|
||||||
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
|
|
||||||
discreteKeys_(discreteKeys),
|
discreteKeys_(discreteKeys),
|
||||||
continuousKeys_(continuousKeys) {}
|
continuousKeys_(continuousKeys) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
|
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
|
||||||
: Base(CollectKeys({}, discreteKeys)),
|
: Base(CollectKeys({}, discreteKeys)),
|
||||||
isDiscrete_(true),
|
category_(Category::Discrete),
|
||||||
discreteKeys_(discreteKeys),
|
discreteKeys_(discreteKeys),
|
||||||
continuousKeys_({}) {}
|
continuousKeys_({}) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
return e != nullptr && Base::equals(*e, tol) &&
|
return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
|
||||||
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
|
continuousKeys_ == e->continuousKeys_ &&
|
||||||
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
|
|
||||||
discreteKeys_ == e->discreteKeys_;
|
discreteKeys_ == e->discreteKeys_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -82,9 +94,21 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
void HybridFactor::print(const std::string &s,
|
void HybridFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
if (isContinuous_) std::cout << "Continuous ";
|
switch (category_) {
|
||||||
if (isDiscrete_) std::cout << "Discrete ";
|
case Category::Continuous:
|
||||||
if (isHybrid_) std::cout << "Hybrid ";
|
std::cout << "Continuous ";
|
||||||
|
break;
|
||||||
|
case Category::Discrete:
|
||||||
|
std::cout << "Discrete ";
|
||||||
|
break;
|
||||||
|
case Category::Hybrid:
|
||||||
|
std::cout << "Hybrid ";
|
||||||
|
break;
|
||||||
|
case Category::None:
|
||||||
|
std::cout << "None ";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
std::cout << "[";
|
std::cout << "[";
|
||||||
for (size_t c = 0; c < continuousKeys_.size(); c++) {
|
for (size_t c = 0; c < continuousKeys_.size(); c++) {
|
||||||
std::cout << formatter(continuousKeys_.at(c));
|
std::cout << formatter(continuousKeys_.at(c));
|
||||||
|
|
|
||||||
|
|
@ -52,10 +52,13 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
|
||||||
* @ingroup hybrid
|
* @ingroup hybrid
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridFactor : public Factor {
|
class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
|
public:
|
||||||
|
/// Enum to help with categorizing hybrid factors.
|
||||||
|
enum class Category { None, Discrete, Continuous, Hybrid };
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool isDiscrete_ = false;
|
/// Record what category of HybridFactor this is.
|
||||||
bool isContinuous_ = false;
|
Category category_ = Category::None;
|
||||||
bool isHybrid_ = false;
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Set of DiscreteKeys for this factor.
|
// Set of DiscreteKeys for this factor.
|
||||||
|
|
@ -116,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// True if this is a factor of discrete variables only.
|
/// True if this is a factor of discrete variables only.
|
||||||
bool isDiscrete() const { return isDiscrete_; }
|
bool isDiscrete() const { return category_ == Category::Discrete; }
|
||||||
|
|
||||||
/// True if this is a factor of continuous variables only.
|
/// True if this is a factor of continuous variables only.
|
||||||
bool isContinuous() const { return isContinuous_; }
|
bool isContinuous() const { return category_ == Category::Continuous; }
|
||||||
|
|
||||||
/// True is this is a Discrete-Continuous factor.
|
/// True is this is a Discrete-Continuous factor.
|
||||||
bool isHybrid() const { return isHybrid_; }
|
bool isHybrid() const { return category_ == Category::Hybrid; }
|
||||||
|
|
||||||
/// Return the number of continuous variables in this factor.
|
/// Return the number of continuous variables in this factor.
|
||||||
size_t nrContinuous() const { return continuousKeys_.size(); }
|
size_t nrContinuous() const { return continuousKeys_.size(); }
|
||||||
|
|
@ -142,9 +145,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
template <class ARCHIVE>
|
template <class ARCHIVE>
|
||||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
|
ar &BOOST_SERIALIZATION_NVP(category_);
|
||||||
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
|
|
||||||
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
|
|
||||||
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
|
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
|
||||||
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
|
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -114,10 +114,11 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
<< "\n";
|
<< "\n";
|
||||||
} else {
|
} else {
|
||||||
// Is hybrid
|
// Is hybrid
|
||||||
auto mixtureComponent =
|
auto conditionalComponent =
|
||||||
hc->asMixture()->operator()(values.discrete());
|
hc->asMixture()->operator()(values.discrete());
|
||||||
mixtureComponent->print(ss.str(), keyFormatter);
|
conditionalComponent->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = " << mixtureComponent->error(values) << "\n";
|
std::cout << "error = " << conditionalComponent->error(values)
|
||||||
|
<< "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
|
@ -411,10 +412,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional from the conditionals
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::Conditionals conditionals(
|
||||||
eliminationResults, [](const Result &pair) { return pair.first; });
|
eliminationResults, [](const Result &pair) { return pair.first; });
|
||||||
auto gaussianMixture = std::make_shared<HybridGaussianConditional>(
|
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
||||||
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
frontalKeys, continuousSeparator, discreteSeparator, conditionals);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
|
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************
|
/* ************************************************************************
|
||||||
|
|
@ -465,7 +466,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
// Now we will need to know how to retrieve the corresponding continuous
|
// Now we will need to know how to retrieve the corresponding continuous
|
||||||
// densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
|
// densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
|
||||||
// defined order!). We also need to consider when there is pruning. Two
|
// defined order!). We also need to consider when there is pruning. Two
|
||||||
// mixture factors could have different pruning patterns - one could have
|
// hybrid factors could have different pruning patterns - one could have
|
||||||
// (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
|
// (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
|
||||||
// creates a big problem in how to identify the intersection of non-pruned
|
// creates a big problem in how to identify the intersection of non-pruned
|
||||||
// branches.
|
// branches.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue