address review comments
parent
3d8603b23b
commit
2cf2100710
|
@ -51,7 +51,7 @@ struct HybridGaussianConditional::Helper {
|
|||
std::vector<GC::shared_ptr> gcs;
|
||||
fvs.reserve(p.size());
|
||||
gcs.reserve(p.size());
|
||||
for (const auto &[mean, sigma] : p) {
|
||||
for (auto &&[mean, sigma] : p) {
|
||||
auto gaussianConditional =
|
||||
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
|
||||
double value = gaussianConditional->negLogConstant();
|
||||
|
@ -96,38 +96,34 @@ HybridGaussianConditional::HybridGaussianConditional(
|
|||
conditionals_(helper.conditionals),
|
||||
negLogConstant_(helper.minNegLogConstant) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKey &mode,
|
||||
const DiscreteKey &discreteParent,
|
||||
const std::vector<GaussianConditional::shared_ptr> &conditionals)
|
||||
: HybridGaussianConditional(DiscreteKeys{mode},
|
||||
Conditionals({mode}, conditionals)) {}
|
||||
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
||||
Conditionals({discreteParent}, conditionals)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key, //
|
||||
const DiscreteKey &discreteParent, Key key, //
|
||||
const std::vector<std::pair<Vector, double>> ¶meters)
|
||||
: HybridGaussianConditional(DiscreteKeys{mode},
|
||||
Helper(mode, parameters, key)) {}
|
||||
: HybridGaussianConditional(DiscreteKeys{discreteParent},
|
||||
Helper(discreteParent, parameters, key)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key, //
|
||||
const DiscreteKey &discreteParent, Key key, //
|
||||
const Matrix &A, Key parent,
|
||||
const std::vector<std::pair<Vector, double>> ¶meters)
|
||||
: HybridGaussianConditional(DiscreteKeys{mode},
|
||||
Helper(mode, parameters, key, A, parent)) {}
|
||||
: HybridGaussianConditional(
|
||||
DiscreteKeys{discreteParent},
|
||||
Helper(discreteParent, parameters, key, A, parent)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key, //
|
||||
const DiscreteKey &discreteParent, Key key, //
|
||||
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
|
||||
const std::vector<std::pair<Vector, double>> ¶meters)
|
||||
: HybridGaussianConditional(
|
||||
DiscreteKeys{mode},
|
||||
Helper(mode, parameters, key, A1, parent1, A2, parent2)) {}
|
||||
DiscreteKeys{discreteParent},
|
||||
Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::HybridGaussianConditional(
|
||||
const DiscreteKeys &discreteParents,
|
||||
const HybridGaussianConditional::Conditionals &conditionals)
|
||||
|
|
|
@ -79,45 +79,45 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
/**
|
||||
* @brief Construct from one discrete key and vector of conditionals.
|
||||
*
|
||||
* @param mode Single discrete parent variable
|
||||
* @param discreteParent Single discrete parent variable
|
||||
* @param conditionals Vector of conditionals with the same size as the
|
||||
* cardinality of the discrete parent.
|
||||
*/
|
||||
HybridGaussianConditional(
|
||||
const DiscreteKey &mode,
|
||||
const DiscreteKey &discreteParent,
|
||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||
|
||||
/**
|
||||
* @brief Constructs a HybridGaussianConditional with means mu_i and
|
||||
* standard deviations sigma_i.
|
||||
*
|
||||
* @param mode The discrete mode key.
|
||||
* @param discreteParent The discrete parent or "mode" key.
|
||||
* @param key The key for this conditional variable.
|
||||
* @param parameters A vector of pairs (mu_i, sigma_i).
|
||||
*/
|
||||
HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key,
|
||||
const DiscreteKey &discreteParent, Key key,
|
||||
const std::vector<std::pair<Vector, double>> ¶meters);
|
||||
|
||||
/**
|
||||
* @brief Constructs a HybridGaussianConditional with conditional means
|
||||
* A × parent + b_i and standard deviations sigma_i.
|
||||
*
|
||||
* @param mode The discrete mode key.
|
||||
* @param discreteParent The discrete parent or "mode" key.
|
||||
* @param key The key for this conditional variable.
|
||||
* @param A The matrix A.
|
||||
* @param parent The key of the parent variable.
|
||||
* @param parameters A vector of pairs (b_i, sigma_i).
|
||||
*/
|
||||
HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key, const Matrix &A, Key parent,
|
||||
const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
|
||||
const std::vector<std::pair<Vector, double>> ¶meters);
|
||||
|
||||
/**
|
||||
* @brief Constructs a HybridGaussianConditional with conditional means
|
||||
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
|
||||
*
|
||||
* @param mode The discrete mode key.
|
||||
* @param discreteParent The discrete parent or "mode" key.
|
||||
* @param key The key for this conditional variable.
|
||||
* @param A1 The first matrix.
|
||||
* @param parent1 The key of the first parent variable.
|
||||
|
@ -126,7 +126,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* @param parameters A vector of pairs (b_i, sigma_i).
|
||||
*/
|
||||
HybridGaussianConditional(
|
||||
const DiscreteKey &mode, Key key, //
|
||||
const DiscreteKey &discreteParent, Key key, //
|
||||
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
|
||||
const std::vector<std::pair<Vector, double>> ¶meters);
|
||||
|
||||
|
|
|
@ -79,15 +79,16 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
double mu0 = 1.0, mu1 = 3.0;
|
||||
double sigma = 2.0;
|
||||
|
||||
HybridBayesNet hbn;
|
||||
// Create a Gaussian mixture model p(z|m) with same sigma.
|
||||
HybridBayesNet gmm;
|
||||
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma},
|
||||
{Vector1(mu1), sigma}};
|
||||
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
|
||||
hbn.push_back(mixing);
|
||||
gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
|
||||
gmm.push_back(mixing);
|
||||
|
||||
// At the halfway point between the means, we should get P(m|z)=0.5
|
||||
double midway = mu1 - mu0;
|
||||
auto pMid = SolveHBN(hbn, midway);
|
||||
auto pMid = SolveHBN(gmm, midway);
|
||||
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
||||
|
||||
// Everywhere else, the result should be a sigmoid.
|
||||
|
@ -96,7 +97,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
|||
const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);
|
||||
|
||||
// Workflow 1: convert HBN to HFG and solve
|
||||
auto posterior1 = SolveHBN(hbn, z);
|
||||
auto posterior1 = SolveHBN(gmm, z);
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
@ -117,16 +118,17 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
double mu0 = 1.0, mu1 = 3.0;
|
||||
double sigma0 = 8.0, sigma1 = 4.0;
|
||||
|
||||
HybridBayesNet hbn;
|
||||
// Create a Gaussian mixture model p(z|m) with same sigma.
|
||||
HybridBayesNet gmm;
|
||||
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma0},
|
||||
{Vector1(mu1), sigma1}};
|
||||
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
|
||||
hbn.push_back(mixing);
|
||||
gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
|
||||
gmm.push_back(mixing);
|
||||
|
||||
// We get zMax=3.1333 by finding the maximum value of the function, at which
|
||||
// point the mode m==1 is about twice as probable as m==0.
|
||||
double zMax = 3.133;
|
||||
auto pMax = SolveHBN(hbn, zMax);
|
||||
auto pMax = SolveHBN(gmm, zMax);
|
||||
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
||||
|
||||
// Everywhere else, the result should be a bell curve like function.
|
||||
|
@ -135,7 +137,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
|||
const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);
|
||||
|
||||
// Workflow 1: convert HBN to HFG and solve
|
||||
auto posterior1 = SolveHBN(hbn, z);
|
||||
auto posterior1 = SolveHBN(gmm, z);
|
||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||
|
||||
// Workflow 2: directly specify HFG and solve
|
||||
|
|
Loading…
Reference in New Issue