Merge pull request #1785 from borglab/discrete-improv-2
commit
1422b6c431
|
|
@ -236,6 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
size_t maxValue = 0;
|
size_t maxValue = 0;
|
||||||
double maxP = 0;
|
double maxP = 0;
|
||||||
DiscreteValues values = parentsValues;
|
DiscreteValues values = parentsValues;
|
||||||
|
|
@ -254,6 +257,33 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
return maxValue;
|
return maxValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteConditional::argmaxInPlace(DiscreteValues* values) const {
|
||||||
|
ADT pFS = choose(*values, true); // P(F|S=parentsValues)
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
DiscreteValues mpe;
|
||||||
|
double maxP = 0;
|
||||||
|
|
||||||
|
// Get all Possible Configurations
|
||||||
|
const auto allPosbValues = frontalAssignments();
|
||||||
|
|
||||||
|
// Find the maximum
|
||||||
|
for (const auto& frontalVals : allPosbValues) {
|
||||||
|
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
|
||||||
|
// Update maximum solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
mpe = frontalVals;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set values (inPlace) to maximum
|
||||||
|
for (Key j : frontals()) {
|
||||||
|
(*values)[j] = mpe[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
assert(nrFrontals() == 1);
|
assert(nrFrontals() == 1);
|
||||||
|
|
@ -459,7 +489,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteConditional::evaluate(const HybridValues& x) const{
|
double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||||
return this->evaluate(x.discrete());
|
return this->evaluate(x.discrete());
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,9 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -39,7 +39,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
public Conditional<DecisionTreeFactor, DiscreteConditional> {
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DiscreteConditional This; ///< Typedef to this class
|
typedef DiscreteConditional This; ///< Typedef to this class
|
||||||
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
typedef DecisionTreeFactor BaseFactor; ///< Typedef to our factor base class
|
||||||
typedef Conditional<BaseFactor, This>
|
typedef Conditional<BaseFactor, This>
|
||||||
|
|
@ -159,9 +159,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Log-probability is just -error(x).
|
/// Log-probability is just -error(x).
|
||||||
double logProbability(const DiscreteValues& x) const {
|
double logProbability(const DiscreteValues& x) const { return -error(x); }
|
||||||
return -error(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// print index signature only
|
/// print index signature only
|
||||||
void printSignature(
|
void printSignature(
|
||||||
|
|
@ -214,11 +212,18 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
size_t sample() const;
|
size_t sample() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return assignment that maximizes distribution.
|
* @brief Return assignment for single frontal variable that maximizes value.
|
||||||
* @return Optimal assignment (1 frontal variable).
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
*/
|
*/
|
||||||
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calculate assignment for frontal variables that maximizes value.
|
||||||
|
* @param (in/out) parentsValues Known assignments for the parents.
|
||||||
|
*/
|
||||||
|
void argmaxInPlace(DiscreteValues* parentsValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -244,7 +249,6 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name HybridValues methods.
|
/// @name HybridValues methods.
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,8 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet(
|
||||||
|
|
||||||
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const {
|
||||||
// Argmax each node in turn in topological sort order (parents first).
|
// Argmax each node in turn in topological sort order (parents first).
|
||||||
for (auto it = std::make_reverse_iterator(end()); it != std::make_reverse_iterator(begin()); ++it) {
|
for (auto it = std::make_reverse_iterator(end());
|
||||||
|
it != std::make_reverse_iterator(begin()); ++it) {
|
||||||
// dereference to get the sharedFactor to the lookup table
|
// dereference to get the sharedFactor to the lookup table
|
||||||
(*it)->argmaxInPlace(&result);
|
(*it)->argmaxInPlace(&result);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue