address review comments

release/4.3a0
Varun Agrawal 2022-12-22 09:22:34 +05:30
parent 098d2ce4a4
commit d94b3199a0
10 changed files with 37 additions and 13 deletions

View File

@ -210,13 +210,14 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
// functor to convert from GaussianConditional to double error value.
// functor to calculate to double error value from GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
} else {
// return arbitrarily large error
// Return arbitrarily large error if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
};
@ -227,6 +228,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
}

View File

@ -112,6 +112,7 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
double GaussianMixtureFactor::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
}

View File

@ -244,13 +244,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);
// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
@ -261,6 +264,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
@ -273,6 +277,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
return error_tree;
}
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);

View File

@ -428,6 +428,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;
@ -435,8 +436,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error.
factor_error = gaussianMixture->error(continuousValues);
// If first factor, assign error, else add it.
if (idx == 0) {
error_tree = factor_error;
} else {
@ -450,7 +453,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
// Compute the error of the gaussian factor.
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

View File

@ -23,6 +23,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/nonlinear/NonlinearFactor.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/nonlinear/Symbol.h>
#include <algorithm>
@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor {
* elements based on the number of discrete keys and the cardinality of the
* keys, so that the decision tree is constructed appropriately.
*
* @tparam FACTOR The type of the factor shared pointers being passed in. Will
* be typecast to NonlinearFactor shared pointers.
* @tparam FACTOR The type of the factor shared pointers being passed in.
* Will be typecast to NonlinearFactor shared pointers.
* @param keys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Vector of shared pointers to factors.
* @param factors Vector of nonlinear factors.
* @param normalized Flag indicating if the factor error is already
* normalized.
*/

View File

@ -196,8 +196,10 @@ class HybridNonlinearFactorGraph {
#include <gtsam/hybrid/MixtureFactor.h>
class MixtureFactor : gtsam::HybridFactor {
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);
MixtureFactor(
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors,
bool normalized = false);
template <FACTOR = {gtsam::NonlinearFactor}>
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,

View File

@ -104,7 +104,7 @@ TEST(GaussianMixture, Error) {
X(2), S2, model);
// Create decision tree
DiscreteKey m1(1, 2);
DiscreteKey m1(M(1), 2);
GaussianMixture::Conditionals conditionals(
{m1},
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
@ -115,12 +115,19 @@ TEST(GaussianMixture, Error) {
values.insert(X(2), Vector2::Zero());
auto error_tree = mixture.error(values);
// regression
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> leaves = {0.5, 4.3252595};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
// regression
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
// Regression for non-tree version.
DiscreteValues assignment;
assignment[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
assignment[M(1)] = 1;
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8);
}
/* ************************************************************************* */

View File

@ -178,6 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
std::vector<DiscreteKey> discrete_keys = {m1};
// Error values for regression test
std::vector<double> errors = {1, 4};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);

View File

@ -216,8 +216,7 @@ TEST(HybridBayesNet, Error) {
// Verify error computation and check for specific error value
DiscreteValues discrete_values;
discrete_values[M(0)] = 1;
discrete_values[M(1)] = 1;
insert(discrete_values)(M(0), 1)(M(1), 1);
double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {

View File

@ -41,7 +41,8 @@ TEST(MixtureFactor, Constructor) {
CHECK(it == factor.end());
}
/* ************************************************************************* */
// Test .print() output.
TEST(MixtureFactor, Printing) {
DiscreteKey m1(1, 2);
double between0 = 0.0;