add TableFactor constructor from DiscreteConditional
parent
c06861e7da
commit
fceb290e80
|
|
@ -181,6 +181,15 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::vector<double> DecisionTreeFactor::probabilities() const {
|
||||||
|
std::vector<double> probs;
|
||||||
|
for (auto&& [key, value] : enumerate()) {
|
||||||
|
probs.push_back(value);
|
||||||
|
}
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
|
|
@ -306,11 +315,10 @@ namespace gtsam {
|
||||||
|
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Get the probabilities in the decision tree so we can threshold.
|
||||||
std::vector<double> probabilities;
|
std::vector<double> probabilities;
|
||||||
this->visitLeaf([&](const Leaf& leaf) {
|
// NOTE(Varun) this is potentially slow due to the cartesian product
|
||||||
size_t nrAssignments = leaf.nrAssignments();
|
for (auto&& [assignment, prob] : this->enumerate()) {
|
||||||
double prob = leaf.constant();
|
probabilities.push_back(prob);
|
||||||
probabilities.insert(probabilities.end(), nrAssignments, prob);
|
}
|
||||||
});
|
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves
|
// The number of probabilities can be lower than max_leaves
|
||||||
if (probabilities.size() <= N) {
|
if (probabilities.size() <= N) {
|
||||||
|
|
|
||||||
|
|
@ -181,6 +181,9 @@ namespace gtsam {
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
|
/// Get all the probabilities in order of assignment values
|
||||||
|
std::vector<double> probabilities() const;
|
||||||
|
|
||||||
/// Return all the discrete keys associated with this factor.
|
/// Return all the discrete keys associated with this factor.
|
||||||
DiscreteKeys discreteKeys() const;
|
DiscreteKeys discreteKeys() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,12 @@
|
||||||
* @file TableFactor.cpp
|
* @file TableFactor.cpp
|
||||||
* @brief discrete factor
|
* @brief discrete factor
|
||||||
* @date May 4, 2023
|
* @date May 4, 2023
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim, Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
|
@ -56,6 +57,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
|
: TableFactor(c.discreteKeys(), c.probabilities()) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
const std::vector<double>& table) {
|
const std::vector<double>& table) {
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
/**
|
/**
|
||||||
* @file TableFactor.h
|
* @file TableFactor.h
|
||||||
* @date May 4, 2023
|
* @date May 4, 2023
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim, Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
@ -32,6 +32,7 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
class DiscreteConditional;
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -57,10 +58,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
||||||
* product of arrays in O(1)
|
* product of arrays in O(1)
|
||||||
* Example)
|
* Example)
|
||||||
* v0 | v1 | val
|
* v0 | v1 | val
|
||||||
* 0 | 0 | 10
|
* 0 | 0 | 10
|
||||||
* 0 | 1 | 21
|
* 0 | 1 | 21
|
||||||
* 1 | 0 | 32
|
* 1 | 0 | 32
|
||||||
* 1 | 1 | 43
|
* 1 | 1 | 43
|
||||||
|
|
@ -75,13 +76,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
* @brief Return ith key in keys_ as a DiscreteKey
|
* @brief Return ith key in keys_ as a DiscreteKey
|
||||||
* @param i ith key in keys_
|
* @param i ith key in keys_
|
||||||
* @return DiscreteKey
|
* @return DiscreteKey
|
||||||
* */
|
*/
|
||||||
DiscreteKey discreteKey(size_t i) const {
|
DiscreteKey discreteKey(size_t i) const {
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert probability table given as doubles to SparseVector.
|
/// Convert probability table given as doubles to SparseVector.
|
||||||
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||||
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
||||||
|
|
||||||
/// Convert probability table given as string to SparseVector.
|
/// Convert probability table given as string to SparseVector.
|
||||||
|
|
@ -142,6 +143,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
: TableFactor(DiscreteKeys{key}, row) {}
|
: TableFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
|
/** Construct from a DiscreteConditional type */
|
||||||
|
explicit TableFactor(const DiscreteConditional& c);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
||||||
for (auto&& kv : measured_time) {
|
for (auto&& kv : measured_time) {
|
||||||
cout << "dropout: " << kv.first
|
cout << "dropout: " << kv.first
|
||||||
<< " | TableFactor time: " << kv.second.first.count()
|
<< " | TableFactor time: " << kv.second.first.count()
|
||||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
|
||||||
|
endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
|
||||||
|
|
||||||
// Assert that error = -log(value)
|
// Assert that error = -log(value)
|
||||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||||
|
|
||||||
|
// Construct from DiscreteConditional
|
||||||
|
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
||||||
|
TableFactor f4(conditional);
|
||||||
|
// Manually constructed via inspection and comparison to DecisionTreeFactor
|
||||||
|
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||||
|
EXPECT(assert_equal(expected, f4));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Benchmark which compares runtime of multiplication of two TableFactors
|
// Benchmark which compares runtime of multiplication of two TableFactors
|
||||||
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
||||||
TEST(TableFactor, benchmark) {
|
// NOTE: Enable to run.
|
||||||
|
TEST_DISABLED(TableFactor, benchmark) {
|
||||||
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
||||||
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue