Merge branch 'develop' into fix/expressions
commit
9cc00a85f6
|
@ -121,8 +121,8 @@ namespace gtsam {
|
||||||
for (auto&& factor : factors) product = (*factor) * product;
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// Sum all the potentials by pretending all keys are frontal:
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
auto normalization = product.sum(product.size());
|
auto normalization = product.max(product.size());
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*normalization);
|
product = product / (*normalization);
|
||||||
|
@ -210,6 +210,12 @@ namespace gtsam {
|
||||||
for (auto&& factor : factors) product = (*factor) * product;
|
for (auto&& factor : factors) product = (*factor) * product;
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
|
// Max over all the potentials by pretending all keys are frontal:
|
||||||
|
auto normalization = product.max(product.size());
|
||||||
|
|
||||||
|
// Normalize the product factor to prevent underflow.
|
||||||
|
product = product / (*normalization);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
gttic(sum);
|
gttic(sum);
|
||||||
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||||
|
|
|
@ -108,7 +108,14 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
|
|
||||||
// Test EliminateDiscrete
|
// Test EliminateDiscrete
|
||||||
const Ordering frontalKeys{0};
|
const Ordering frontalKeys{0};
|
||||||
const auto [conditional, newFactor] = EliminateDiscrete(graph, frontalKeys);
|
const auto [conditional, newFactorPtr] = EliminateDiscrete(graph, frontalKeys);
|
||||||
|
|
||||||
|
DecisionTreeFactor newFactor = *newFactorPtr;
|
||||||
|
|
||||||
|
// Normalize newFactor by max for comparison with expected
|
||||||
|
auto normalization = newFactor.max(newFactor.size());
|
||||||
|
|
||||||
|
newFactor = newFactor / *normalization;
|
||||||
|
|
||||||
// Check Conditional
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
|
@ -117,9 +124,13 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
EXPECT(assert_equal(expectedConditional, *conditional));
|
EXPECT(assert_equal(expectedConditional, *conditional));
|
||||||
|
|
||||||
// Check Factor
|
// Check Factor
|
||||||
CHECK(newFactor);
|
CHECK(&newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
EXPECT(assert_equal(expectedFactor, *newFactor));
|
// Normalize by max.
|
||||||
|
normalization = expectedFactor.max(expectedFactor.size());
|
||||||
|
// Ensure normalization is correct.
|
||||||
|
expectedFactor = expectedFactor / *normalization;
|
||||||
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
||||||
// Test using elimination tree
|
// Test using elimination tree
|
||||||
const Ordering ordering{0, 1, 2};
|
const Ordering ordering{0, 1, 2};
|
||||||
|
|
|
@ -59,17 +59,12 @@ public:
|
||||||
/// @name Advanced Constructors
|
/// @name Advanced Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
explicit PinholeBaseK(const Vector &v) :
|
explicit PinholeBaseK(const Vector& v) : PinholeBase(v) {}
|
||||||
PinholeBase(v) {
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
virtual ~PinholeBaseK() override {
|
|
||||||
}
|
|
||||||
|
|
||||||
/// return calibration
|
/// return calibration
|
||||||
virtual const CALIBRATION& calibration() const = 0;
|
virtual const CALIBRATION& calibration() const = 0;
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ public:
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Default constructor is origin */
|
/** Default constructor is origin */
|
||||||
Pose3() : R_(traits<Rot3>::Identity()), t_(traits<Point3>::Identity()) {}
|
Pose3() : R_(traits<Rot3>::Identity()), t_(traits<Point3>::Identity()) {}
|
||||||
|
|
||||||
/** Copy constructor */
|
/** Copy constructor */
|
||||||
Pose3(const Pose3& pose) :
|
Pose3(const Pose3& pose) :
|
||||||
|
|
|
@ -42,6 +42,9 @@ namespace gtsam {
|
||||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef Factor Base; ///< Our base class
|
typedef Factor Base; ///< Our base class
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Default constructor creates empty factor */
|
/** Default constructor creates empty factor */
|
||||||
GaussianFactor() {}
|
GaussianFactor() {}
|
||||||
|
|
||||||
|
@ -50,19 +53,22 @@ namespace gtsam {
|
||||||
template<typename CONTAINER>
|
template<typename CONTAINER>
|
||||||
GaussianFactor(const CONTAINER& keys) : Base(keys) {}
|
GaussianFactor(const CONTAINER& keys) : Base(keys) {}
|
||||||
|
|
||||||
/** Destructor */
|
/// @}
|
||||||
virtual ~GaussianFactor() override {}
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
// Implementing Testable interface
|
/// print with optional string
|
||||||
|
|
||||||
/// print
|
|
||||||
void print(
|
void print(
|
||||||
const std::string& s = "",
|
const std::string& s = "",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override = 0;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override = 0;
|
||||||
|
|
||||||
/** Equals for testable */
|
/// assert equality up to a tolerance
|
||||||
virtual bool equals(const GaussianFactor& lf, double tol = 1e-9) const = 0;
|
virtual bool equals(const GaussianFactor& lf, double tol = 1e-9) const = 0;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In Gaussian factors, the error function returns either the negative log-likelihood, e.g.,
|
* In Gaussian factors, the error function returns either the negative log-likelihood, e.g.,
|
||||||
* 0.5*(A*x-b)'*D*(A*x-b)
|
* 0.5*(A*x-b)'*D*(A*x-b)
|
||||||
|
@ -144,6 +150,10 @@ namespace gtsam {
|
||||||
virtual void updateHessian(const KeyVector& keys,
|
virtual void updateHessian(const KeyVector& keys,
|
||||||
SymmetricBlockMatrix* info) const = 0;
|
SymmetricBlockMatrix* info) const = 0;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Operator interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// y += alpha * A'*A*x
|
/// y += alpha * A'*A*x
|
||||||
virtual void multiplyHessianAdd(double alpha, const VectorValues& x, VectorValues& y) const = 0;
|
virtual void multiplyHessianAdd(double alpha, const VectorValues& x, VectorValues& y) const = 0;
|
||||||
|
|
||||||
|
@ -156,12 +166,18 @@ namespace gtsam {
|
||||||
/// Gradient wrt a key at any values
|
/// Gradient wrt a key at any values
|
||||||
virtual Vector gradient(Key key, const VectorValues& x) const = 0;
|
virtual Vector gradient(Key key, const VectorValues& x) const = 0;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Advanced Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
// Determine position of a given key
|
// Determine position of a given key
|
||||||
template <typename CONTAINER>
|
template <typename CONTAINER>
|
||||||
static DenseIndex Slot(const CONTAINER& keys, Key key) {
|
static DenseIndex Slot(const CONTAINER& keys, Key key) {
|
||||||
return std::find(keys.begin(), keys.end(), key) - keys.begin();
|
return std::find(keys.begin(), keys.end(), key) - keys.begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
|
@ -171,7 +187,6 @@ namespace gtsam {
|
||||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}; // GaussianFactor
|
}; // GaussianFactor
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
|
|
|
@ -107,9 +107,6 @@ namespace gtsam {
|
||||||
template<class DERIVEDFACTOR>
|
template<class DERIVEDFACTOR>
|
||||||
GaussianFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
GaussianFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/** Virtual destructor */
|
|
||||||
virtual ~GaussianFactorGraph() override {}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -130,7 +130,7 @@ namespace gtsam {
|
||||||
GTSAM_EXPORT std::ostream& operator<<(std::ostream& os, const VectorValues& v) {
|
GTSAM_EXPORT std::ostream& operator<<(std::ostream& os, const VectorValues& v) {
|
||||||
// Change print depending on whether we are using TBB
|
// Change print depending on whether we are using TBB
|
||||||
#ifdef GTSAM_USE_TBB
|
#ifdef GTSAM_USE_TBB
|
||||||
map<Key, Vector> sorted;
|
std::map<Key, Vector> sorted;
|
||||||
for (const auto& [key,value] : v) {
|
for (const auto& [key,value] : v) {
|
||||||
sorted.emplace(key, value);
|
sorted.emplace(key, value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -105,9 +105,6 @@ public:
|
||||||
/// @name Standard Interface
|
/// @name Standard Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/** Destructor */
|
|
||||||
virtual ~NonlinearFactor() override {}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In nonlinear factors, the error function returns the negative log-likelihood
|
* In nonlinear factors, the error function returns the negative log-likelihood
|
||||||
* as a non-linear function of the values in a \class Values object.
|
* as a non-linear function of the values in a \class Values object.
|
||||||
|
|
|
@ -78,9 +78,6 @@ namespace gtsam {
|
||||||
template<class DERIVEDFACTOR>
|
template<class DERIVEDFACTOR>
|
||||||
NonlinearFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
NonlinearFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
|
||||||
|
|
||||||
/// Destructor
|
|
||||||
virtual ~NonlinearFactorGraph() override {}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -79,8 +79,6 @@ namespace gtsam {
|
||||||
/** Create symbolic version of any factor */
|
/** Create symbolic version of any factor */
|
||||||
explicit SymbolicFactor(const Factor& factor) : Base(factor.keys()) {}
|
explicit SymbolicFactor(const Factor& factor) : Base(factor.keys()) {}
|
||||||
|
|
||||||
virtual ~SymbolicFactor() override {}
|
|
||||||
|
|
||||||
/// Copy this object as its actual derived type.
|
/// Copy this object as its actual derived type.
|
||||||
SymbolicFactor::shared_ptr clone() const { return std::make_shared<This>(*this); }
|
SymbolicFactor::shared_ptr clone() const { return std::make_shared<This>(*this); }
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ Author: Frank Dellaert
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gtsam import DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
|
||||||
OrderingType = Ordering.OrderingType
|
OrderingType = Ordering.OrderingType
|
||||||
|
@ -216,5 +216,63 @@ class TestDiscreteFactorGraph(GtsamTestCase):
|
||||||
|
|
||||||
self.assertEqual(vals, [desired_state]*num_obs)
|
self.assertEqual(vals, [desired_state]*num_obs)
|
||||||
|
|
||||||
|
def test_sumProduct_chain(self):
|
||||||
|
"""
|
||||||
|
Test for numerical underflow in EliminateDiscrete on long chains.
|
||||||
|
Adapted from the toy problem of @pcl15423
|
||||||
|
Ref: https://github.com/borglab/gtsam/issues/1448
|
||||||
|
"""
|
||||||
|
num_states = 3
|
||||||
|
chain_length = 400
|
||||||
|
desired_state = 1
|
||||||
|
states = list(range(num_states))
|
||||||
|
|
||||||
|
# Helper function to mimic the behavior of gtbook.Variables discrete_series function
|
||||||
|
def make_key(character, index, cardinality):
|
||||||
|
symbol = Symbol(character, index)
|
||||||
|
key = symbol.key()
|
||||||
|
return (key, cardinality)
|
||||||
|
|
||||||
|
X = {index: make_key("X", index, len(states)) for index in range(chain_length)}
|
||||||
|
graph = DiscreteFactorGraph()
|
||||||
|
|
||||||
|
# Construct test transition matrix
|
||||||
|
transitions = np.diag([1.0, 0.5, 0.1])
|
||||||
|
transitions += 0.1/(num_states)
|
||||||
|
|
||||||
|
# Ensure that the transition matrix is Markov (columns sum to 1)
|
||||||
|
transitions /= np.sum(transitions, axis=0)
|
||||||
|
|
||||||
|
# The stationary distribution is the eigenvector corresponding to eigenvalue 1
|
||||||
|
eigvals, eigvecs = np.linalg.eig(transitions)
|
||||||
|
stationary_idx = np.where(np.isclose(eigvals, 1.0))
|
||||||
|
stationary_dist = eigvecs[:, stationary_idx]
|
||||||
|
|
||||||
|
# Ensure that the stationary distribution is positive and normalized
|
||||||
|
stationary_dist /= np.sum(stationary_dist)
|
||||||
|
expected = DecisionTreeFactor(X[chain_length-1], stationary_dist.flatten())
|
||||||
|
|
||||||
|
# The transition matrix parsed by DiscreteConditional is a row-wise CPT
|
||||||
|
transitions = transitions.T
|
||||||
|
transition_cpt = []
|
||||||
|
for i in range(0, num_states):
|
||||||
|
transition_row = "/".join([str(x) for x in transitions[i]])
|
||||||
|
transition_cpt.append(transition_row)
|
||||||
|
transition_cpt = " ".join(transition_cpt)
|
||||||
|
|
||||||
|
for i in reversed(range(1, chain_length)):
|
||||||
|
transition_conditional = DiscreteConditional(X[i], [X[i-1]], transition_cpt)
|
||||||
|
graph.push_back(transition_conditional)
|
||||||
|
|
||||||
|
# Run sum product using natural ordering so the resulting Bayes net has the form:
|
||||||
|
# X_0 <- X_1 <- ... <- X_n
|
||||||
|
sum_product = graph.sumProduct(OrderingType.NATURAL)
|
||||||
|
|
||||||
|
# Get the DiscreteConditional representing the marginal on the last factor
|
||||||
|
last_marginal = sum_product.at(chain_length - 1)
|
||||||
|
|
||||||
|
# Ensure marginal probabilities are close to the stationary distribution
|
||||||
|
self.gtsamAssertEquals(expected, last_marginal)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue