Address Varun's comments
parent
d5fd279449
commit
2c4990b613
|
|
@ -50,7 +50,7 @@ GaussianMixture GaussianMixture::FromConditionalList(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::addTo(
|
GaussianMixture::Sum GaussianMixture::add(
|
||||||
const GaussianMixture::Sum &sum) const {
|
const GaussianMixture::Sum &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
|
|
@ -58,20 +58,21 @@ GaussianMixture::Sum GaussianMixture::addTo(
|
||||||
result.push_back(graph2);
|
result.push_back(graph2);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
const Sum wrapped = wrappedConditionals();
|
const Sum wrapped = asGraph();
|
||||||
return sum.empty() ? wrapped : sum.apply(wrapped, add);
|
return sum.empty() ? wrapped : sum.apply(wrapped, add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::wrappedConditionals() const {
|
GaussianMixture::Sum GaussianMixture::asGraph() const {
|
||||||
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
|
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return result;
|
return result;
|
||||||
};
|
};
|
||||||
return {conditionals_, wrap};
|
return {conditionals_, lambda};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* TODO(fan): this (for Testable) is not implemented! */
|
||||||
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
* @file GaussianMixture.h
|
* @file GaussianMixture.h
|
||||||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
* @date Mar 12, 2022
|
* @date Mar 12, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -55,10 +56,10 @@ class GaussianMixture : public HybridFactor,
|
||||||
const Conditionals &conditionals();
|
const Conditionals &conditionals();
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
Sum addTo(const Sum &sum) const;
|
Sum add(const Sum &sum) const;
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
Sum wrappedConditionals() const;
|
Sum asGraph() const;
|
||||||
|
|
||||||
static This FromConditionalList(
|
static This FromConditionalList(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::addTo(
|
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||||
const GaussianMixtureFactor::Sum &sum) const {
|
const GaussianMixtureFactor::Sum &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,7 @@ class GaussianMixtureFactor : public HybridFactor {
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors);
|
const std::vector<GaussianFactor::shared_ptr> &factors);
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
Sum addTo(const Sum &sum) const;
|
Sum add(const Sum &sum) const;
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
Sum wrappedFactors() const;
|
Sum wrappedFactors() const;
|
||||||
|
|
|
||||||
|
|
@ -286,12 +286,12 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
|
||||||
if (f->isHybrid_) {
|
if (f->isHybrid_) {
|
||||||
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f);
|
||||||
if (cgmf) {
|
if (cgmf) {
|
||||||
sum = cgmf->addTo(sum);
|
sum = cgmf->add(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
auto gm = boost::dynamic_pointer_cast<HybridConditional>(f);
|
||||||
if (gm) {
|
if (gm) {
|
||||||
sum = gm->asMixture()->addTo(sum);
|
sum = gm->asMixture()->add(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (f->isContinuous_) {
|
} else if (f->isContinuous_) {
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A HybridGaussianFactor is a wrapper for GaussianFactor so that we do not have
|
* A HybridGaussianFactor is a layer over GaussianFactor so that we do not have
|
||||||
* a diamond inheritance.
|
* a diamond inheritance.
|
||||||
*/
|
*/
|
||||||
class HybridGaussianFactor : public HybridFactor {
|
class HybridGaussianFactor : public HybridFactor {
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,21 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @file Switching.h
|
||||||
|
* @date Mar 11, 2022
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @author Fan Jiang
|
||||||
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/Matrix.h>
|
#include <gtsam/base/Matrix.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testHybridConditional.cpp
|
* @file testHybridFactorGraph.cpp
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
*/
|
*/
|
||||||
Loading…
Reference in New Issue