Merge pull request #1696 from borglab/model-selection-integration
AlgebraicDecisionTree Helpersrelease/4.3a0
commit
231d1adbbd
|
@ -196,6 +196,42 @@ namespace gtsam {
|
||||||
return this->apply(g, &Ring::div);
|
return this->apply(g, &Ring::div);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute sum of all values
|
||||||
|
double sum() const {
|
||||||
|
double sum = 0;
|
||||||
|
auto visitor = [&](double y) { sum += y; };
|
||||||
|
this->visit(visitor);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Helper method to perform normalization such that all leaves in the
|
||||||
|
* tree sum to 1
|
||||||
|
*
|
||||||
|
* @param sum
|
||||||
|
* @return AlgebraicDecisionTree
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree normalize(double sum) const {
|
||||||
|
return this->apply([&sum](const double& x) { return x / sum; });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the minimum values amongst all leaves
|
||||||
|
double min() const {
|
||||||
|
double min = std::numeric_limits<double>::max();
|
||||||
|
auto visitor = [&](double x) { min = x < min ? x : min; };
|
||||||
|
this->visit(visitor);
|
||||||
|
return min;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the maximum values amongst all leaves
|
||||||
|
double max() const {
|
||||||
|
// Get the most negative value
|
||||||
|
double max = -std::numeric_limits<double>::max();
|
||||||
|
auto visitor = [&](double x) { max = x > max ? x : max; };
|
||||||
|
this->visit(visitor);
|
||||||
|
return max;
|
||||||
|
}
|
||||||
|
|
||||||
/** sum out variable */
|
/** sum out variable */
|
||||||
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
|
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
|
||||||
return this->combine(label, cardinality, &Ring::add);
|
return this->combine(label, cardinality, &Ring::add);
|
||||||
|
|
|
@ -593,6 +593,55 @@ TEST(ADT, zero) {
|
||||||
EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Example ADT from 0 to 11.
|
||||||
|
ADT exampleADT() {
|
||||||
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
|
ADT f(A & B & C, "0 6 2 8 4 10 1 7 3 9 5 11");
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test sum
|
||||||
|
TEST(ADT, Sum) {
|
||||||
|
ADT a = exampleADT();
|
||||||
|
double expected_sum = 0;
|
||||||
|
for (double i = 0; i < 12; i++) {
|
||||||
|
expected_sum += i;
|
||||||
|
}
|
||||||
|
EXPECT_DOUBLES_EQUAL(expected_sum, a.sum(), 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test normalize
|
||||||
|
TEST(ADT, Normalize) {
|
||||||
|
ADT a = exampleADT();
|
||||||
|
double sum = a.sum();
|
||||||
|
auto actual = a.normalize(sum);
|
||||||
|
|
||||||
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
|
DiscreteKeys keys = DiscreteKeys{A, B, C};
|
||||||
|
std::vector<double> cpt{0 / sum, 6 / sum, 2 / sum, 8 / sum,
|
||||||
|
4 / sum, 10 / sum, 1 / sum, 7 / sum,
|
||||||
|
3 / sum, 9 / sum, 5 / sum, 11 / sum};
|
||||||
|
ADT expected(keys, cpt);
|
||||||
|
EXPECT(assert_equal(expected, actual));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test min
|
||||||
|
TEST(ADT, Min) {
|
||||||
|
ADT a = exampleADT();
|
||||||
|
double min = a.min();
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, min, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test max
|
||||||
|
TEST(ADT, Max) {
|
||||||
|
ADT a = exampleADT();
|
||||||
|
double max = a.max();
|
||||||
|
EXPECT_DOUBLES_EQUAL(11.0, max, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -67,7 +67,8 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
double logConstant_; ///< log of the normalization constant.
|
double logConstant_; ///< log of the normalization constant.
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
* @brief Convert a DecisionTree of factors into
|
||||||
|
* a DecisionTree of Gaussian factor graphs.
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
|
@ -214,7 +215,8 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
|
||||||
* only, with the leaf values as the error for each assignment.
|
* only, with the leaf values as the error for each assignment.
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this Gaussian Mixture.
|
* @brief Compute the logProbability of this Gaussian Mixture.
|
||||||
|
|
|
@ -135,7 +135,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
||||||
* as the factors involved, and leaf values as the error.
|
* as the factors involved, and leaf values as the error.
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const;
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
* @brief Compute the log-likelihood, including the log-normalizing constant.
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
/**
|
/**
|
||||||
* @file HybridValues.h
|
* @file HybridValues.h
|
||||||
* @date Jul 28, 2022
|
* @date Jul 28, 2022
|
||||||
|
* @author Varun Agrawal
|
||||||
* @author Shangjie Xue
|
* @author Shangjie Xue
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
@ -54,13 +55,13 @@ class GTSAM_EXPORT HybridValues {
|
||||||
HybridValues() = default;
|
HybridValues() = default;
|
||||||
|
|
||||||
/// Construct from DiscreteValues and VectorValues.
|
/// Construct from DiscreteValues and VectorValues.
|
||||||
HybridValues(const VectorValues &cv, const DiscreteValues &dv)
|
HybridValues(const VectorValues& cv, const DiscreteValues& dv)
|
||||||
: continuous_(cv), discrete_(dv){}
|
: continuous_(cv), discrete_(dv) {}
|
||||||
|
|
||||||
/// Construct from all values types.
|
/// Construct from all values types.
|
||||||
HybridValues(const VectorValues& cv, const DiscreteValues& dv,
|
HybridValues(const VectorValues& cv, const DiscreteValues& dv,
|
||||||
const Values& v)
|
const Values& v)
|
||||||
: continuous_(cv), discrete_(dv), nonlinear_(v){}
|
: continuous_(cv), discrete_(dv), nonlinear_(v) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -101,9 +102,7 @@ class GTSAM_EXPORT HybridValues {
|
||||||
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }
|
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }
|
||||||
|
|
||||||
/// Check whether a variable with key \c j exists in values.
|
/// Check whether a variable with key \c j exists in values.
|
||||||
bool existsNonlinear(Key j) {
|
bool existsNonlinear(Key j) { return nonlinear_.exists(j); }
|
||||||
return nonlinear_.exists(j);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check whether a variable with key \c j exists.
|
/// Check whether a variable with key \c j exists.
|
||||||
bool exists(Key j) {
|
bool exists(Key j) {
|
||||||
|
@ -128,9 +127,7 @@ class GTSAM_EXPORT HybridValues {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// insert_or_assign() , similar to Values.h
|
/// insert_or_assign() , similar to Values.h
|
||||||
void insert_or_assign(Key j, size_t value) {
|
void insert_or_assign(Key j, size_t value) { discrete_[j] = value; }
|
||||||
discrete_[j] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Insert all continuous values from \c values. Throws an invalid_argument
|
/** Insert all continuous values from \c values. Throws an invalid_argument
|
||||||
* exception if any keys to be inserted are already used. */
|
* exception if any keys to be inserted are already used. */
|
||||||
|
|
|
@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) {
|
||||||
for (auto discrete_conditional : *discreteBayesNet) {
|
for (auto discrete_conditional : *discreteBayesNet) {
|
||||||
bayesNet->add(discrete_conditional);
|
bayesNet->add(discrete_conditional);
|
||||||
}
|
}
|
||||||
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();
|
|
||||||
|
|
||||||
HybridValues hybrid_values = bayesNet->optimize();
|
HybridValues hybrid_values = bayesNet->optimize();
|
||||||
|
|
||||||
|
|
|
@ -381,17 +381,22 @@ typedef gtsam::GncOptimizer<gtsam::GncParams<gtsam::LevenbergMarquardtParams>> G
|
||||||
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer {
|
virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer {
|
||||||
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
|
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
|
||||||
const gtsam::Values& initialValues);
|
const gtsam::Values& initialValues,
|
||||||
|
const gtsam::LevenbergMarquardtParams& params =
|
||||||
|
gtsam::LevenbergMarquardtParams());
|
||||||
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
|
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
|
||||||
const gtsam::Values& initialValues,
|
const gtsam::Values& initialValues,
|
||||||
const gtsam::LevenbergMarquardtParams& params);
|
const gtsam::Ordering& ordering,
|
||||||
|
const gtsam::LevenbergMarquardtParams& params =
|
||||||
|
gtsam::LevenbergMarquardtParams());
|
||||||
|
|
||||||
double lambda() const;
|
double lambda() const;
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/nonlinear/ISAM2.h>
|
#include <gtsam/nonlinear/ISAM2.h>
|
||||||
class ISAM2GaussNewtonParams {
|
class ISAM2GaussNewtonParams {
|
||||||
ISAM2GaussNewtonParams();
|
ISAM2GaussNewtonParams(double _wildfireThreshold = 0.001);
|
||||||
|
|
||||||
void print(string s = "") const;
|
void print(string s = "") const;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue