pass DiscreteConditional& for pruning instead of shared_ptr
parent
b7bddde82b
commit
d6bc1e11a6
|
@ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
// The last discrete conditional may be a TableDistribution
|
||||
if (auto dtc =
|
||||
std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||
if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) {
|
||||
DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor());
|
||||
joint = joint * dc;
|
||||
} else {
|
||||
|
@ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
for (auto &&conditional : *this) {
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
// Prune the hybrid Gaussian conditional!
|
||||
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||
auto prunedHybridGaussianConditional = hgc->prune(*pruned);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
|
|
|
@ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
if (!hybridGaussianCond->pruned()) {
|
||||
// Imperative
|
||||
clique->conditional() = std::make_shared<HybridConditional>(
|
||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
||||
hybridGaussianCond->prune(*parentData.prunedDiscreteProbs));
|
||||
}
|
||||
}
|
||||
return parentData;
|
||||
|
|
|
@ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
|||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||
const DiscreteConditional::shared_ptr &discreteProbs) const {
|
||||
// Find keys in discreteProbs->keys() but not in this->keys():
|
||||
const DiscreteConditional &discreteProbs) const {
|
||||
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||
std::set<Key> theirs(discreteProbs->keys().begin(),
|
||||
discreteProbs->keys().end());
|
||||
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||
discreteProbs.keys().end());
|
||||
std::vector<Key> diff;
|
||||
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||
std::back_inserter(diff));
|
||||
|
||||
// Find maximum probability value for every combination of our keys.
|
||||
Ordering keys(diff);
|
||||
auto max = discreteProbs->max(keys);
|
||||
auto max = discreteProbs.max(keys);
|
||||
|
||||
// Check the max value for every combination of our keys.
|
||||
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||
|
|
|
@ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||
*/
|
||||
HybridGaussianConditional::shared_ptr prune(
|
||||
const DiscreteConditional::shared_ptr &discreteProbs) const;
|
||||
const DiscreteConditional &discreteProbs) const;
|
||||
|
||||
/// Return true if the conditional has already been pruned.
|
||||
bool pruned() const { return pruned_; }
|
||||
|
|
|
@ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
potentials[i] = 1;
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
// Prune the HybridGaussianConditional
|
||||
const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>(
|
||||
keys.size(), decisionTreeFactor));
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||
}
|
||||
|
@ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(
|
||||
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||
|
@ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) {
|
|||
0, 0, 0.5, 0};
|
||||
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||
|
||||
const auto pruned = hgc.prune(
|
||||
std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor));
|
||||
const auto pruned =
|
||||
hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor));
|
||||
|
||||
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||
|
|
Loading…
Reference in New Issue