add TableFactor constructor from DiscreteConditional
parent
c06861e7da
commit
fceb290e80
|
|
@ -181,6 +181,15 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
std::vector<double> DecisionTreeFactor::probabilities() const {
|
||||
std::vector<double> probs;
|
||||
for (auto&& [key, value] : enumerate()) {
|
||||
probs.push_back(value);
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||
DiscreteKeys result;
|
||||
|
|
@ -306,11 +315,10 @@ namespace gtsam {
|
|||
|
||||
// Get the probabilities in the decision tree so we can threshold.
|
||||
std::vector<double> probabilities;
|
||||
this->visitLeaf([&](const Leaf& leaf) {
|
||||
size_t nrAssignments = leaf.nrAssignments();
|
||||
double prob = leaf.constant();
|
||||
probabilities.insert(probabilities.end(), nrAssignments, prob);
|
||||
});
|
||||
// NOTE(Varun) this is potentially slow due to the cartesian product
|
||||
for (auto&& [assignment, prob] : this->enumerate()) {
|
||||
probabilities.push_back(prob);
|
||||
}
|
||||
|
||||
// The number of probabilities can be lower than max_leaves
|
||||
if (probabilities.size() <= N) {
|
||||
|
|
|
|||
|
|
@ -181,6 +181,9 @@ namespace gtsam {
|
|||
/// Enumerate all values into a map from values to double.
|
||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||
|
||||
/// Get all the probabilities in order of assignment values
|
||||
std::vector<double> probabilities() const;
|
||||
|
||||
/// Return all the discrete keys associated with this factor.
|
||||
DiscreteKeys discreteKeys() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@
|
|||
* @file TableFactor.cpp
|
||||
* @brief discrete factor
|
||||
* @date May 4, 2023
|
||||
* @author Yoonwoo Kim
|
||||
* @author Yoonwoo Kim, Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/FastSet.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
|
|
@ -56,6 +57,10 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
|||
sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||
: TableFactor(c.discreteKeys(), c.probabilities()) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
Eigen::SparseVector<double> TableFactor::Convert(
|
||||
const std::vector<double>& table) {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@
|
|||
/**
|
||||
* @file TableFactor.h
|
||||
* @date May 4, 2023
|
||||
* @author Yoonwoo Kim
|
||||
* @author Yoonwoo Kim, Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
|
@ -32,6 +32,7 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
class DiscreteConditional;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
|
|
@ -75,7 +76,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
* @brief Return ith key in keys_ as a DiscreteKey
|
||||
* @param i ith key in keys_
|
||||
* @return DiscreteKey
|
||||
* */
|
||||
*/
|
||||
DiscreteKey discreteKey(size_t i) const {
|
||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||
}
|
||||
|
|
@ -142,6 +143,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||
: TableFactor(DiscreteKeys{key}, row) {}
|
||||
|
||||
/** Construct from a DiscreteConditional type */
|
||||
explicit TableFactor(const DiscreteConditional& c);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
|
|||
|
|
@ -93,7 +93,8 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
|
|||
for (auto&& kv : measured_time) {
|
||||
cout << "dropout: " << kv.first
|
||||
<< " | TableFactor time: " << kv.second.first.count()
|
||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
|
||||
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
|
||||
endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -124,6 +125,13 @@ TEST(TableFactor, constructors) {
|
|||
|
||||
// Assert that error = -log(value)
|
||||
EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
|
||||
|
||||
// Construct from DiscreteConditional
|
||||
DiscreteConditional conditional(X | Y = "1/1 2/3 1/4");
|
||||
TableFactor f4(conditional);
|
||||
// Manually constructed via inspection and comparison to DecisionTreeFactor
|
||||
TableFactor expected(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8");
|
||||
EXPECT(assert_equal(expected, f4));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -156,7 +164,8 @@ TEST(TableFactor, multiplication) {
|
|||
/* ************************************************************************* */
|
||||
// Benchmark which compares runtime of multiplication of two TableFactors
|
||||
// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
|
||||
TEST(TableFactor, benchmark) {
|
||||
// NOTE: Enable to run.
|
||||
TEST_DISABLED(TableFactor, benchmark) {
|
||||
DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
|
||||
H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue