address review comments
parent
098d2ce4a4
commit
d94b3199a0
|
|
@ -210,13 +210,14 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
|||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from GaussianConditional to double error value.
|
||||
// functor to calculate to double error value from GaussianConditional.
|
||||
auto errorFunc =
|
||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues);
|
||||
} else {
|
||||
// return arbitrarily large error
|
||||
// Return arbitrarily large error if conditional is null
|
||||
// Conditional is null if it is pruned out.
|
||||
return 1e50;
|
||||
}
|
||||
};
|
||||
|
|
@ -227,6 +228,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(discreteValues);
|
||||
return conditional->error(continuousValues);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
|||
double GaussianMixtureFactor::error(
|
||||
const VectorValues &continuousValues,
|
||||
const DiscreteValues &discreteValues) const {
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto factor = factors_(discreteValues);
|
||||
return factor->error(continuousValues);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -244,13 +244,16 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree;
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> conditional_error;
|
||||
|
||||
if (factors_.at(idx)->isHybrid()) {
|
||||
// If factor is hybrid, select based on assignment.
|
||||
// If factor is hybrid, select based on assignment and compute error.
|
||||
GaussianMixture::shared_ptr gm = this->atMixture(idx);
|
||||
conditional_error = gm->error(continuousValues);
|
||||
|
||||
// Assign for the first index, add error for subsequent ones.
|
||||
if (idx == 0) {
|
||||
error_tree = conditional_error;
|
||||
} else {
|
||||
|
|
@ -261,6 +264,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = this->atGaussian(idx)->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
|
|
@ -273,6 +277,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
return error_tree;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
|
|
|
|||
|
|
@ -428,6 +428,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
AlgebraicDecisionTree<Key> factor_error;
|
||||
|
||||
|
|
@ -435,8 +436,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
// If factor is hybrid, select based on assignment.
|
||||
GaussianMixtureFactor::shared_ptr gaussianMixture =
|
||||
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
|
||||
// Compute factor error.
|
||||
factor_error = gaussianMixture->error(continuousValues);
|
||||
|
||||
// If first factor, assign error, else add it.
|
||||
if (idx == 0) {
|
||||
error_tree = factor_error;
|
||||
} else {
|
||||
|
|
@ -450,7 +453,9 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
|
||||
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();
|
||||
|
||||
// Compute the error of the gaussian factor.
|
||||
double error = gaussian->error(continuousValues);
|
||||
// Add the gaussian factor error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridNonlinearFactor.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
#include <gtsam/nonlinear/Symbol.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -86,11 +87,11 @@ class MixtureFactor : public HybridFactor {
|
|||
* elements based on the number of discrete keys and the cardinality of the
|
||||
* keys, so that the decision tree is constructed appropriately.
|
||||
*
|
||||
* @tparam FACTOR The type of the factor shared pointers being passed in. Will
|
||||
* be typecast to NonlinearFactor shared pointers.
|
||||
* @tparam FACTOR The type of the factor shared pointers being passed in.
|
||||
* Will be typecast to NonlinearFactor shared pointers.
|
||||
* @param keys Vector of keys for continuous factors.
|
||||
* @param discreteKeys Vector of discrete keys.
|
||||
* @param factors Vector of shared pointers to factors.
|
||||
* @param factors Vector of nonlinear factors.
|
||||
* @param normalized Flag indicating if the factor error is already
|
||||
* normalized.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -196,8 +196,10 @@ class HybridNonlinearFactorGraph {
|
|||
|
||||
#include <gtsam/hybrid/MixtureFactor.h>
|
||||
class MixtureFactor : gtsam::HybridFactor {
|
||||
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors, bool normalized = false);
|
||||
MixtureFactor(
|
||||
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
const gtsam::DecisionTree<gtsam::Key, gtsam::NonlinearFactor*>& factors,
|
||||
bool normalized = false);
|
||||
|
||||
template <FACTOR = {gtsam::NonlinearFactor}>
|
||||
MixtureFactor(const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ TEST(GaussianMixture, Error) {
|
|||
X(2), S2, model);
|
||||
|
||||
// Create decision tree
|
||||
DiscreteKey m1(1, 2);
|
||||
DiscreteKey m1(M(1), 2);
|
||||
GaussianMixture::Conditionals conditionals(
|
||||
{m1},
|
||||
vector<GaussianConditional::shared_ptr>{conditional0, conditional1});
|
||||
|
|
@ -115,12 +115,19 @@ TEST(GaussianMixture, Error) {
|
|||
values.insert(X(2), Vector2::Zero());
|
||||
auto error_tree = mixture.error(values);
|
||||
|
||||
// regression
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
std::vector<double> leaves = {0.5, 4.3252595};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||
|
||||
// regression
|
||||
EXPECT(assert_equal(expected_error, error_tree, 1e-6));
|
||||
|
||||
// Regression for non-tree version.
|
||||
DiscreteValues assignment;
|
||||
assignment[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(0.5, mixture.error(values, assignment), 1e-8);
|
||||
assignment[M(1)] = 1;
|
||||
EXPECT_DOUBLES_EQUAL(4.3252595155709335, mixture.error(values, assignment), 1e-8);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -178,6 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
|
|||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
// Error values for regression test
|
||||
std::vector<double> errors = {1, 4};
|
||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
|
||||
|
||||
|
|
|
|||
|
|
@ -216,8 +216,7 @@ TEST(HybridBayesNet, Error) {
|
|||
|
||||
// Verify error computation and check for specific error value
|
||||
DiscreteValues discrete_values;
|
||||
discrete_values[M(0)] = 1;
|
||||
discrete_values[M(1)] = 1;
|
||||
insert(discrete_values)(M(0), 1)(M(1), 1);
|
||||
|
||||
double total_error = 0;
|
||||
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ TEST(MixtureFactor, Constructor) {
|
|||
CHECK(it == factor.end());
|
||||
}
|
||||
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test .print() output.
|
||||
TEST(MixtureFactor, Printing) {
|
||||
DiscreteKey m1(1, 2);
|
||||
double between0 = 0.0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue