Merge pull request #1785 from borglab/discrete-improv-2

release/4.3a0
Varun Agrawal 2024-07-21 11:39:43 -04:00 committed by GitHub
commit 1422b6c431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 10 deletions

View File

@ -236,6 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
/* ************************************************************************** */ /* ************************************************************************** */
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const { size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Initialize
size_t maxValue = 0; size_t maxValue = 0;
double maxP = 0; double maxP = 0;
DiscreteValues values = parentsValues; DiscreteValues values = parentsValues;
@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
return maxValue; return maxValue;
} }
/* ************************************************************************** */
void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const {
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
// Initialize
DiscreteValues mpe;
double maxP = 0;
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();
// Find the maximum
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1); assert(nrFrontals() == 1);

View File

@ -18,9 +18,9 @@
#pragma once #pragma once
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <memory> #include <memory>
#include <string> #include <string>
@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
/// @{ /// @{
/// Log-probability is just -error(x). /// Log-probability is just -error(x).
double logProbability(const DiscreteValues& x) const { double logProbability(const DiscreteValues& x) const { return -error(x); }
return -error(x);
}
/// print index signature only /// print index signature only
void printSignature( void printSignature(
@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional
size_t sample() const; size_t sample() const;
/** /**
* @brief Return assignment that maximizes distribution. * @brief Return assignment for single frontal variable that maximizes value.
* @return Optimal assignment (1 frontal variable). * @param parentsValues Known assignments for the parents.
* @return maximizing assignment for the frontal variable.
*/ */
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(DiscreteValues* parentsValues) const;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{
@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
/// @name HybridValues methods. /// @name HybridValues methods.
/// @{ /// @{

View File

@ -119,7 +119,8 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
// Argmax each node in turn in topological sort order (parents first). // Argmax each node in turn in topological sort order (parents first).
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) { for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
// dereference to get the sharedFactor to the lookup table // dereference to get the sharedFactor to the lookup table
(*it)->argmaxInPlace(&result); (*it)->argmaxInPlace(&result);
} }