improve DiscreteConditional::argmax method to accept parent values
parent
0a7db41290
commit
89f7f7f721
|
@ -235,16 +235,16 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::argmax() const {
|
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
size_t maxValue = 0;
|
size_t maxValue = 0;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
|
DiscreteValues values = parentsValues;
|
||||||
|
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
assert(nrParents() == 0);
|
|
||||||
DiscreteValues frontals;
|
|
||||||
Key j = firstFrontalKey();
|
Key j = firstFrontalKey();
|
||||||
for (size_t value = 0; value < cardinality(j); value++) {
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
frontals[j] = value;
|
values[j] = value;
|
||||||
double pValueS = (*this)(frontals);
|
double pValueS = (*this)(values);
|
||||||
// Update MPE solution if better
|
// Update MPE solution if better
|
||||||
if (pValueS > maxP) {
|
if (pValueS > maxP) {
|
||||||
maxP = pValueS;
|
maxP = pValueS;
|
||||||
|
|
|
@ -217,7 +217,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
* @brief Return assignment that maximizes distribution.
|
* @brief Return assignment that maximizes distribution.
|
||||||
* @return Optimal assignment (1 frontal variable).
|
* @return Optimal assignment (1 frontal variable).
|
||||||
*/
|
*/
|
||||||
size_t argmax() const;
|
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
|
|
|
@ -289,6 +289,29 @@ TEST(DiscreteConditional, choose) {
|
||||||
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
EXPECT(assert_equal(expected3, *actual3, 1e-9));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check argmax on P(C|D) and P(D)
|
||||||
|
TEST(DiscreteConditional, Argmax) {
|
||||||
|
DiscreteKey C(2, 2), D(4, 2);
|
||||||
|
DiscreteConditional D_cond(D, "1/3");
|
||||||
|
DiscreteConditional C_given_DE((C | D) = "1/4 1/1");
|
||||||
|
|
||||||
|
// Case 1: No parents
|
||||||
|
size_t actual1 = D_cond.argmax();
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual1);
|
||||||
|
|
||||||
|
// Case 2: Given parent values
|
||||||
|
DiscreteValues given;
|
||||||
|
given[D.first] = 1;
|
||||||
|
size_t actual2 = C_given_DE.argmax(given);
|
||||||
|
// Should be 0 since D=1 gives 0.5/0.5
|
||||||
|
EXPECT_LONGS_EQUAL(0, actual2);
|
||||||
|
|
||||||
|
given[D.first] = 0;
|
||||||
|
size_t actual3 = C_given_DE.argmax(given);
|
||||||
|
EXPECT_LONGS_EQUAL(1, actual3);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// Check markdown representation looks as expected, no parents.
|
// Check markdown representation looks as expected, no parents.
|
||||||
TEST(DiscreteConditional, markdown_prior) {
|
TEST(DiscreteConditional, markdown_prior) {
|
||||||
|
|
Loading…
Reference in New Issue