pass DiscreteConditional& for pruning instead of shared_ptr
parent
b7bddde82b
commit
d6bc1e11a6
|
@ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint;
|
||||||
for (auto &&conditional : marginal) {
|
for (auto &&conditional : marginal) {
|
||||||
// The last discrete conditional may be a TableDistribution
|
// The last discrete conditional may be a TableDistribution
|
||||||
if (auto dtc =
|
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||||
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
|
||||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||||
joint = joint * dc;
|
joint = joint * dc;
|
||||||
} else {
|
} else {
|
||||||
|
@ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Prune the hybrid Gaussian conditional!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
auto prunedHybridGaussianConditional = hgc->prune(*pruned);
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
result.push_back(prunedHybridGaussianConditional);
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
|
|
|
@ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
if (!hybridGaussianCond->pruned()) {
|
if (!hybridGaussianCond->pruned()) {
|
||||||
// Imperative
|
// Imperative
|
||||||
clique->conditional() = std::make_shared<HybridConditional>(
|
clique->conditional() = std::make_shared<HybridConditional>(
|
||||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return parentData;
|
return parentData;
|
||||||
|
|
|
@ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
const DiscreteConditional::shared_ptr &discreteProbs) const {
|
const DiscreteConditional &discreteProbs) const {
|
||||||
// Find keys in discreteProbs->keys() but not in this->keys():
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
std::set<Key> theirs(discreteProbs->keys().begin(),
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||||
discreteProbs->keys().end());
|
discreteProbs.keys().end());
|
||||||
std::vector<Key> diff;
|
std::vector<Key> diff;
|
||||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||||
std::back_inserter(diff));
|
std::back_inserter(diff));
|
||||||
|
|
||||||
// Find maximum probability value for every combination of our keys.
|
// Find maximum probability value for every combination of our keys.
|
||||||
Ordering keys(diff);
|
Ordering keys(diff);
|
||||||
auto max = discreteProbs->max(keys);
|
auto max = discreteProbs.max(keys);
|
||||||
|
|
||||||
// Check the max value for every combination of our keys.
|
// Check the max value for every combination of our keys.
|
||||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
|
|
|
@ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||||
*/
|
*/
|
||||||
HybridGaussianConditional::shared_ptr prune(
|
HybridGaussianConditional::shared_ptr prune(
|
||||||
const DiscreteConditional::shared_ptr &discreteProbs) const;
|
const DiscreteConditional &discreteProbs) const;
|
||||||
|
|
||||||
/// Return true if the conditional has already been pruned.
|
/// Return true if the conditional has already been pruned.
|
||||||
bool pruned() const { return pruned_; }
|
bool pruned() const { return pruned_; }
|
||||||
|
|
|
@ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
potentials[i] = 1;
|
potentials[i] = 1;
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
// Prune the HybridGaussianConditional
|
// Prune the HybridGaussianConditional
|
||||||
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>(
|
const auto pruned =
|
||||||
keys.size(), decisionTreeFactor));
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||||
}
|
}
|
||||||
|
@ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(
|
const auto pruned =
|
||||||
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||||
|
@ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
|
||||||
0, 0, 0.5, 0};
|
0, 0, 0.5, 0};
|
||||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
const auto pruned = hgc.prune(
|
const auto pruned =
|
||||||
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||||
|
|
||||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||||
|
|
Loading…
Reference in New Issue