Address Varun's comments
parent
d5fd279449
commit
2c4990b613
|
|
@ -50,7 +50,7 @@ GaussianMixture GaussianMixture::FromConditionalList(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::addTo(
|
||||
GaussianMixture::Sum GaussianMixture::add(
|
||||
const GaussianMixture::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
|
|
@ -58,20 +58,21 @@ GaussianMixture::Sum GaussianMixture::addTo(
|
|||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
const Sum wrapped = wrappedConditionals();
|
||||
const Sum wrapped = asGraph();
|
||||
return sum.empty() ? wrapped : sum.apply(wrapped, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::wrappedConditionals() const {
|
||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
||||
GaussianMixture::Sum GaussianMixture::asGraph() const {
|
||||
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
return {conditionals_, lambda};
|
||||
}
|
||||
|
||||
/* TODO(fan): this (for Testable) is not implemented! */
|
||||
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
* @file GaussianMixture.h
|
||||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
||||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
* @date Mar 12, 2022
|
||||
*/
|
||||
|
||||
|
|
@ -55,10 +56,10 @@ class GaussianMixture : public HybridFactor,
|
|||
const Conditionals &conditionals();
|
||||
|
||||
/* *******************************************************************************/
|
||||
Sum addTo(const Sum &sum) const;
|
||||
Sum add(const Sum &sum) const;
|
||||
|
||||
/* *******************************************************************************/
|
||||
Sum wrappedConditionals() const;
|
||||
Sum asGraph() const;
|
||||
|
||||
static This FromConditionalList(
|
||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||
|
|
@ -71,4 +72,4 @@ class GaussianMixture : public HybridFactor,
|
|||
const std::string &s = "GaussianMixture\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
};
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||
const GaussianMixtureFactor::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class GaussianMixtureFactor : public HybridFactor {
|
|||
const std::vector<GaussianFactor::shared_ptr> &factors);
|
||||
|
||||
/* *******************************************************************************/
|
||||
Sum addTo(const Sum &sum) const;
|
||||
Sum add(const Sum &sum) const;
|
||||
|
||||
/* *******************************************************************************/
|
||||
Sum wrappedFactors() const;
|
||||
|
|
|
|||
|
|
@ -76,4 +76,4 @@ void HybridFactor::print(
|
|||
|
||||
HybridFactor::~HybridFactor() = default;
|
||||
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -286,12 +286,12 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
|||
if (f->isHybrid_) {
|
||||
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
||||
if (cgmf) {
|
||||
sum = cgmf->addTo(sum);
|
||||
sum = cgmf->add(sum);
|
||||
}
|
||||
|
||||
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
||||
if (gm) {
|
||||
sum = gm->asMixture()->addTo(sum);
|
||||
sum = gm->asMixture()->add(sum);
|
||||
}
|
||||
|
||||
} else if (f->isContinuous_) {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* A HybridGaussianFactor is a wrapper for GaussianFactor so that we do not have
|
||||
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
||||
* a diamond inheritance.
|
||||
*/
|
||||
class HybridGaussianFactor : public HybridFactor {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,21 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file Switching.h
|
||||
* @date Mar 11, 2022
|
||||
* @author Varun Agrawal
|
||||
* @author Fan Jiang
|
||||
*/
|
||||
|
||||
#include <gtsam/base/Matrix.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
|
|
@ -65,4 +83,4 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
|||
return {new_order, levels};
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* @file testHybridConditional.cpp
|
||||
* @file testHybridFactorGraph.cpp
|
||||
* @date Mar 11, 2022
|
||||
* @author Fan Jiang
|
||||
*/
|
||||
Loading…
Reference in New Issue