add TableFactor constructor from DiscreteConditional

release/4.3a0
Varun Agrawal 2023-06-28 19:09:08 -04:00
parent c06861e7da
commit fceb290e80
5 changed files with 44 additions and 15 deletions

View File

@ -181,6 +181,15 @@ namespace gtsam {
return result;
}
/* ************************************************************************ */
std::vector<double> DecisionTreeFactor::probabilities() const {
std::vector<double> probs;
for (auto&& [key, value] : enumerate()) {
probs.push_back(value);
}
return probs;
}
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
@ -306,11 +315,10 @@ namespace gtsam {
// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visitLeaf([&](const Leaf& leaf) {
size_t nrAssignments = leaf.nrAssignments();
double prob = leaf.constant();
probabilities.insert(probabilities.end(), nrAssignments, prob);
});
// NOTE(Varun) this is potentially slow due to the cartesian product
for (auto&& [assignment, prob] : this->enumerate()) {
probabilities.push_back(prob);
}
// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {

View File

@ -181,6 +181,9 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
/// Get all the probabilities in order of assignment values
std::vector<double> probabilities() const;
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

View File

@ -13,11 +13,12 @@
* @file TableFactor.cpp
* @brief discrete factor
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridValues.h>
@ -56,6 +57,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
}
/* ************************************************************************ */
TableFactor::TableFactor(const DiscreteConditional& c)
: TableFactor(c.discreteKeys(), c.probabilities()) {}
/* ************************************************************************ */
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {

View File

@ -12,7 +12,7 @@
/**
* @file TableFactor.h
* @date May 4, 2023
* @author Yoonwoo Kim
* @author Yoonwoo Kim, Varun Agrawal
*/
#pragma once
@ -32,6 +32,7 @@
namespace gtsam {
class DiscreteConditional;
class HybridValues;
/**
@ -75,7 +76,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
* @brief Return ith key in keys_ as a DiscreteKey
* @param i ith key in keys_
* @return DiscreteKey
* */
*/
DiscreteKey discreteKey(size_t i) const {
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
}
@ -142,6 +143,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
: TableFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
explicit TableFactor(const DiscreteConditional& c);
/// @}
/// @name Testable
/// @{

View File

@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
}
}
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
// Assert that error = -log(value)
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
// Construct from DiscreteConditional
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
TableFactor f4(conditional);
// Manually constructed via inspection and comparison to DecisionTreeFactor
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
EXPECT(assert_equal(expected, f4));
}
/* ************************************************************************* */
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
/* ************************************************************************* */
// Benchmark which compares runtime of multiplication of two TableFactors
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
TEST(TableFactor, benchmark) {
// NOTE: Enable to run.
TEST_DISABLED(TableFactor, benchmark) {
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);