Merge branch 'develop' into discrete-elimination-refactor

release/4.3a0
Varun Agrawal 2024-12-10 10:26:39 -05:00
commit cc4e9cb4db
19 changed files with 32 additions and 24 deletions

View File

@ -129,7 +129,7 @@ protected:
result_.addFailure(Failure(name_, __FILE__, __LINE__, #expected, #actual)); } result_.addFailure(Failure(name_, __FILE__, __LINE__, #expected, #actual)); }
#define CHECK_EQUAL(expected,actual)\ #define CHECK_EQUAL(expected,actual)\
{ if (!((expected) == (actual))) { result_.addFailure(Failure(name_, __FILE__, __LINE__, std::to_string(expected), std::to_string(actual))); } } { if (!((expected) == (actual))) { result_.addFailure(Failure(name_, __FILE__, __LINE__, std::to_string(expected), std::to_string(actual))); return; } }
#define LONGS_EQUAL(expected,actual)\ #define LONGS_EQUAL(expected,actual)\
{ long actualTemp = actual; \ { long actualTemp = actual; \

View File

@ -195,7 +195,7 @@ namespace gtsam {
// Construct unordered_map with values // Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result; std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) { for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment)); result.emplace_back(assignment, evaluate(assignment));
} }
return result; return result;
} }

View File

@ -137,10 +137,13 @@ namespace gtsam {
/// Calculate probability for given values, /// Calculate probability for given values,
/// is just look up in AlgebraicDecisionTree. /// is just look up in AlgebraicDecisionTree.
double operator()(const Assignment<Key>& values) const override { virtual double evaluate(const Assignment<Key>& values) const override {
return ADT::operator()(values); return ADT::operator()(values);
} }
/// Disambiguate to use DiscreteFactor version. Mainly for wrapper
using DiscreteFactor::operator();
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;

View File

@ -169,7 +169,7 @@ class GTSAM_EXPORT DiscreteConditional
} }
using BaseFactor::error; ///< DiscreteValues version using BaseFactor::error; ///< DiscreteValues version
using BaseFactor::evaluate; // DiscreteValues version using BaseFactor::evaluate; ///< DiscreteValues version
using BaseFactor::operator(); ///< DiscreteValues version using BaseFactor::operator(); ///< DiscreteValues version
/** /**

View File

@ -107,12 +107,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
* @param values Discrete assignment. * @param values Discrete assignment.
* @return double * @return double
*/ */
double evaluate(const Assignment<Key>& values) const { virtual double evaluate(const Assignment<Key>& values) const = 0;
return operator()(values);
}
/// Find value for given assignment of values to variables /// Find value for given assignment of values to variables
virtual double operator()(const Assignment<Key>& values) const = 0; double operator()(const DiscreteValues& values) const {
return evaluate(values);
}
/// Error is just -log(value) /// Error is just -log(value)
virtual double error(const DiscreteValues& values) const; virtual double error(const DiscreteValues& values) const;

View File

@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
double TableFactor::operator()(const Assignment<Key>& values) const { double TableFactor::evaluate(const Assignment<Key>& values) const {
// a b c d => D * (C * (B * (a) + b) + c) + d // a b c d => D * (C * (B * (a) + b) + c) + d
uint64_t idx = 0, card = 1; uint64_t idx = 0, card = 1;
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {

View File

@ -152,7 +152,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
// /// @{ // /// @{
/// Evaluate probability distribution, is just look up in TableFactor. /// Evaluate probability distribution, is just look up in TableFactor.
double operator()(const Assignment<Key>& values) const override; double evaluate(const Assignment<Key>& values) const override;
/// Calculate error for DiscreteValues `x`, is -log(probability). /// Calculate error for DiscreteValues `x`, is -log(probability).
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;

View File

@ -61,14 +61,14 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table); DecisionTreeFactor(const std::vector<gtsam::DiscreteKey>& keys, string table);
DecisionTreeFactor(const gtsam::DiscreteConditional& c); DecisionTreeFactor(const gtsam::DiscreteConditional& c);
void print(string s = "DecisionTreeFactor\n", void print(string s = "DecisionTreeFactor\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
size_t cardinality(gtsam::Key j) const; size_t cardinality(gtsam::Key j) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
size_t cardinality(gtsam::Key j) const; size_t cardinality(gtsam::Key j) const;

View File

@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double AllDiff::operator()(const Assignment<Key>& values) const { double AllDiff::evaluate(const Assignment<Key>& values) const {
std::set<size_t> taken; // record values taken by keys std::set<size_t> taken; // record values taken by keys
for (Key dkey : keys_) { for (Key dkey : keys_) {
size_t value = values.at(dkey); // get the value for that key size_t value = values.at(dkey); // get the value for that key

View File

@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
} }
/// Calculate value = expensive ! /// Calculate value = expensive !
double operator()(const Assignment<Key>& values) const override; double evaluate(const Assignment<Key>& values) const override;
/// Convert into a decisiontree, can be *very* expensive ! /// Convert into a decisiontree, can be *very* expensive !
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;

View File

@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
} }
/// Calculate value /// Calculate value
double operator()(const Assignment<Key>& values) const override { double evaluate(const Assignment<Key>& values) const override {
return (double)(values.at(keys_[0]) != values.at(keys_[1])); return (double)(values.at(keys_[0]) != values.at(keys_[1]));
} }

View File

@ -30,7 +30,7 @@ string Domain::base1Str() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double Domain::operator()(const Assignment<Key>& values) const { double Domain::evaluate(const Assignment<Key>& values) const {
return contains(values.at(key())); return contains(values.at(key()));
} }

View File

@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
bool contains(size_t value) const { return values_.count(value) > 0; } bool contains(size_t value) const { return values_.count(value) > 0; }
/// Calculate value /// Calculate value
double operator()(const Assignment<Key>& values) const override; double evaluate(const Assignment<Key>& values) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;

View File

@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double SingleValue::operator()(const Assignment<Key>& values) const { double SingleValue::evaluate(const Assignment<Key>& values) const {
return (double)(values.at(keys_[0]) == value_); return (double)(values.at(keys_[0]) == value_);
} }

View File

@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
} }
/// Calculate value /// Calculate value
double operator()(const Assignment<Key>& values) const override; double evaluate(const Assignment<Key>& values) const override;
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override; DecisionTreeFactor toDecisionTreeFactor() const override;

View File

@ -13,9 +13,10 @@ Author: Frank Dellaert
import unittest import unittest
from gtsam.utils.test_case import GtsamTestCase
from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues, from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
Ordering) Ordering)
from gtsam.utils.test_case import GtsamTestCase
class TestDecisionTreeFactor(GtsamTestCase): class TestDecisionTreeFactor(GtsamTestCase):

View File

@ -19,8 +19,8 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
DiscreteConditional, DiscreteFactorGraph, DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
DiscreteValues, Ordering) Ordering)
class TestDiscreteBayesNet(GtsamTestCase): class TestDiscreteBayesNet(GtsamTestCase):

View File

@ -13,9 +13,10 @@ Author: Varun Agrawal
import unittest import unittest
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
# Some DiscreteKeys for binary variables: # Some DiscreteKeys for binary variables:
A = 0, 2 A = 0, 2
B = 1, 2 B = 1, 2

View File

@ -14,9 +14,12 @@ Author: Frank Dellaert
import unittest import unittest
import numpy as np import numpy as np
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import (DecisionTreeFactor, DiscreteConditional,
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
Symbol)
OrderingType = Ordering.OrderingType OrderingType = Ordering.OrderingType