add method to select underlying continuous Gaussian graph given discrete assignment
parent
8372d8490c
commit
f62805f8b3
|
|
@ -296,7 +296,8 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
|
|
||||||
// Logspace version of:
|
// Logspace version of:
|
||||||
// exp(-factor->error(kEmpty)) * conditional->normalizationConstant();
|
// exp(-factor->error(kEmpty)) * conditional->normalizationConstant();
|
||||||
// We take negative of the logNormalizationConstant `log(1/k)` to get `log(k)`.
|
// We take negative of the logNormalizationConstant `log(1/k)`
|
||||||
|
// to get `log(k)`.
|
||||||
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
return -factor->error(kEmpty) + (-conditional->logNormalizationConstant());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -326,6 +327,7 @@ static std::shared_ptr<Factor> createGaussianMixtureFactor(
|
||||||
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
||||||
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
||||||
// Add 2.0 term since the constant term will be premultiplied by 0.5
|
// Add 2.0 term since the constant term will be premultiplied by 0.5
|
||||||
|
// as per the Hessian definition
|
||||||
hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
|
hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
|
||||||
}
|
}
|
||||||
return factor;
|
return factor;
|
||||||
|
|
@ -563,4 +565,24 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||||
return prob_tree;
|
return prob_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
GaussianFactorGraph HybridGaussianFactorGraph::operator()(
|
||||||
|
const DiscreteValues &assignment) const {
|
||||||
|
GaussianFactorGraph gfg;
|
||||||
|
for (auto &&f : *this) {
|
||||||
|
if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
gfg.push_back(gf);
|
||||||
|
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
|
||||||
|
gfg.push_back(gf);
|
||||||
|
} else if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
|
gfg.push_back((*gmf)(assignment));
|
||||||
|
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
|
||||||
|
gfg.push_back((*gm)(assignment));
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return gfg;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -210,6 +210,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
GaussianFactorGraphTree assembleGraphTree() const;
|
GaussianFactorGraphTree assembleGraphTree() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
/// Get the GaussianFactorGraph at a given discrete assignment.
|
||||||
|
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -490,6 +490,58 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Select a particular continuous factor graph given a discrete assignment
|
||||||
|
TEST(HybridGaussianFactorGraph, DiscreteSelection) {
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
||||||
|
|
||||||
|
DiscreteValues dv00{{M(0), 0}, {M(1), 0}};
|
||||||
|
GaussianFactorGraph continuous_00 = graph(dv00);
|
||||||
|
GaussianFactorGraph expected_00;
|
||||||
|
expected_00.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_00.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
|
||||||
|
expected_00.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
|
||||||
|
expected_00.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_00.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_00, continuous_00));
|
||||||
|
|
||||||
|
DiscreteValues dv01{{M(0), 0}, {M(1), 1}};
|
||||||
|
GaussianFactorGraph continuous_01 = graph(dv01);
|
||||||
|
GaussianFactorGraph expected_01;
|
||||||
|
expected_01.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_01.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
|
||||||
|
expected_01.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
|
||||||
|
expected_01.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_01.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_01, continuous_01));
|
||||||
|
|
||||||
|
DiscreteValues dv10{{M(0), 1}, {M(1), 0}};
|
||||||
|
GaussianFactorGraph continuous_10 = graph(dv10);
|
||||||
|
GaussianFactorGraph expected_10;
|
||||||
|
expected_10.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_10.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
|
||||||
|
expected_10.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
|
||||||
|
expected_10.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_10.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_10, continuous_10));
|
||||||
|
|
||||||
|
DiscreteValues dv11{{M(0), 1}, {M(1), 1}};
|
||||||
|
GaussianFactorGraph continuous_11 = graph(dv11);
|
||||||
|
GaussianFactorGraph expected_11;
|
||||||
|
expected_11.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_11.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
|
||||||
|
expected_11.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
|
||||||
|
expected_11.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
|
||||||
|
expected_11.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_11, continuous_11));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, optimize) {
|
TEST(HybridGaussianFactorGraph, optimize) {
|
||||||
HybridGaussianFactorGraph hfg;
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue