rename atGaussian to atMixture and add new atGaussian for continuous conditionals, fix choose method for all types
parent
4e451d5c0b
commit
8c41f63167
|
|
@ -112,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
|
||||
return factors_.at(i)->asMixture();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||
return factors_.at(i)->asGaussian();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||
return factors_.at(i)->asDiscreteConditional();
|
||||
|
|
@ -126,17 +131,22 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
try {
|
||||
GaussianMixture gm = *this->atGaussian(idx);
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *this->atMixture(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} catch (std::exception &exc) {
|
||||
// factor at `idx` is discrete-only, so we simply continue.
|
||||
assert(factors_.at(idx)->discreteKeys().size() ==
|
||||
factors_.at(idx)->keys().size());
|
||||
} else if (factors_.at(idx)->isContinuous()) {
|
||||
// If continuous only, add gaussian conditional.
|
||||
factors_.at(idx)->print();
|
||||
gbn.push_back((this->atGaussian(idx)));
|
||||
|
||||
} else if (factors_.at(idx)->isDiscrete()) {
|
||||
// If factor at `idx` is discrete-only, we simply continue.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return gbn;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
}
|
||||
|
||||
/// Get a specific Gaussian mixture by index `i`.
|
||||
GaussianMixture::shared_ptr atGaussian(size_t i) const;
|
||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||
|
||||
/// Get a specific Gaussian conditional by index `i`.
|
||||
GaussianConditional::shared_ptr atGaussian(size_t i) const;
|
||||
|
||||
/// Get a specific discrete conditional by index `i`.
|
||||
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
|
||||
|
|
|
|||
|
|
@ -73,16 +73,16 @@ TEST(HybridBayesNet, Choose) {
|
|||
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(0)))(assignment),
|
||||
hybridBayesNet->atMixture(0)))(assignment),
|
||||
*gbn.at(0)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(1)))(assignment),
|
||||
hybridBayesNet->atMixture(1)))(assignment),
|
||||
*gbn.at(1)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(2)))(assignment),
|
||||
hybridBayesNet->atMixture(2)))(assignment),
|
||||
*gbn.at(2)));
|
||||
EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
|
||||
hybridBayesNet->atGaussian(3)))(assignment),
|
||||
hybridBayesNet->atMixture(3)))(assignment),
|
||||
*gbn.at(3)));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue