Merge pull request #1928 from borglab/fix-table-factor
commit
137a503746
|
@ -48,7 +48,7 @@ namespace gtsam {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
return ADT::equals(f, tol);
|
return Base::equals(other, tol) && ADT::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,11 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
bool DiscreteFactor::equals(const DiscreteFactor& lf, double tol) const {
|
||||||
|
return Base::equals(lf, tol) && cardinalities_ == lf.cardinalities_;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
DiscreteKeys DiscreteFactor::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
|
|
|
@ -77,7 +77,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// equals
|
/// equals
|
||||||
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
|
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const;
|
||||||
|
|
||||||
/// print
|
/// print
|
||||||
void print(
|
void print(
|
||||||
|
|
|
@ -92,13 +92,28 @@ TableFactor::TableFactor(const DiscreteKeys& dkeys,
|
||||||
const DecisionTreeFactor& dtf)
|
const DecisionTreeFactor& dtf)
|
||||||
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
|
: TableFactor(dkeys, ComputeLeafOrdering(dkeys, dtf)) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
TableFactor::TableFactor(const DecisionTreeFactor& dtf)
|
||||||
|
: TableFactor(dtf.discreteKeys(),
|
||||||
|
ComputeLeafOrdering(dtf.discreteKeys(), dtf)) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::TableFactor(const DiscreteConditional& c)
|
TableFactor::TableFactor(const DiscreteConditional& c)
|
||||||
: TableFactor(c.discreteKeys(), c) {}
|
: TableFactor(c.discreteKeys(), c) {}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(
|
Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
const std::vector<double>& table) {
|
const DiscreteKeys& keys, const std::vector<double>& table) {
|
||||||
|
size_t max_size = 1;
|
||||||
|
for (auto&& [_, cardinality] : keys.cardinalities()) {
|
||||||
|
max_size *= cardinality;
|
||||||
|
}
|
||||||
|
if (table.size() != max_size) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"The cardinalities of the keys don't match the number of values in the "
|
||||||
|
"input.");
|
||||||
|
}
|
||||||
|
|
||||||
Eigen::SparseVector<double> sparse_table(table.size());
|
Eigen::SparseVector<double> sparse_table(table.size());
|
||||||
// Count number of nonzero elements in table and reserve the space.
|
// Count number of nonzero elements in table and reserve the space.
|
||||||
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
const uint64_t nnz = std::count_if(table.begin(), table.end(),
|
||||||
|
@ -113,13 +128,14 @@ Eigen::SparseVector<double> TableFactor::Convert(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
|
Eigen::SparseVector<double> TableFactor::Convert(const DiscreteKeys& keys,
|
||||||
|
const std::string& table) {
|
||||||
// Convert string to doubles.
|
// Convert string to doubles.
|
||||||
std::vector<double> ys;
|
std::vector<double> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
|
||||||
std::back_inserter(ys));
|
std::back_inserter(ys));
|
||||||
return Convert(ys);
|
return Convert(keys, ys);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -128,7 +144,8 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
const auto& f(static_cast<const TableFactor&>(other));
|
const auto& f(static_cast<const TableFactor&>(other));
|
||||||
return sparse_table_.isApprox(f.sparse_table_, tol);
|
return Base::equals(other, tol) &&
|
||||||
|
sparse_table_.isApprox(f.sparse_table_, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +267,8 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
for (auto&& kv : assignment) {
|
for (auto&& kv : assignment) {
|
||||||
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
|
||||||
}
|
}
|
||||||
cout << " | " << it.value() << " | " << it.index() << endl;
|
cout << " | " << std::setw(10) << std::left << it.value() << " | "
|
||||||
|
<< it.index() << endl;
|
||||||
}
|
}
|
||||||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||||
}
|
}
|
||||||
|
|
|
@ -80,12 +80,16 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert probability table given as doubles to SparseVector.
|
/**
|
||||||
/// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
* Convert probability table given as doubles to SparseVector.
|
||||||
static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
|
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
||||||
|
*/
|
||||||
|
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
|
|
||||||
/// Convert probability table given as string to SparseVector.
|
/// Convert probability table given as string to SparseVector.
|
||||||
static Eigen::SparseVector<double> Convert(const std::string& table);
|
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
||||||
|
const std::string& table);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
|
@ -111,11 +115,11 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/** Constructor from doubles */
|
||||||
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
||||||
: TableFactor(keys, Convert(table)) {}
|
: TableFactor(keys, Convert(keys, table)) {}
|
||||||
|
|
||||||
/** Constructor from string */
|
/** Constructor from string */
|
||||||
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
||||||
: TableFactor(keys, Convert(table)) {}
|
: TableFactor(keys, Convert(keys, table)) {}
|
||||||
|
|
||||||
/// Single-key specialization
|
/// Single-key specialization
|
||||||
template <class SOURCE>
|
template <class SOURCE>
|
||||||
|
@ -128,6 +132,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
|
|
||||||
/// Constructor from DecisionTreeFactor
|
/// Constructor from DecisionTreeFactor
|
||||||
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
|
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
|
||||||
|
TableFactor(const DecisionTreeFactor& dtf);
|
||||||
|
|
||||||
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
|
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
|
||||||
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
|
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
|
||||||
|
|
|
@ -134,14 +134,17 @@ TEST(TableFactor, constructors) {
|
||||||
EXPECT(assert_equal(expected, f4));
|
EXPECT(assert_equal(expected, f4));
|
||||||
|
|
||||||
// Test for 9=3x3 values.
|
// Test for 9=3x3 values.
|
||||||
DiscreteKey V(0, 3), W(1, 3);
|
DiscreteKey V(0, 3), W(1, 3), O(100, 3);
|
||||||
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
DiscreteConditional conditional5(V | W = "1/2/3 5/6/7 9/10/11");
|
||||||
TableFactor f5(conditional5);
|
TableFactor f5(conditional5);
|
||||||
// GTSAM_PRINT(f5);
|
|
||||||
TableFactor expected_f5(
|
std::string expected_values =
|
||||||
X & Y,
|
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667";
|
||||||
"0.166667 0.277778 0.3 0.333333 0.333333 0.333333 0.5 0.388889 0.366667");
|
TableFactor expected_f5(V & W, expected_values);
|
||||||
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
EXPECT(assert_equal(expected_f5, f5, 1e-6));
|
||||||
|
|
||||||
|
TableFactor f5_with_wrong_keys(V & O, expected_values);
|
||||||
|
EXPECT(assert_inequal(f5_with_wrong_keys, f5, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue