Merge branch 'common-discrete-evaluate' into common-ops
commit
246a1fa25b
|
@ -195,7 +195,7 @@ namespace gtsam {
|
||||||
// Construct unordered_map with values
|
// Construct unordered_map with values
|
||||||
std::vector<std::pair<DiscreteValues, double>> result;
|
std::vector<std::pair<DiscreteValues, double>> result;
|
||||||
for (const auto& assignment : assignments) {
|
for (const auto& assignment : assignments) {
|
||||||
result.emplace_back(assignment, operator()(assignment));
|
result.emplace_back(assignment, evaluate(assignment));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,10 +137,13 @@ namespace gtsam {
|
||||||
|
|
||||||
/// Calculate probability for given values,
|
/// Calculate probability for given values,
|
||||||
/// is just look up in AlgebraicDecisionTree.
|
/// is just look up in AlgebraicDecisionTree.
|
||||||
double operator()(const Assignment<Key>& values) const override {
|
virtual double evaluate(const Assignment<Key>& values) const override {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Disambiguate to use DiscreteFactor version. Mainly for wrapper
|
||||||
|
using DiscreteFactor::operator();
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const override;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
|
|
|
@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisionTree
|
/// Evaluate, just look up in AlgebraicDecisionTree
|
||||||
double evaluate(const DiscreteValues& values) const {
|
virtual double evaluate(const Assignment<Key>& values) const override {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
using DecisionTreeFactor::error; ///< DiscreteValues version
|
using DecisionTreeFactor::error; ///< DiscreteValues version
|
||||||
using DecisionTreeFactor::operator(); ///< DiscreteValues version
|
using DiscreteFactor::operator(); ///< DiscreteValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief restrict to given *parent* values.
|
* @brief restrict to given *parent* values.
|
||||||
|
|
|
@ -106,12 +106,12 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
* @param values Discrete assignment.
|
* @param values Discrete assignment.
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double evaluate(const Assignment<Key>& values) const {
|
virtual double evaluate(const Assignment<Key>& values) const = 0;
|
||||||
return operator()(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find value for given assignment of values to variables
|
/// Find value for given assignment of values to variables
|
||||||
virtual double operator()(const Assignment<Key>& values) const = 0;
|
double operator()(const DiscreteValues& values) const {
|
||||||
|
return evaluate(values);
|
||||||
|
}
|
||||||
|
|
||||||
/// Error is just -log(value)
|
/// Error is just -log(value)
|
||||||
virtual double error(const DiscreteValues& values) const;
|
virtual double error(const DiscreteValues& values) const;
|
||||||
|
|
|
@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double TableFactor::operator()(const Assignment<Key>& values) const {
|
double TableFactor::evaluate(const Assignment<Key>& values) const {
|
||||||
// a b c d => D * (C * (B * (a) + b) + c) + d
|
// a b c d => D * (C * (B * (a) + b) + c) + d
|
||||||
uint64_t idx = 0, card = 1;
|
uint64_t idx = 0, card = 1;
|
||||||
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
|
for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
|
||||||
|
|
|
@ -152,7 +152,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
// /// @{
|
// /// @{
|
||||||
|
|
||||||
/// Evaluate probability distribution, is just look up in TableFactor.
|
/// Evaluate probability distribution, is just look up in TableFactor.
|
||||||
double operator()(const Assignment<Key>& values) const override;
|
double evaluate(const Assignment<Key>& values) const override;
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const override;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
|
@ -41,7 +41,7 @@ virtual class DiscreteFactor : gtsam::Factor {
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
bool equals(const gtsam::DiscreteFactor& other, double tol = 1e-9) const;
|
||||||
double operator()(const gtsam::Assignment<gtsam::Key>& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
@ -69,7 +69,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
|
||||||
|
|
||||||
size_t cardinality(gtsam::Key j) const;
|
size_t cardinality(gtsam::Key j) const;
|
||||||
|
|
||||||
double operator()(const gtsam::Assignment<gtsam::Key>& values) const;
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const;
|
||||||
size_t cardinality(gtsam::Key j) const;
|
size_t cardinality(gtsam::Key j) const;
|
||||||
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const;
|
||||||
|
@ -248,6 +248,7 @@ class DiscreteBayesTree {
|
||||||
void saveGraph(string s,
|
void saveGraph(string s,
|
||||||
const gtsam::KeyFormatter& keyFormatter =
|
const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
double operator()(const gtsam::DiscreteValues& values) const;
|
||||||
|
|
||||||
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
string markdown(const gtsam::KeyFormatter& keyFormatter =
|
||||||
gtsam::DefaultKeyFormatter) const;
|
gtsam::DefaultKeyFormatter) const;
|
||||||
|
|
|
@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double AllDiff::operator()(const Assignment<Key>& values) const {
|
double AllDiff::evaluate(const Assignment<Key>& values) const {
|
||||||
std::set<size_t> taken; // record values taken by keys
|
std::set<size_t> taken; // record values taken by keys
|
||||||
for (Key dkey : keys_) {
|
for (Key dkey : keys_) {
|
||||||
size_t value = values.at(dkey); // get the value for that key
|
size_t value = values.at(dkey); // get the value for that key
|
||||||
|
|
|
@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value = expensive !
|
/// Calculate value = expensive !
|
||||||
double operator()(const Assignment<Key>& values) const override;
|
double evaluate(const Assignment<Key>& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree, can be *very* expensive !
|
/// Convert into a decisiontree, can be *very* expensive !
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
|
@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Assignment<Key>& values) const override {
|
double evaluate(const Assignment<Key>& values) const override {
|
||||||
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
|
return (double)(values.at(keys_[0]) != values.at(keys_[1]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ string Domain::base1Str() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double Domain::operator()(const Assignment<Key>& values) const {
|
double Domain::evaluate(const Assignment<Key>& values) const {
|
||||||
return contains(values.at(key()));
|
return contains(values.at(key()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
bool contains(size_t value) const { return values_.count(value) > 0; }
|
bool contains(size_t value) const { return values_.count(value) > 0; }
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Assignment<Key>& values) const override;
|
double evaluate(const Assignment<Key>& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
|
@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double SingleValue::operator()(const Assignment<Key>& values) const {
|
double SingleValue::evaluate(const Assignment<Key>& values) const {
|
||||||
return (double)(values.at(keys_[0]) == value_);
|
return (double)(values.at(keys_[0]) == value_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double operator()(const Assignment<Key>& values) const override;
|
double evaluate(const Assignment<Key>& values) const override;
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override;
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
||||||
|
|
|
@ -13,9 +13,10 @@ Author: Frank Dellaert
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
|
from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues,
|
||||||
Ordering)
|
Ordering)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestDecisionTreeFactor(GtsamTestCase):
|
class TestDecisionTreeFactor(GtsamTestCase):
|
||||||
|
|
|
@ -19,8 +19,8 @@ from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
import gtsam
|
import gtsam
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique,
|
||||||
DiscreteConditional, DiscreteFactorGraph,
|
DiscreteConditional, DiscreteFactorGraph, DiscreteValues,
|
||||||
DiscreteValues, Ordering)
|
Ordering)
|
||||||
|
|
||||||
|
|
||||||
class TestDiscreteBayesNet(GtsamTestCase):
|
class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
|
|
|
@ -13,9 +13,10 @@ Author: Varun Agrawal
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys
|
||||||
|
|
||||||
# Some DiscreteKeys for binary variables:
|
# Some DiscreteKeys for binary variables:
|
||||||
A = 0, 2
|
A = 0, 2
|
||||||
B = 1, 2
|
B = 1, 2
|
||||||
|
|
|
@ -14,9 +14,12 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
|
from gtsam import (DecisionTreeFactor, DiscreteConditional,
|
||||||
|
DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering,
|
||||||
|
Symbol)
|
||||||
|
|
||||||
OrderingType = Ordering.OrderingType
|
OrderingType = Ordering.OrderingType
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue