address review comments

release/4.3a0
Frank Dellaert 2024-09-29 09:40:30 -07:00
parent 3d8603b23b
commit 2cf2100710
3 changed files with 34 additions and 36 deletions

View File

@ -51,7 +51,7 @@ struct HybridGaussianConditional::Helper {
std::vector<GC::shared_ptr> gcs; std::vector<GC::shared_ptr> gcs;
fvs.reserve(p.size()); fvs.reserve(p.size());
gcs.reserve(p.size()); gcs.reserve(p.size());
for (const auto &[mean, sigma] : p) { for (auto &&[mean, sigma] : p) {
auto gaussianConditional = auto gaussianConditional =
GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma); GC::sharedMeanAndStddev(std::forward<Args>(args)..., mean, sigma);
double value = gaussianConditional->negLogConstant(); double value = gaussianConditional->negLogConstant();
@ -96,38 +96,34 @@ HybridGaussianConditional::HybridGaussianConditional(
conditionals_(helper.conditionals), conditionals_(helper.conditionals),
negLogConstant_(helper.minNegLogConstant) {} negLogConstant_(helper.minNegLogConstant) {}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals) const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(DiscreteKeys{mode}, : HybridGaussianConditional(DiscreteKeys{discreteParent},
Conditionals({mode}, conditionals)) {} Conditionals({discreteParent}, conditionals)) {}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, // const DiscreteKey &discreteParent, Key key, //
const std::vector<std::pair<Vector, double>> &parameters) const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{mode}, : HybridGaussianConditional(DiscreteKeys{discreteParent},
Helper(mode, parameters, key)) {} Helper(discreteParent, parameters, key)) {}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, // const DiscreteKey &discreteParent, Key key, //
const Matrix &A, Key parent, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters) const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional(DiscreteKeys{mode}, : HybridGaussianConditional(
Helper(mode, parameters, key, A, parent)) {} DiscreteKeys{discreteParent},
Helper(discreteParent, parameters, key, A, parent)) {}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKey &mode, Key key, // const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2, const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters) const std::vector<std::pair<Vector, double>> &parameters)
: HybridGaussianConditional( : HybridGaussianConditional(
DiscreteKeys{mode}, DiscreteKeys{discreteParent},
Helper(mode, parameters, key, A1, parent1, A2, parent2)) {} Helper(discreteParent, parameters, key, A1, parent1, A2, parent2)) {}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional( HybridGaussianConditional::HybridGaussianConditional(
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals) const HybridGaussianConditional::Conditionals &conditionals)

View File

@ -79,45 +79,45 @@ class GTSAM_EXPORT HybridGaussianConditional
/** /**
* @brief Construct from one discrete key and vector of conditionals. * @brief Construct from one discrete key and vector of conditionals.
* *
* @param mode Single discrete parent variable * @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the * @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent. * cardinality of the discrete parent.
*/ */
HybridGaussianConditional( HybridGaussianConditional(
const DiscreteKey &mode, const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
/** /**
* @brief Constructs a HybridGaussianConditional with means mu_i and * @brief Constructs a HybridGaussianConditional with means mu_i and
* standard deviations sigma_i. * standard deviations sigma_i.
* *
* @param mode The discrete mode key. * @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable. * @param key The key for this conditional variable.
* @param parameters A vector of pairs (mu_i, sigma_i). * @param parameters A vector of pairs (mu_i, sigma_i).
*/ */
HybridGaussianConditional( HybridGaussianConditional(
const DiscreteKey &mode, Key key, const DiscreteKey &discreteParent, Key key,
const std::vector<std::pair<Vector, double>> &parameters); const std::vector<std::pair<Vector, double>> &parameters);
/** /**
* @brief Constructs a HybridGaussianConditional with conditional means * @brief Constructs a HybridGaussianConditional with conditional means
* A × parent + b_i and standard deviations sigma_i. * A × parent + b_i and standard deviations sigma_i.
* *
* @param mode The discrete mode key. * @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable. * @param key The key for this conditional variable.
* @param A The matrix A. * @param A The matrix A.
* @param parent The key of the parent variable. * @param parent The key of the parent variable.
* @param parameters A vector of pairs (b_i, sigma_i). * @param parameters A vector of pairs (b_i, sigma_i).
*/ */
HybridGaussianConditional( HybridGaussianConditional(
const DiscreteKey &mode, Key key, const Matrix &A, Key parent, const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters); const std::vector<std::pair<Vector, double>> &parameters);
/** /**
* @brief Constructs a HybridGaussianConditional with conditional means * @brief Constructs a HybridGaussianConditional with conditional means
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i. * A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
* *
* @param mode The discrete mode key. * @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable. * @param key The key for this conditional variable.
* @param A1 The first matrix. * @param A1 The first matrix.
* @param parent1 The key of the first parent variable. * @param parent1 The key of the first parent variable.
@ -126,7 +126,7 @@ class GTSAM_EXPORT HybridGaussianConditional
* @param parameters A vector of pairs (b_i, sigma_i). * @param parameters A vector of pairs (b_i, sigma_i).
*/ */
HybridGaussianConditional( HybridGaussianConditional(
const DiscreteKey &mode, Key key, // const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2, const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters); const std::vector<std::pair<Vector, double>> &parameters);

View File

@ -79,15 +79,16 @@ TEST(GaussianMixture, GaussianMixtureModel) {
double mu0 = 1.0, mu1 = 3.0; double mu0 = 1.0, mu1 = 3.0;
double sigma = 2.0; double sigma = 2.0;
HybridBayesNet hbn; // Create a Gaussian mixture model p(z|m) with same sigma.
HybridBayesNet gmm;
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma}, std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma},
{Vector1(mu1), sigma}}; {Vector1(mu1), sigma}};
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters); gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
hbn.push_back(mixing); gmm.push_back(mixing);
// At the halfway point between the means, we should get P(m|z)=0.5 // At the halfway point between the means, we should get P(m|z)=0.5
double midway = mu1 - mu0; double midway = mu1 - mu0;
auto pMid = SolveHBN(hbn, midway); auto pMid = SolveHBN(gmm, midway);
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid)); EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
// Everywhere else, the result should be a sigmoid. // Everywhere else, the result should be a sigmoid.
@ -96,7 +97,7 @@ TEST(GaussianMixture, GaussianMixtureModel) {
const double expected = prob_m_z(mu0, mu1, sigma, sigma, z); const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);
// Workflow 1: convert HBN to HFG and solve // Workflow 1: convert HBN to HFG and solve
auto posterior1 = SolveHBN(hbn, z); auto posterior1 = SolveHBN(gmm, z);
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve
@ -117,16 +118,17 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
double mu0 = 1.0, mu1 = 3.0; double mu0 = 1.0, mu1 = 3.0;
double sigma0 = 8.0, sigma1 = 4.0; double sigma0 = 8.0, sigma1 = 4.0;
HybridBayesNet hbn; // Create a Gaussian mixture model p(z|m) with same sigma.
HybridBayesNet gmm;
std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma0}, std::vector<std::pair<Vector, double>> parameters{{Vector1(mu0), sigma0},
{Vector1(mu1), sigma1}}; {Vector1(mu1), sigma1}};
hbn.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters); gmm.emplace_shared<HybridGaussianConditional>(m, Z(0), parameters);
hbn.push_back(mixing); gmm.push_back(mixing);
// We get zMax=3.1333 by finding the maximum value of the function, at which // We get zMax=3.1333 by finding the maximum value of the function, at which
// point the mode m==1 is about twice as probable as m==0. // point the mode m==1 is about twice as probable as m==0.
double zMax = 3.133; double zMax = 3.133;
auto pMax = SolveHBN(hbn, zMax); auto pMax = SolveHBN(gmm, zMax);
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4)); EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
// Everywhere else, the result should be a bell curve like function. // Everywhere else, the result should be a bell curve like function.
@ -135,7 +137,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z); const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);
// Workflow 1: convert HBN to HFG and solve // Workflow 1: convert HBN to HFG and solve
auto posterior1 = SolveHBN(hbn, z); auto posterior1 = SolveHBN(gmm, z);
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8); EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
// Workflow 2: directly specify HFG and solve // Workflow 2: directly specify HFG and solve