add TableFactor constructor from DiscreteConditional

release/4.3a0
Varun Agrawal 2023-06-28 19:09:08 -04:00
parent c06861e7da
commit fceb290e80
5 changed files with 44 additions and 15 deletions

View File

@ -181,6 +181,15 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************ */
std::vector<double> DecisionTreeFactor::probabilities() const {
std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}
return probs;
}
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const { DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;
@ -306,11 +315,10 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold. // Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities; std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) { // NOTE(Varun) this is potentially slow due to the cartesian product
size_t nrAssignments = leaf.nrAssignments(); for (auto&& [assignment, prob] : this->enumerate()) {
double prob = leaf.constant(); probabilities.push_back(prob);
probabilities.insert(probabilities.end(), nrAssignments, prob); }
});
// The number of probabilities can be lower than max_leaves // The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) { if (probabilities.size() <= N) {

View File

@ -181,6 +181,9 @@ namespace gtsam {
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;
/// Return all the discrete keys associated with this factor. /// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const; DiscreteKeys discreteKeys() const;

View File

@ -13,11 +13,12 @@
* @file TableFactor.cpp * @file TableFactor.cpp
* @brief discrete factor * @brief discrete factor
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#include <gtsam/base/FastSet.h> #include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
@ -56,6 +57,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
} }
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
/* ************************************************************************ */ /* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert( Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) { const std::vector<double>& table) {

View File

@ -12,7 +12,7 @@
/** /**
* @file TableFactor.h * @file TableFactor.h
* @date May 4, 2023 * @date May 4, 2023
* @author Yoonwoo Kim * @author Yoonwoo Kim, Varun Agrawal
*/ */
#pragma once #pragma once
@ -32,6 +32,7 @@
namespace gtsam { namespace gtsam {
class DiscreteConditional;
class HybridValues; class HybridValues;
/** /**
@ -57,10 +58,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/** /**
* @brief Uses lazy cartesian product to find nth entry in the cartesian * @brief Uses lazy cartesian product to find nth entry in the cartesian
* product of arrays in O(1) * product of arrays in O(1)
* Example) * Example)
* v0 | v1 | val * v0 | v1 | val
* 0 | 0 | 10 * 0 | 0 | 10
* 0 | 1 | 21 * 0 | 1 | 21
* 1 | 0 | 32 * 1 | 0 | 32
* 1 | 1 | 43 * 1 | 1 | 43
@ -75,13 +76,13 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey * @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_ * @param i ith key in keys_
* @return DiscreteKey * @return DiscreteKey
* */ */
DiscreteKey discreteKey(size_t i) const { DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
} }
/// Convert probability table given as doubles to SparseVector. /// Convert probability table given as doubles to SparseVector.
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
static Eigen::SparseVector<double> Convert(const std::vector<double>& table); static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
/// Convert probability table given as string to SparseVector. /// Convert probability table given as string to SparseVector.
@ -142,6 +143,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row) TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {} : TableFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{

View File

@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) { for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count() << " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl; << " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
} }
} }
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
// Assert that error = -log(value) // Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
TableFactor f4(conditional);
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
/* ************************************************************************* */ /* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors // Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity. // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) { // NOTE: Enable to run.
TEST_DISABLED(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);