add improved versions of push_back for HybridBayesNet

release/4.3a0
Varun Agrawal 2024-08-25 13:51:33 -04:00
parent b54ed7209e
commit 351f0bd3a5
2 changed files with 70 additions and 25 deletions

View File

@ -33,6 +33,18 @@ namespace gtsam {
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
template <typename T>
struct is_shared_ptr : std::false_type {};
template <typename T>
struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
/// Helper templates for checking if a type is a shared pointer or not
template <typename T>
using IsSharedPtr = typename std::enable_if<is_shared_ptr<T>::value>::type;
template <typename T>
using IsNotSharedPtr =
typename std::enable_if<!is_shared_ptr<T>::value>::type;
public:
using Base = BayesNet<HybridConditional>;
using This = HybridBayesNet;
@ -70,20 +82,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
factors_.push_back(conditional);
}
/**
* Preferred: add a conditional directly using a pointer.
*
* Examples:
* hbn.emplace_back(new GaussianMixture(...)));
* hbn.emplace_back(new GaussianConditional(...)));
* hbn.emplace_back(new DiscreteConditional(...)));
*/
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(std::make_shared<HybridConditional>(
std::shared_ptr<Conditional>(conditional)));
}
/**
* Add a conditional using a shared_ptr, using implicit conversion to
* a HybridConditional.
@ -101,6 +99,54 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
std::make_shared<HybridConditional>(std::move(conditional)));
}
/**
* @brief Add a conditional to the Bayes net.
* Implicitly convert to a HybridConditional.
*
* E.g.
* hbn.push_back(std::make_shared<DiscreteConditional>(m, "1/1"));
*
* @tparam CONDITIONAL Type of conditional. This is shared_ptr version.
* @param conditional The conditional as a shared pointer.
* @return IsSharedPtr<CONDITIONAL>
*/
template <class CONDITIONAL>
IsSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) {
factors_.push_back(std::make_shared<HybridConditional>(conditional));
}
/**
* @brief Add a conditional to the Bayes net.
* Implicitly convert to a HybridConditional.
*
* E.g.
* hbn.push_back(DiscreteConditional(m, "1/1"));
* hbn.push_back(GaussianConditional(X(0), Vector1(0.0), I_1x1));
*
* @tparam CONDITIONAL Type of conditional. This is const ref version.
* @param conditional The conditional as a const reference.
* @return IsSharedPtr<CONDITIONAL>
*/
template <class CONDITIONAL>
IsNotSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) {
auto cond_shared_ptr = std::make_shared<CONDITIONAL>(conditional);
push_back(cond_shared_ptr);
}
/**
* Preferred: add a conditional directly using a pointer.
*
* Examples:
* hbn.emplace_back(new GaussianMixture(...)));
* hbn.emplace_back(new GaussianConditional(...)));
* hbn.emplace_back(new DiscreteConditional(...)));
*/
template <class Conditional>
void emplace_back(Conditional *conditional) {
factors_.push_back(std::make_shared<HybridConditional>(
std::shared_ptr<Conditional>(conditional)));
}
/**
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.

View File

@ -221,12 +221,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) {
auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model),
c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model);
auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1});
auto mixing = new DiscreteConditional(m, "0.5/0.5");
GaussianMixture gm({z}, {}, {m}, {c0, c1});
DiscreteConditional mixing(m, "0.5/0.5");
HybridBayesNet hbn;
hbn.emplace_back(gm);
hbn.emplace_back(mixing);
hbn.push_back(gm);
hbn.push_back(mixing);
// The result should be a sigmoid.
// So should be m = 0.5 at z=3.0 - 1.0=2.0
@ -237,7 +237,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) {
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
HybridBayesNet expected;
expected.emplace_back(new DiscreteConditional(m, "0.5/0.5"));
expected.push_back(DiscreteConditional(m, "0.5/0.5"));
EXPECT(assert_equal(expected, *bn));
}
@ -265,12 +265,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) {
auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model0),
c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1);
auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1});
auto mixing = new DiscreteConditional(m, "0.5/0.5");
GaussianMixture gm({z}, {}, {m}, {c0, c1});
DiscreteConditional mixing(m, "0.5/0.5");
HybridBayesNet hbn;
hbn.emplace_back(gm);
hbn.emplace_back(mixing);
hbn.push_back(gm);
hbn.push_back(mixing);
// The result should be a sigmoid leaning towards model1
// since it has the tighter covariance.
@ -281,8 +281,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) {
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
HybridBayesNet expected;
expected.emplace_back(
new DiscreteConditional(m, "0.338561851224/0.661438148776"));
expected.push_back(DiscreteConditional(m, "0.338561851224/0.661438148776"));
EXPECT(assert_equal(expected, *bn));
}