Working version with helper
parent
08cf399bbf
commit
69b1313eed
|
@ -70,14 +70,13 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianFactor::HybridGaussianFactor(
|
||||
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
|
||||
const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||
: Base(HybridFactor::Category::Hybrid) {
|
||||
// Extract continuous keys from first-null factor and verify all others
|
||||
KeyVector continuousKeys;
|
||||
: discreteKeys({discreteKey}) {
|
||||
// Extract continuous keys from the first non-null factor
|
||||
for (const auto &factor : factors) {
|
||||
if (!factor) continue;
|
||||
if (!factor) continue; // Skip null factors
|
||||
if (continuousKeys.empty()) {
|
||||
continuousKeys = factor->keys();
|
||||
} else if (factor->keys() != continuousKeys) {
|
||||
|
@ -85,25 +84,18 @@ HybridGaussianFactor::HybridGaussianFactor(
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize the base class
|
||||
Factor::keys_ = continuousKeys;
|
||||
Factor::keys_.push_back(discreteKey.first);
|
||||
Base::discreteKeys_ = {discreteKey};
|
||||
Base::continuousKeys_ = continuousKeys;
|
||||
|
||||
// Build the DecisionTree from factor vector
|
||||
factors_ = Factors({discreteKey}, factors);
|
||||
// Build the DecisionTree from the factor vector
|
||||
factorsTree = Factors(discreteKeys, factors);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianFactor::HybridGaussianFactor(
|
||||
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
|
||||
const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactorValuePair> &factorPairs)
|
||||
: Base(HybridFactor::Category::Hybrid) {
|
||||
// Extract continuous keys from first-null factor and verify all others
|
||||
KeyVector continuousKeys;
|
||||
: discreteKeys({discreteKey}) {
|
||||
// Extract continuous keys from the first non-null factor
|
||||
for (const auto &pair : factorPairs) {
|
||||
if (!pair.first) continue;
|
||||
if (!pair.first) continue; // Skip null factors
|
||||
if (continuousKeys.empty()) {
|
||||
continuousKeys = pair.first->keys();
|
||||
} else if (pair.first->keys() != continuousKeys) {
|
||||
|
@ -111,45 +103,26 @@ HybridGaussianFactor::HybridGaussianFactor(
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize the base class
|
||||
Factor::keys_ = continuousKeys;
|
||||
Factor::keys_.push_back(discreteKey.first);
|
||||
Base::discreteKeys_ = {discreteKey};
|
||||
Base::continuousKeys_ = continuousKeys;
|
||||
|
||||
// Build the FactorValuePairs DecisionTree
|
||||
FactorValuePairs pairTree({discreteKey}, factorPairs);
|
||||
|
||||
// Assign factors_ after calling augment
|
||||
factors_ = augment(pairTree);
|
||||
pairs = FactorValuePairs(discreteKeys, factorPairs);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||
const FactorValuePairs &factorPairs)
|
||||
: Base(HybridFactor::Category::Hybrid) {
|
||||
// Verify that all factors have the same keys
|
||||
KeyVector continuousKeys;
|
||||
HybridGaussianFactor::ConstructorHelper::ConstructorHelper(
|
||||
const DiscreteKeys &discreteKeys, const FactorValuePairs &factorPairs)
|
||||
: discreteKeys(discreteKeys) {
|
||||
// Extract continuous keys from the first non-null factor
|
||||
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
|
||||
if (pair.first) {
|
||||
if (continuousKeys.empty()) {
|
||||
continuousKeys = pair.first->keys();
|
||||
} else if (pair.first->keys() != continuousKeys) {
|
||||
throw std::invalid_argument("All factors must have the same keys");
|
||||
}
|
||||
if (!pair.first) return; // Skip null factors
|
||||
if (continuousKeys.empty()) {
|
||||
continuousKeys = pair.first->keys();
|
||||
} else if (pair.first->keys() != continuousKeys) {
|
||||
throw std::invalid_argument("All factors must have the same keys");
|
||||
}
|
||||
});
|
||||
|
||||
// Initialize the base class
|
||||
Factor::keys_ = continuousKeys;
|
||||
for (const auto &discreteKey : discreteKeys) {
|
||||
Factor::keys_.push_back(discreteKey.first);
|
||||
}
|
||||
Base::discreteKeys_ = discreteKeys;
|
||||
Base::continuousKeys_ = continuousKeys;
|
||||
|
||||
// Assign factors_ after calling augment
|
||||
factors_ = augment(factorPairs);
|
||||
// Build the FactorValuePairs DecisionTree
|
||||
pairs = factorPairs;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -89,7 +89,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
* @param factors Vector of gaussian factors, one for each mode.
|
||||
*/
|
||||
HybridGaussianFactor(const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors);
|
||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new HybridGaussianFactor on a single discrete key,
|
||||
|
@ -101,7 +102,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
* @param factorPairs Vector of gaussian factor-scalar pairs, one per mode.
|
||||
*/
|
||||
HybridGaussianFactor(const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactorValuePair> &factorPairs);
|
||||
const std::vector<GaussianFactorValuePair> &factorPairs)
|
||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new HybridGaussianFactor on a several discrete keys M,
|
||||
|
@ -113,7 +115,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
* @param factors The decision tree of Gaussian factor/scalar pairs.
|
||||
*/
|
||||
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||
const FactorValuePairs &factors);
|
||||
const FactorValuePairs &factors)
|
||||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
@ -193,6 +196,30 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
double potentiallyPrunedComponentError(
|
||||
const sharedFactor &gf, const VectorValues &continuousValues) const;
|
||||
|
||||
/// Helper struct to assist in constructing the HybridGaussianFactor
|
||||
struct ConstructorHelper {
|
||||
KeyVector continuousKeys; // Continuous keys extracted from factors
|
||||
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
||||
FactorValuePairs pairs; // Used only if factorsTree is empty
|
||||
Factors factorsTree;
|
||||
|
||||
ConstructorHelper(
|
||||
const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactor::shared_ptr> &factorsVec);
|
||||
|
||||
ConstructorHelper(const DiscreteKey &discreteKey,
|
||||
const std::vector<GaussianFactorValuePair> &factorPairs);
|
||||
|
||||
ConstructorHelper(const DiscreteKeys &discreteKeys,
|
||||
const FactorValuePairs &factorPairs);
|
||||
};
|
||||
|
||||
// Private constructor using ConstructorHelper
|
||||
HybridGaussianFactor(const ConstructorHelper &helper)
|
||||
: Base(helper.continuousKeys, helper.discreteKeys),
|
||||
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
|
||||
: helper.factorsTree) {}
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
|
@ -101,12 +101,6 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
|
||||
auto hbn = GaussianMixtureModel(mu0, mu1, sigma, sigma);
|
||||
|
||||
// Check the number of keys matches what we expect
|
||||
auto hgc = hbn.at(0)->asHybrid();
|
||||
EXPECT_LONGS_EQUAL(2, hgc->keys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size());
|
||||
|
||||
// At the halfway point between the means, we should get P(m|z)=0.5
|
||||
double midway = mu1 - mu0;
|
||||
auto pMid = SolveHBN(hbn, midway);
|
||||
|
@ -141,12 +135,6 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
|
||||
auto hbn = GaussianMixtureModel(mu0, mu1, sigma0, sigma1);
|
||||
|
||||
// Check the number of keys matches what we expect
|
||||
auto hgc = hbn.at(0)->asHybrid();
|
||||
EXPECT_LONGS_EQUAL(2, hgc->keys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hgc->continuousKeys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hgc->discreteKeys().size());
|
||||
|
||||
// We get zMax=3.1333 by finding the maximum value of the function, at which
|
||||
// point the mode m==1 is about twice as probable as m==0.
|
||||
double zMax = 3.133;
|
||||
|
|
Loading…
Reference in New Issue