Merge pull request #1832 from borglab/hybrid-enum

release/4.3a0
Varun Agrawal 2024-09-18 16:17:59 -04:00 committed by GitHub
commit 2c140df196
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 56 additions and 37 deletions

View File

@ -28,13 +28,8 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
const DiscreteKeys &discreteFrontals, const DiscreteKeys &discreteFrontals,
const KeyVector &continuousParents, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents) const DiscreteKeys &discreteParents)
: HybridConditional( : HybridConditional(CollectKeys(continuousFrontals, continuousParents),
CollectKeys( CollectDiscreteKeys(discreteFrontals, discreteParents),
{continuousFrontals.begin(), continuousFrontals.end()},
KeyVector{continuousParents.begin(), continuousParents.end()}),
CollectDiscreteKeys(
{discreteFrontals.begin(), discreteFrontals.end()},
{discreteParents.begin(), discreteParents.end()}),
continuousFrontals.size() + discreteFrontals.size()) {} continuousFrontals.size() + discreteFrontals.size()) {}
/* ************************************************************************ */ /* ************************************************************************ */
@ -56,9 +51,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
const std::shared_ptr<HybridGaussianConditional> &gaussianMixture) const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(gaussianMixture->continuousKeys(),
gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous()),
gaussianMixture->discreteKeys()), gaussianMixture->discreteKeys()),
BaseConditional(gaussianMixture->nrFrontals()) { BaseConditional(gaussianMixture->nrFrontals()) {
inner_ = gaussianMixture; inner_ = gaussianMixture;

View File

@ -50,31 +50,43 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &keys) HybridFactor::HybridFactor(const KeyVector &keys)
: Base(keys), isContinuous_(true), continuousKeys_(keys) {} : Base(keys), category_(Category::Continuous), continuousKeys_(keys) {}
/* ************************************************************************ */
HybridFactor::Category GetCategory(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys) {
if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
return HybridFactor::Category::Discrete;
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
return HybridFactor::Category::Continuous;
} else if ((continuousKeys.size() != 0) && (discreteKeys.size() != 0)) {
return HybridFactor::Category::Hybrid;
} else {
// Case where we have no keys. Should never happen.
return HybridFactor::Category::None;
}
}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const KeyVector &continuousKeys, HybridFactor::HybridFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys) const DiscreteKeys &discreteKeys)
: Base(CollectKeys(continuousKeys, discreteKeys)), : Base(CollectKeys(continuousKeys, discreteKeys)),
isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), category_(GetCategory(continuousKeys, discreteKeys)),
isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
discreteKeys_(discreteKeys), discreteKeys_(discreteKeys),
continuousKeys_(continuousKeys) {} continuousKeys_(continuousKeys) {}
/* ************************************************************************ */ /* ************************************************************************ */
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
: Base(CollectKeys({}, discreteKeys)), : Base(CollectKeys({}, discreteKeys)),
isDiscrete_(true), category_(Category::Discrete),
discreteKeys_(discreteKeys), discreteKeys_(discreteKeys),
continuousKeys_({}) {} continuousKeys_({}) {}
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridFactor::equals(const HybridFactor &lf, double tol) const { bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol) && return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && continuousKeys_ == e->continuousKeys_ &&
isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
discreteKeys_ == e->discreteKeys_; discreteKeys_ == e->discreteKeys_;
} }
@ -82,9 +94,21 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
void HybridFactor::print(const std::string &s, void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n"); std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous_) std::cout << "Continuous "; switch (category_) {
if (isDiscrete_) std::cout << "Discrete "; case Category::Continuous:
if (isHybrid_) std::cout << "Hybrid "; std::cout << "Continuous ";
break;
case Category::Discrete:
std::cout << "Discrete ";
break;
case Category::Hybrid:
std::cout << "Hybrid ";
break;
case Category::None:
std::cout << "None ";
break;
}
std::cout << "["; std::cout << "[";
for (size_t c = 0; c < continuousKeys_.size(); c++) { for (size_t c = 0; c < continuousKeys_.size(); c++) {
std::cout << formatter(continuousKeys_.at(c)); std::cout << formatter(continuousKeys_.at(c));

View File

@ -52,10 +52,13 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
* @ingroup hybrid * @ingroup hybrid
*/ */
class GTSAM_EXPORT HybridFactor : public Factor { class GTSAM_EXPORT HybridFactor : public Factor {
public:
/// Enum to help with categorizing hybrid factors.
enum class Category { None, Discrete, Continuous, Hybrid };
private: private:
bool isDiscrete_ = false; /// Record what category of HybridFactor this is.
bool isContinuous_ = false; Category category_ = Category::None;
bool isHybrid_ = false;
protected: protected:
// Set of DiscreteKeys for this factor. // Set of DiscreteKeys for this factor.
@ -116,13 +119,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @{ /// @{
/// True if this is a factor of discrete variables only. /// True if this is a factor of discrete variables only.
bool isDiscrete() const { return isDiscrete_; } bool isDiscrete() const { return category_ == Category::Discrete; }
/// True if this is a factor of continuous variables only. /// True if this is a factor of continuous variables only.
bool isContinuous() const { return isContinuous_; } bool isContinuous() const { return category_ == Category::Continuous; }
/// True is this is a Discrete-Continuous factor. /// True is this is a Discrete-Continuous factor.
bool isHybrid() const { return isHybrid_; } bool isHybrid() const { return category_ == Category::Hybrid; }
/// Return the number of continuous variables in this factor. /// Return the number of continuous variables in this factor.
size_t nrContinuous() const { return continuousKeys_.size(); } size_t nrContinuous() const { return continuousKeys_.size(); }
@ -142,9 +145,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <class ARCHIVE> template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) { void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_); ar &BOOST_SERIALIZATION_NVP(category_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_); ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_); ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
} }

View File

@ -114,10 +114,11 @@ void HybridGaussianFactorGraph::printErrors(
<< "\n"; << "\n";
} else { } else {
// Is hybrid // Is hybrid
auto mixtureComponent = auto conditionalComponent =
hc->asMixture()->operator()(values.discrete()); hc->asMixture()->operator()(values.discrete());
mixtureComponent->print(ss.str(), keyFormatter); conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << mixtureComponent->error(values) << "\n"; std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
} }
} }
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) { } else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
@ -411,10 +412,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Create the HybridGaussianConditional from the conditionals // Create the HybridGaussianConditional from the conditionals
HybridGaussianConditional::Conditionals conditionals( HybridGaussianConditional::Conditionals conditionals(
eliminationResults, [](const Result &pair) { return pair.first; }); eliminationResults, [](const Result &pair) { return pair.first; });
auto gaussianMixture = std::make_shared<HybridGaussianConditional>( auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
frontalKeys, continuousSeparator, discreteSeparator, conditionals); frontalKeys, continuousSeparator, discreteSeparator, conditionals);
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor}; return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
} }
/* ************************************************************************ /* ************************************************************************
@ -465,7 +466,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Now we will need to know how to retrieve the corresponding continuous // Now we will need to know how to retrieve the corresponding continuous
// densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO
// defined order!). We also need to consider when there is pruning. Two // defined order!). We also need to consider when there is pruning. Two
// mixture factors could have different pruning patterns - one could have // hybrid factors could have different pruning patterns - one could have
// (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this
// creates a big problem in how to identify the intersection of non-pruned // creates a big problem in how to identify the intersection of non-pruned
// branches. // branches.