Merge pull request #1696 from borglab/model-selection-integration

AlgebraicDecisionTree Helpers
release/4.3a0
Varun Agrawal 2024-08-21 05:15:25 -04:00 committed by GitHub
commit 231d1adbbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 105 additions and 16 deletions

View File

@ -196,6 +196,42 @@ namespace gtsam {
return this->apply(g, &Ring::div); return this->apply(g, &Ring::div);
} }
/// Compute sum of all values
double sum() const {
double sum = 0;
auto visitor = [&](double y) { sum += y; };
this->visit(visitor);
return sum;
}
/**
* @brief Helper method to perform normalization such that all leaves in the
* tree sum to 1
*
* @param sum
* @return AlgebraicDecisionTree
*/
AlgebraicDecisionTree normalize(double sum) const {
return this->apply([&sum](const double& x) { return x / sum; });
}
/// Find the minimum values amongst all leaves
double min() const {
double min = std::numeric_limits<double>::max();
auto visitor = [&](double x) { min = x < min ? x : min; };
this->visit(visitor);
return min;
}
/// Find the maximum values amongst all leaves
double max() const {
// Get the most negative value
double max = -std::numeric_limits<double>::max();
auto visitor = [&](double x) { max = x > max ? x : max; };
this->visit(visitor);
return max;
}
/** sum out variable */ /** sum out variable */
AlgebraicDecisionTree sum(const L& label, size_t cardinality) const { AlgebraicDecisionTree sum(const L& label, size_t cardinality) const {
return this->combine(label, cardinality, &Ring::add); return this->combine(label, cardinality, &Ring::add);

View File

@ -593,6 +593,55 @@ TEST(ADT, zero) {
EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9);
} }
/// Example ADT from 0 to 11.
ADT exampleADT() {
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
ADT f(A & B & C, "0 6 2 8 4 10 1 7 3 9 5 11");
return f;
}
/* ************************************************************************** */
// Test sum
TEST(ADT, Sum) {
ADT a = exampleADT();
double expected_sum = 0;
for (double i = 0; i < 12; i++) {
expected_sum += i;
}
EXPECT_DOUBLES_EQUAL(expected_sum, a.sum(), 1e-9);
}
/* ************************************************************************** */
// Test normalize
TEST(ADT, Normalize) {
ADT a = exampleADT();
double sum = a.sum();
auto actual = a.normalize(sum);
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKeys keys = DiscreteKeys{A, B, C};
std::vector<double> cpt{0 / sum, 6 / sum, 2 / sum, 8 / sum,
4 / sum, 10 / sum, 1 / sum, 7 / sum,
3 / sum, 9 / sum, 5 / sum, 11 / sum};
ADT expected(keys, cpt);
EXPECT(assert_equal(expected, actual));
}
/* ************************************************************************** */
// Test min
TEST(ADT, Min) {
ADT a = exampleADT();
double min = a.min();
EXPECT_DOUBLES_EQUAL(0.0, min, 1e-9);
}
/* ************************************************************************** */
// Test max
TEST(ADT, Max) {
ADT a = exampleADT();
double max = a.max();
EXPECT_DOUBLES_EQUAL(11.0, max, 1e-9);
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -67,7 +67,8 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant. double logConstant_; ///< log of the normalization constant.
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into
* a DecisionTree of Gaussian factor graphs.
*/ */
GaussianFactorGraphTree asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
@ -214,7 +215,8 @@ class GTSAM_EXPORT GaussianMixture
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys * @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment. * only, with the leaf values as the error for each assignment.
*/ */
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
/** /**
* @brief Compute the logProbability of this Gaussian Mixture. * @brief Compute the logProbability of this Gaussian Mixture.

View File

@ -135,7 +135,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error. * as the factors involved, and leaf values as the error.
*/ */
AlgebraicDecisionTree<Key> errorTree(const VectorValues &continuousValues) const; AlgebraicDecisionTree<Key> errorTree(
const VectorValues &continuousValues) const;
/** /**
* @brief Compute the log-likelihood, including the log-normalizing constant. * @brief Compute the log-likelihood, including the log-normalizing constant.

View File

@ -12,6 +12,7 @@
/** /**
* @file HybridValues.h * @file HybridValues.h
* @date Jul 28, 2022 * @date Jul 28, 2022
* @author Varun Agrawal
* @author Shangjie Xue * @author Shangjie Xue
*/ */
@ -54,13 +55,13 @@ class GTSAM_EXPORT HybridValues {
HybridValues() = default; HybridValues() = default;
/// Construct from DiscreteValues and VectorValues. /// Construct from DiscreteValues and VectorValues.
HybridValues(const VectorValues &cv, const DiscreteValues &dv) HybridValues(const VectorValues& cv, const DiscreteValues& dv)
: continuous_(cv), discrete_(dv){} : continuous_(cv), discrete_(dv) {}
/// Construct from all values types. /// Construct from all values types.
HybridValues(const VectorValues& cv, const DiscreteValues& dv, HybridValues(const VectorValues& cv, const DiscreteValues& dv,
const Values& v) const Values& v)
: continuous_(cv), discrete_(dv), nonlinear_(v){} : continuous_(cv), discrete_(dv), nonlinear_(v) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -101,9 +102,7 @@ class GTSAM_EXPORT HybridValues {
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); } bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); }
/// Check whether a variable with key \c j exists in values. /// Check whether a variable with key \c j exists in values.
bool existsNonlinear(Key j) { bool existsNonlinear(Key j) { return nonlinear_.exists(j); }
return nonlinear_.exists(j);
}
/// Check whether a variable with key \c j exists. /// Check whether a variable with key \c j exists.
bool exists(Key j) { bool exists(Key j) {
@ -128,9 +127,7 @@ class GTSAM_EXPORT HybridValues {
} }
/// insert_or_assign() , similar to Values.h /// insert_or_assign() , similar to Values.h
void insert_or_assign(Key j, size_t value) { void insert_or_assign(Key j, size_t value) { discrete_[j] = value; }
discrete_[j] = value;
}
/** Insert all continuous values from \c values. Throws an invalid_argument /** Insert all continuous values from \c values. Throws an invalid_argument
* exception if any keys to be inserted are already used. */ * exception if any keys to be inserted are already used. */

View File

@ -333,7 +333,6 @@ TEST(HybridEstimation, Probability) {
for (auto discrete_conditional : *discreteBayesNet) { for (auto discrete_conditional : *discreteBayesNet) {
bayesNet->add(discrete_conditional); bayesNet->add(discrete_conditional);
} }
auto discreteConditional = discreteBayesNet->at(0)->asDiscrete();
HybridValues hybrid_values = bayesNet->optimize(); HybridValues hybrid_values = bayesNet->optimize();

View File

@ -381,17 +381,22 @@ typedef gtsam::GncOptimizer<gtsam::GncParams<gtsam::LevenbergMarquardtParams>> G
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h> #include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer { virtual class LevenbergMarquardtOptimizer : gtsam::NonlinearOptimizer {
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph, LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& initialValues); const gtsam::Values& initialValues,
const gtsam::LevenbergMarquardtParams& params =
gtsam::LevenbergMarquardtParams());
LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph, LevenbergMarquardtOptimizer(const gtsam::NonlinearFactorGraph& graph,
const gtsam::Values& initialValues, const gtsam::Values& initialValues,
const gtsam::LevenbergMarquardtParams& params); const gtsam::Ordering& ordering,
const gtsam::LevenbergMarquardtParams& params =
gtsam::LevenbergMarquardtParams());
double lambda() const; double lambda() const;
void print(string s = "") const; void print(string s = "") const;
}; };
#include <gtsam/nonlinear/ISAM2.h> #include <gtsam/nonlinear/ISAM2.h>
class ISAM2GaussNewtonParams { class ISAM2GaussNewtonParams {
ISAM2GaussNewtonParams(); ISAM2GaussNewtonParams(double _wildfireThreshold = 0.001);
void print(string s = "") const; void print(string s = "") const;