Merge branch 'hybrid-printerrors' into model-selection-integration
commit
a80b5d4f5a
|
@ -63,7 +63,7 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> DecisionTreeFactor::error() const {
|
||||
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
|
||||
// Get all possible assignments
|
||||
DiscreteKeys dkeys = discreteKeys();
|
||||
// Reverse to make cartesian product output a more natural ordering.
|
||||
|
|
|
@ -293,7 +293,7 @@ namespace gtsam {
|
|||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override;
|
||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
|
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
|||
double error(const HybridValues& c) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
virtual AlgebraicDecisionTree<Key> error() const = 0;
|
||||
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
|
||||
|
||||
/// Multiply in a DecisionTreeFactor and return the result as
|
||||
/// DecisionTreeFactor
|
||||
|
|
|
@ -169,8 +169,8 @@ double TableFactor::error(const HybridValues& values) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> TableFactor::error() const {
|
||||
return toDecisionTreeFactor().error();
|
||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
|
||||
return toDecisionTreeFactor().errorTree();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
@ -359,7 +359,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
double error(const HybridValues& values) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override;
|
||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
||||
|
||||
/// @}
|
||||
};
|
||||
|
|
|
@ -75,7 +75,7 @@ TEST(DecisionTreeFactor, Error) {
|
|||
// Create factors
|
||||
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||
|
||||
auto errors = f.error();
|
||||
auto errors = f.errorTree();
|
||||
// regression
|
||||
AlgebraicDecisionTree<Key> expected(
|
||||
{X, Y, Z},
|
||||
|
|
|
@ -342,7 +342,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
// Check if valid pointer
|
||||
|
@ -355,8 +355,8 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
|
||||
return errorTree;
|
||||
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -221,7 +221,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||
* only, with the leaf values as the error for each assignment.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||
|
|
|
@ -102,14 +102,14 @@ GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
|
||||
AlgebraicDecisionTree<Key> GaussianMixtureFactor::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
|
||||
return gf->error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||
return error_tree;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -135,7 +135,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the factors involved, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||
|
|
|
@ -434,7 +434,7 @@ HybridValues HybridBayesNet::sample() const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
|
@ -442,7 +442,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, compute error for all assignments.
|
||||
result = result + gm->error(continuousValues);
|
||||
result = result + gm->errorTree(continuousValues);
|
||||
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the error and add it to the result
|
||||
|
|
|
@ -210,7 +210,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
|
|
@ -99,7 +99,7 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gmf->error(values.continuous()).print("", keyFormatter);
|
||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
|
@ -113,12 +113,12 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||
} else if (hc->isDiscrete()) {
|
||||
std::cout << "error = ";
|
||||
hc->asDiscrete()->error().print("", keyFormatter);
|
||||
hc->asDiscrete()->errorTree().print("", keyFormatter);
|
||||
std::cout << "\n";
|
||||
} else {
|
||||
// Is hybrid
|
||||
std::cout << "error = ";
|
||||
hc->asMixture()->error(values.continuous()).print();
|
||||
hc->asMixture()->errorTree(values.continuous()).print();
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
df->error().print("", keyFormatter);
|
||||
df->errorTree().print("", keyFormatter);
|
||||
}
|
||||
|
||||
} else {
|
||||
|
@ -513,7 +513,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
|
@ -524,7 +524,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
|
||||
if (auto gaussianMixture = dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||
// Compute factor error and add it.
|
||||
error_tree = error_tree + gaussianMixture->error(continuousValues);
|
||||
error_tree = error_tree + gaussianMixture->errorTree(continuousValues);
|
||||
} else if (auto gaussian = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
|
@ -553,7 +553,7 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
|||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
||||
// NOTE: The 0.5 term is handled by each factor
|
||||
return exp(-error);
|
||||
|
|
|
@ -171,7 +171,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* @param continuousValues Continuous values at which to compute the error.
|
||||
* @return AlgebraicDecisionTree<Key>
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;
|
||||
AlgebraicDecisionTree<Key> errorTree(
|
||||
const VectorValues& continuousValues) const;
|
||||
|
||||
/**
|
||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
||||
|
|
|
@ -66,7 +66,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
mf->error(values.nonlinear()).print("", keyFormatter);
|
||||
mf->errorTree(values.nonlinear()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto gmf =
|
||||
|
@ -77,7 +77,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gmf->error(values.continuous()).print("", keyFormatter);
|
||||
gmf->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) {
|
||||
|
@ -87,7 +87,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
gm->error(values.continuous()).print("", keyFormatter);
|
||||
gm->errorTree(values.continuous()).print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||
|
@ -121,7 +121,7 @@ void HybridNonlinearFactorGraph::printErrors(
|
|||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = ";
|
||||
df->error().print("", keyFormatter);
|
||||
df->errorTree().print("", keyFormatter);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
|
|
@ -131,13 +131,13 @@ class MixtureFactor : public HybridFactor {
|
|||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||
* as the factor, and leaf values as the error.
|
||||
*/
|
||||
AlgebraicDecisionTree<Key> error(const Values& continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> errorTree(const Values& continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [continuousValues](const sharedFactor& factor) {
|
||||
return factor->error(continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> errorTree(factors_, errorFunc);
|
||||
return errorTree;
|
||||
DecisionTree<Key, double> result(factors_, errorFunc);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -97,7 +97,7 @@ TEST(GaussianMixture, LogProbability) {
|
|||
/// Check error.
|
||||
TEST(GaussianMixture, Error) {
|
||||
using namespace equal_constants;
|
||||
auto actual = mixture.error(vv);
|
||||
auto actual = mixture.errorTree(vv);
|
||||
|
||||
// Check result.
|
||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
||||
|
@ -134,7 +134,7 @@ TEST(GaussianMixture, Likelihood) {
|
|||
std::vector<double> leaves = {conditionals[0]->likelihood(vv)->error(vv),
|
||||
conditionals[1]->likelihood(vv)->error(vv)};
|
||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
||||
EXPECT(assert_equal(expected, likelihood->error(vv), 1e-6));
|
||||
EXPECT(assert_equal(expected, likelihood->errorTree(vv), 1e-6));
|
||||
|
||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||
std::vector<double> ratio(2);
|
||||
|
|
|
@ -178,7 +178,7 @@ TEST(GaussianMixtureFactor, Error) {
|
|||
continuousValues.insert(X(2), Vector2(1, 1));
|
||||
|
||||
// error should return a tree of errors, with nodes for each discrete value.
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.errorTree(continuousValues);
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
// Error values for regression test
|
||||
|
|
|
@ -182,7 +182,7 @@ TEST(HybridBayesNet, Error) {
|
|||
values.insert(X(1), Vector1(1));
|
||||
|
||||
AlgebraicDecisionTree<Key> actual_errors =
|
||||
bayesNet.error(values.continuous());
|
||||
bayesNet.errorTree(values.continuous());
|
||||
|
||||
// Regression.
|
||||
// Manually added all the error values from the 3 conditional types.
|
||||
|
|
|
@ -580,7 +580,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
|||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
auto error_tree = graph.error(delta.continuous());
|
||||
auto error_tree = graph.errorTree(delta.continuous());
|
||||
|
||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
||||
|
|
|
@ -100,7 +100,8 @@ TEST(MixtureFactor, Error) {
|
|||
continuousValues.insert<double>(X(1), 0);
|
||||
continuousValues.insert<double>(X(2), 1);
|
||||
|
||||
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
|
||||
AlgebraicDecisionTree<Key> error_tree =
|
||||
mixtureFactor.errorTree(continuousValues);
|
||||
|
||||
DiscreteKey m1(1, 2);
|
||||
std::vector<DiscreteKey> discrete_keys = {m1};
|
||||
|
|
|
@ -54,7 +54,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
|||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override {
|
||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||
throw std::runtime_error("AllDiff::error not implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ class BinaryAllDiff : public Constraint {
|
|||
}
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override {
|
||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||
throw std::runtime_error("BinaryAllDiff::error not implemented");
|
||||
}
|
||||
};
|
||||
|
|
|
@ -70,7 +70,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
|||
}
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override {
|
||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||
throw std::runtime_error("Domain::error not implemented");
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
|||
}
|
||||
|
||||
/// Compute error for each assignment and return as a tree
|
||||
AlgebraicDecisionTree<Key> error() const override {
|
||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||
throw std::runtime_error("SingleValue::error not implemented");
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue