Working version with helper

release/4.3a0
Frank Dellaert 2024-09-26 13:22:32 -07:00
parent 08cf399bbf
commit 69b1313eed
3 changed files with 52 additions and 64 deletions

View File

@ -70,14 +70,13 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: Base(HybridFactor::Category::Hybrid) {
// Extract continuous keys from first-null factor and verify all others
KeyVector continuousKeys;
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &factor : factors) {
if (!factor) continue;
if (!factor) continue; // Skip null factors
if (continuousKeys.empty()) {
continuousKeys = factor->keys();
} else if (factor->keys() != continuousKeys) {
@ -85,25 +84,18 @@ HybridGaussianFactor::HybridGaussianFactor(
}
}
// Initialize the base class
Factor::keys_ = continuousKeys;
Factor::keys_.push_back(discreteKey.first);
Base::discreteKeys_ = {discreteKey};
Base::continuousKeys_ = continuousKeys;
// Build the DecisionTree from factor vector
factors_ = Factors({discreteKey}, factors);
// Build the DecisionTree from the factor vector
factorsTree = Factors(discreteKeys, factors);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: Base(HybridFactor::Category::Hybrid) {
// Extract continuous keys from first-null factor and verify all others
KeyVector continuousKeys;
: discreteKeys({discreteKey}) {
// Extract continuous keys from the first non-null factor
for (const auto &pair : factorPairs) {
if (!pair.first) continue;
if (!pair.first) continue; // Skip null factors
if (continuousKeys.empty()) {
continuousKeys = pair.first->keys();
} else if (pair.first->keys() != continuousKeys) {
@ -111,45 +103,26 @@ HybridGaussianFactor::HybridGaussianFactor(
}
}
// Initialize the base class
Factor::keys_ = continuousKeys;
Factor::keys_.push_back(discreteKey.first);
Base::discreteKeys_ = {discreteKey};
Base::continuousKeys_ = continuousKeys;
// Build the FactorValuePairs DecisionTree
FactorValuePairs pairTree({discreteKey}, factorPairs);
// Assign factors_ after calling augment
factors_ = augment(pairTree);
pairs = FactorValuePairs(discreteKeys, factorPairs);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs)
: Base(HybridFactor::Category::Hybrid) {
// Verify that all factors have the same keys
KeyVector continuousKeys;
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
const DiscreteKeys &discreteKeys, const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first) {
if (continuousKeys.empty()) {
continuousKeys = pair.first->keys();
} else if (pair.first->keys() != continuousKeys) {
throw std::invalid_argument("All factors must have the same keys");
}
if (!pair.first) return; // Skip null factors
if (continuousKeys.empty()) {
continuousKeys = pair.first->keys();
} else if (pair.first->keys() != continuousKeys) {
throw std::invalid_argument("All factors must have the same keys");
}
});
// Initialize the base class
Factor::keys_ = continuousKeys;
for (const auto &discreteKey : discreteKeys) {
Factor::keys_.push_back(discreteKey.first);
}
Base::discreteKeys_ = discreteKeys;
Base::continuousKeys_ = continuousKeys;
// Assign factors_ after calling augment
factors_ = augment(factorPairs);
// Build the FactorValuePairs DecisionTree
pairs = factorPairs;
}
/* *******************************************************************************/

View File

@ -89,7 +89,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factors Vector of gaussian factors, one for each mode.
*/
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors);
const std::vector<GaussianFactor::shared_ptr> &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
/**
* @brief Construct a new HybridGaussianFactor on a single discrete key,
@ -101,7 +102,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factorPairs Vector of gaussian factor-scalar pairs, one per mode.
*/
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs);
const std::vector<GaussianFactorValuePair> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
/**
* @brief Construct a new HybridGaussianFactor on a several discrete keys M,
@ -113,7 +115,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* @param factors The decision tree of Gaussian factor/scalar pairs.
*/
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors);
const FactorValuePairs &factors)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
/// @}
/// @name Testable
@ -193,6 +196,30 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;
/// Helper struct to assist in constructing the HybridGaussianFactor
struct ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty
Factors factorsTree;
ConstructorHelper(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factorsVec);
ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs);
ConstructorHelper(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs);
};
// Private constructor using ConstructorHelper
HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
: helper.factorsTree) {}
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;

View File

@ -101,12 +101,6 @@ TEST(GaussianMixture, GaussianMixtureModel) {
auto hbn = GaussianMixtureModel(mu0, mu1, sigma, sigma);
// Check the number of keys matches what we expect
auto hgc = hbn.at(0)->asHybrid();
EXPECT_LONGS_EQUAL(2, hgc->keys().size());
EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size());
EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size());
// At the halfway point between the means, we should get P(m|z)=0.5
double midway = mu1 - mu0;
auto pMid = SolveHBN(hbn, midway);
@ -141,12 +135,6 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
auto hbn = GaussianMixtureModel(mu0, mu1, sigma0, sigma1);
// Check the number of keys matches what we expect
auto hgc = hbn.at(0)->asHybrid();
EXPECT_LONGS_EQUAL(2, hgc->keys().size());
EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size());
EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size());
// We get zMax=3.1333 by finding the maximum value of the function, at which
// point the mode m==1 is about twice as probable as m==0.
double zMax = 3.133;