rename atGaussian to atMixture and add new atGaussian for continuous conditionals, fix choose method for all types
parent
4e451d5c0b
commit
8c41f63167
|
|
@ -112,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||||
return factors_.at(i)->asMixture();
|
return factors_.at(i)->asMixture();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||||
|
return factors_.at(i)->asGaussian();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||||
return factors_.at(i)->asDiscreteConditional();
|
return factors_.at(i)->asDiscreteConditional();
|
||||||
|
|
@ -126,17 +131,22 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
try {
|
if (factors_.at(idx)->isHybrid()) {
|
||||||
GaussianMixture gm = *this->atGaussian(idx);
|
// If factor is hybrid, select based on assignment.
|
||||||
|
GaussianMixture gm = *this->atMixture(idx);
|
||||||
gbn.push_back(gm(assignment));
|
gbn.push_back(gm(assignment));
|
||||||
|
|
||||||
} catch (std::exception &exc) {
|
} else if (factors_.at(idx)->isContinuous()) {
|
||||||
// factor at `idx` is discrete-only, so we simply continue.
|
// If continuous only, add gaussian conditional.
|
||||||
assert(factors_.at(idx)->discreteKeys().size() ==
|
factors_.at(idx)->print();
|
||||||
factors_.at(idx)->keys().size());
|
gbn.push_back((this->atGaussian(idx)));
|
||||||
|
|
||||||
|
} else if (factors_.at(idx)->isDiscrete()) {
|
||||||
|
// If factor at `idx` is discrete-only, we simply continue.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return gbn;
|
return gbn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific Gaussian mixture by index `i`.
|
/// Get a specific Gaussian mixture by index `i`.
|
||||||
GaussianMixture::shared_ptr atGaussian(size_t i) const;
|
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||||
|
|
||||||
|
/// Get a specific Gaussian conditional by index `i`.
|
||||||
|
GaussianConditional::shared_ptr atGaussian(size_t i) const;
|
||||||
|
|
||||||
/// Get a specific discrete conditional by index `i`.
|
/// Get a specific discrete conditional by index `i`.
|
||||||
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
||||||
|
|
|
||||||
|
|
@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
|
||||||
EXPECT_LONGS_EQUAL(4, gbn.size());
|
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||||
|
|
||||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
hybridBayesNet->atGaussian(0)))(assignment),
|
hybridBayesNet->atMixture(0)))(assignment),
|
||||||
*gbn.at(0)));
|
*gbn.at(0)));
|
||||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
hybridBayesNet->atGaussian(1)))(assignment),
|
hybridBayesNet->atMixture(1)))(assignment),
|
||||||
*gbn.at(1)));
|
*gbn.at(1)));
|
||||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
hybridBayesNet->atGaussian(2)))(assignment),
|
hybridBayesNet->atMixture(2)))(assignment),
|
||||||
*gbn.at(2)));
|
*gbn.at(2)));
|
||||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||||
hybridBayesNet->atGaussian(3)))(assignment),
|
hybridBayesNet->atMixture(3)))(assignment),
|
||||||
*gbn.at(3)));
|
*gbn.at(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue