rename asMixture to asHybrid
parent
4016de7942
commit
6929d62300
|
|
@ -180,7 +180,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asHybrid()) {
|
||||||
// Make a copy of the hybrid Gaussian conditional and prune it!
|
// Make a copy of the hybrid Gaussian conditional and prune it!
|
||||||
auto prunedHybridGaussianConditional =
|
auto prunedHybridGaussianConditional =
|
||||||
std::make_shared<HybridGaussianConditional>(*gm);
|
std::make_shared<HybridGaussianConditional>(*gm);
|
||||||
|
|
@ -204,7 +204,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asHybrid()) {
|
||||||
// If conditional is hybrid, select based on assignment.
|
// If conditional is hybrid, select based on assignment.
|
||||||
gbn.push_back((*gm)(assignment));
|
gbn.push_back((*gm)(assignment));
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
|
|
@ -291,7 +291,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asHybrid()) {
|
||||||
// If conditional is hybrid, compute error for all assignments.
|
// If conditional is hybrid, compute error for all assignments.
|
||||||
result = result + gm->errorTree(continuousValues);
|
result = result + gm->errorTree(continuousValues);
|
||||||
|
|
||||||
|
|
@ -321,7 +321,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asHybrid()) {
|
||||||
// If conditional is hybrid, select based on assignment and compute
|
// If conditional is hybrid, select based on assignment and compute
|
||||||
// logProbability.
|
// logProbability.
|
||||||
result = result + gm->logProbability(continuousValues);
|
result = result + gm->logProbability(continuousValues);
|
||||||
|
|
@ -369,7 +369,7 @@ HybridGaussianFactorGraph HybridBayesNet::toFactorGraph(
|
||||||
if (conditional->frontalsIn(measurements)) {
|
if (conditional->frontalsIn(measurements)) {
|
||||||
if (auto gc = conditional->asGaussian()) {
|
if (auto gc = conditional->asGaussian()) {
|
||||||
fg.push_back(gc->likelihood(measurements));
|
fg.push_back(gc->likelihood(measurements));
|
||||||
} else if (auto gm = conditional->asMixture()) {
|
} else if (auto gm = conditional->asHybrid()) {
|
||||||
fg.push_back(gm->likelihood(measurements));
|
fg.push_back(gm->likelihood(measurements));
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Unknown conditional type");
|
throw std::runtime_error("Unknown conditional type");
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,7 @@ struct HybridAssignmentData {
|
||||||
|
|
||||||
GaussianConditional::shared_ptr conditional;
|
GaussianConditional::shared_ptr conditional;
|
||||||
if (hybrid_conditional->isHybrid()) {
|
if (hybrid_conditional->isHybrid()) {
|
||||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
conditional = (*hybrid_conditional->asHybrid())(parentData.assignment_);
|
||||||
} else if (hybrid_conditional->isContinuous()) {
|
} else if (hybrid_conditional->isContinuous()) {
|
||||||
conditional = hybrid_conditional->asGaussian();
|
conditional = hybrid_conditional->asGaussian();
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -205,7 +205,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
|
|
||||||
// If conditional is hybrid, we prune it.
|
// If conditional is hybrid, we prune it.
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
auto hybridGaussianCond = conditional->asMixture();
|
auto hybridGaussianCond = conditional->asHybrid();
|
||||||
|
|
||||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs);
|
hybridGaussianCond->prune(parentData.prunedDiscreteProbs);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,8 @@ void HybridConditional::print(const std::string &s,
|
||||||
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&other);
|
const This *e = dynamic_cast<const This *>(&other);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
if (auto gm = asMixture()) {
|
if (auto gm = asHybrid()) {
|
||||||
auto other = e->asMixture();
|
auto other = e->asHybrid();
|
||||||
return other != nullptr && gm->equals(*other, tol);
|
return other != nullptr && gm->equals(*other, tol);
|
||||||
}
|
}
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
|
|
@ -119,7 +119,7 @@ double HybridConditional::error(const HybridValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->error(values.continuous());
|
return gc->error(values.continuous());
|
||||||
}
|
}
|
||||||
if (auto gm = asMixture()) {
|
if (auto gm = asHybrid()) {
|
||||||
return gm->error(values);
|
return gm->error(values);
|
||||||
}
|
}
|
||||||
if (auto dc = asDiscrete()) {
|
if (auto dc = asDiscrete()) {
|
||||||
|
|
@ -134,7 +134,7 @@ double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->logProbability(values.continuous());
|
return gc->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
if (auto gm = asMixture()) {
|
if (auto gm = asHybrid()) {
|
||||||
return gm->logProbability(values);
|
return gm->logProbability(values);
|
||||||
}
|
}
|
||||||
if (auto dc = asDiscrete()) {
|
if (auto dc = asDiscrete()) {
|
||||||
|
|
@ -149,7 +149,7 @@ double HybridConditional::logNormalizationConstant() const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->logNormalizationConstant();
|
return gc->logNormalizationConstant();
|
||||||
}
|
}
|
||||||
if (auto gm = asMixture()) {
|
if (auto gm = asHybrid()) {
|
||||||
return gm->logNormalizationConstant(); // 0.0!
|
return gm->logNormalizationConstant(); // 0.0!
|
||||||
}
|
}
|
||||||
if (auto dc = asDiscrete()) {
|
if (auto dc = asDiscrete()) {
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @return nullptr if not a conditional
|
* @return nullptr if not a conditional
|
||||||
* @return HybridGaussianConditional::shared_ptr otherwise
|
* @return HybridGaussianConditional::shared_ptr otherwise
|
||||||
*/
|
*/
|
||||||
HybridGaussianConditional::shared_ptr asMixture() const {
|
HybridGaussianConditional::shared_ptr asHybrid() const {
|
||||||
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
|
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
} else {
|
} else {
|
||||||
// Is hybrid
|
// Is hybrid
|
||||||
auto conditionalComponent =
|
auto conditionalComponent =
|
||||||
hc->asMixture()->operator()(values.discrete());
|
hc->asHybrid()->operator()(values.discrete());
|
||||||
conditionalComponent->print(ss.str(), keyFormatter);
|
conditionalComponent->print(ss.str(), keyFormatter);
|
||||||
std::cout << "error = " << conditionalComponent->error(values)
|
std::cout << "error = " << conditionalComponent->error(values)
|
||||||
<< "\n";
|
<< "\n";
|
||||||
|
|
@ -184,7 +184,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
||||||
result = gm->add(result);
|
result = gm->add(result);
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
if (auto gm = hc->asMixture()) {
|
if (auto gm = hc->asHybrid()) {
|
||||||
result = gm->add(result);
|
result = gm->add(result);
|
||||||
} else if (auto g = hc->asGaussian()) {
|
} else if (auto g = hc->asGaussian()) {
|
||||||
result = addGaussian(result, g);
|
result = addGaussian(result, g);
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
HybridGaussianConditional::shared_ptr HybridSmoother::gaussianMixture(
|
HybridGaussianConditional::shared_ptr HybridSmoother::gaussianMixture(
|
||||||
size_t index) const {
|
size_t index) const {
|
||||||
return hybridBayesNet_.at(index)->asMixture();
|
return hybridBayesNet_.at(index)->asHybrid();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ virtual class HybridConditional {
|
||||||
double logProbability(const gtsam::HybridValues& values) const;
|
double logProbability(const gtsam::HybridValues& values) const;
|
||||||
double evaluate(const gtsam::HybridValues& values) const;
|
double evaluate(const gtsam::HybridValues& values) const;
|
||||||
double operator()(const gtsam::HybridValues& values) const;
|
double operator()(const gtsam::HybridValues& values) const;
|
||||||
gtsam::HybridGaussianConditional* asMixture() const;
|
gtsam::HybridGaussianConditional* asHybrid() const;
|
||||||
gtsam::GaussianConditional* asGaussian() const;
|
gtsam::GaussianConditional* asGaussian() const;
|
||||||
gtsam::DiscreteConditional* asDiscrete() const;
|
gtsam::DiscreteConditional* asDiscrete() const;
|
||||||
gtsam::Factor* inner();
|
gtsam::Factor* inner();
|
||||||
|
|
|
||||||
|
|
@ -144,13 +144,13 @@ TEST(HybridBayesNet, Choose) {
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(4, gbn.size());
|
EXPECT_LONGS_EQUAL(4, gbn.size());
|
||||||
|
|
||||||
EXPECT(assert_equal(*(*hybridBayesNet->at(0)->asMixture())(assignment),
|
EXPECT(assert_equal(*(*hybridBayesNet->at(0)->asHybrid())(assignment),
|
||||||
*gbn.at(0)));
|
*gbn.at(0)));
|
||||||
EXPECT(assert_equal(*(*hybridBayesNet->at(1)->asMixture())(assignment),
|
EXPECT(assert_equal(*(*hybridBayesNet->at(1)->asHybrid())(assignment),
|
||||||
*gbn.at(1)));
|
*gbn.at(1)));
|
||||||
EXPECT(assert_equal(*(*hybridBayesNet->at(2)->asMixture())(assignment),
|
EXPECT(assert_equal(*(*hybridBayesNet->at(2)->asHybrid())(assignment),
|
||||||
*gbn.at(2)));
|
*gbn.at(2)));
|
||||||
EXPECT(assert_equal(*(*hybridBayesNet->at(3)->asMixture())(assignment),
|
EXPECT(assert_equal(*(*hybridBayesNet->at(3)->asHybrid())(assignment),
|
||||||
*gbn.at(3)));
|
*gbn.at(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -280,9 +280,9 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||||
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
const HybridValues hybridValues{delta.continuous(), discrete_values};
|
||||||
double logProbability = 0;
|
double logProbability = 0;
|
||||||
logProbability += posterior->at(0)->asMixture()->logProbability(hybridValues);
|
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||||
logProbability += posterior->at(1)->asMixture()->logProbability(hybridValues);
|
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
|
||||||
logProbability += posterior->at(2)->asMixture()->logProbability(hybridValues);
|
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
|
||||||
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
|
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ TEST(HybridConditional, Invariants) {
|
||||||
CHECK(hc0->isHybrid());
|
CHECK(hc0->isHybrid());
|
||||||
|
|
||||||
// Check invariants as a HybridGaussianConditional.
|
// Check invariants as a HybridGaussianConditional.
|
||||||
const auto conditional = hc0->asMixture();
|
const auto conditional = hc0->asHybrid();
|
||||||
EXPECT(HybridGaussianConditional::CheckInvariants(*conditional, values));
|
EXPECT(HybridGaussianConditional::CheckInvariants(*conditional, values));
|
||||||
|
|
||||||
// Check invariants as a HybridConditional.
|
// Check invariants as a HybridConditional.
|
||||||
|
|
|
||||||
|
|
@ -333,13 +333,13 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
||||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||||
EXPECT_LONGS_EQUAL(4, incrementalHybrid.size());
|
EXPECT_LONGS_EQUAL(4, incrementalHybrid.size());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
2, incrementalHybrid[X(0)]->conditional()->asMixture()->nrComponents());
|
2, incrementalHybrid[X(0)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
3, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
|
3, incrementalHybrid[X(1)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
|
5, incrementalHybrid[X(2)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
|
5, incrementalHybrid[X(3)]->conditional()->asHybrid()->nrComponents());
|
||||||
|
|
||||||
/***** Run Round 2 *****/
|
/***** Run Round 2 *****/
|
||||||
HybridGaussianFactorGraph graph2;
|
HybridGaussianFactorGraph graph2;
|
||||||
|
|
@ -354,9 +354,9 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
||||||
// with 5 (pruned) leaves.
|
// with 5 (pruned) leaves.
|
||||||
CHECK_EQUAL(5, incrementalHybrid.size());
|
CHECK_EQUAL(5, incrementalHybrid.size());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
|
5, incrementalHybrid[X(3)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, incrementalHybrid[X(4)]->conditional()->asMixture()->nrComponents());
|
5, incrementalHybrid[X(4)]->conditional()->asHybrid()->nrComponents());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************/
|
/* ************************************************************************/
|
||||||
|
|
@ -548,7 +548,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
||||||
|
|
||||||
// Test if pruning worked correctly by checking that we only have 3 leaves in
|
// Test if pruning worked correctly by checking that we only have 3 leaves in
|
||||||
// the last node.
|
// the last node.
|
||||||
auto lastConditional = inc[X(3)]->conditional()->asMixture();
|
auto lastConditional = inc[X(3)]->conditional()->asHybrid();
|
||||||
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
|
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -358,13 +358,13 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
||||||
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
EXPECT_LONGS_EQUAL(4, bayesTree.size());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
2, bayesTree[X(0)]->conditional()->asMixture()->nrComponents());
|
2, bayesTree[X(0)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
3, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
|
3, bayesTree[X(1)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
|
5, bayesTree[X(2)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
|
5, bayesTree[X(3)]->conditional()->asHybrid()->nrComponents());
|
||||||
|
|
||||||
/***** Run Round 2 *****/
|
/***** Run Round 2 *****/
|
||||||
HybridGaussianFactorGraph graph2;
|
HybridGaussianFactorGraph graph2;
|
||||||
|
|
@ -382,9 +382,9 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
// with 5 (pruned) leaves.
|
// with 5 (pruned) leaves.
|
||||||
CHECK_EQUAL(5, bayesTree.size());
|
CHECK_EQUAL(5, bayesTree.size());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
|
5, bayesTree[X(3)]->conditional()->asHybrid()->nrComponents());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
5, bayesTree[X(4)]->conditional()->asMixture()->nrComponents());
|
5, bayesTree[X(4)]->conditional()->asHybrid()->nrComponents());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************/
|
/* ************************************************************************/
|
||||||
|
|
@ -569,7 +569,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
||||||
|
|
||||||
// Test if pruning worked correctly by checking that
|
// Test if pruning worked correctly by checking that
|
||||||
// we only have 3 leaves in the last node.
|
// we only have 3 leaves in the last node.
|
||||||
auto lastConditional = bayesTree[X(3)]->conditional()->asMixture();
|
auto lastConditional = bayesTree[X(3)]->conditional()->asHybrid();
|
||||||
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
|
EXPECT_LONGS_EQUAL(3, lastConditional->nrComponents());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue