Merge pull request #1854 from borglab/feature/more_testing

More tests, some small bugfixes
release/4.3a0
Frank Dellaert 2024-09-30 15:26:28 -07:00 committed by GitHub
commit caa3821b2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 200 additions and 117 deletions

View File

@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first) // sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end()); for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) { it != std::make_reverse_iterator(begin()); ++it) {
(*it)->sampleInPlace(&result); const DiscreteConditional::shared_ptr& conditional = *it;
// Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) {
conditional->sampleInPlace(&result);
}
} }
return result; return result;
} }

View File

@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
/* ************************************************************************** */ /* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1); // throw if more than one frontal:
Key j = (firstFrontalKey()); if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace can only be called on single "
"variable conditionals");
}
Key j = firstFrontalKey();
// throw if values already contains j:
if (values->count(j) > 0) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j");
}
size_t sampled = sample(*values); // Sample variable given parents size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution (*values)[j] = sampled; // store result in partial solution
} }
@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { double DiscreteConditional::negLogConstant() const { return 0.0; }
return 0.0;
}
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter); static_cast<const BaseConditional*>(this)->print(s, formatter);
} }
/// Evaluate, just look up in AlgebraicDecisonTree /// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const { double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values); return ADT::operator()(values);
} }

View File

@ -206,7 +206,7 @@ GaussianBayesNet HybridBayesNet::choose(
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) { if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment. // If conditional is hybrid, select based on assignment.
gbn.push_back((*gm)(assignment)); gbn.push_back(gm->choose(assignment));
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional. // If continuous only, add Gaussian conditional.
gbn.push_back(gc); gbn.push_back(gc);

View File

@ -127,6 +127,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment. * value assignment.
* *
* @note Any pure discrete factors are ignored.
*
* @param assignment The discrete value assignment for the discrete keys. * @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesNet * @return GaussianBayesNet
*/ */

View File

@ -168,7 +168,7 @@ size_t HybridGaussianConditional::nrComponents() const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::operator()( GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const { const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues); auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
@ -192,11 +192,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
// Check the base and the factors: // Check the base and the factors:
return BaseFactor::equals(*e, tol) && return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_, conditionals_.equals(
[tol](const GaussianConditional::shared_ptr &f1, e->conditionals_, [tol](const auto &f1, const auto &f2) {
const GaussianConditional::shared_ptr &f2) { return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return f1->equals(*(f2), tol); });
});
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -159,9 +159,15 @@ class GTSAM_EXPORT HybridGaussianConditional
/// @{ /// @{
/// @brief Return the conditional Gaussian for the given discrete assignment. /// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()( GaussianConditional::shared_ptr choose(
const DiscreteValues &discreteValues) const; const DiscreteValues &discreteValues) const;
/// @brief Syntactic sugar for choose.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const {
return choose(discreteValues);
}
/// Returns the total number of continuous components /// Returns the total number of continuous components
size_t nrComponents() const; size_t nrComponents() const;

View File

@ -154,10 +154,9 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && return Base::equals(*e, tol) &&
factors_.equals(e->factors_, factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
[tol](const sharedFactor &f1, const sharedFactor &f2) { return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
return f1->equals(*f2, tol); });
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -213,16 +212,15 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
} }
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError( /// Helper method to compute the error of a component.
const sharedFactor &gf, const VectorValues &values) const { static double PotentiallyPrunedComponentError(
const GaussianFactor::shared_ptr &gf, const VectorValues &values) {
// Check if valid pointer // Check if valid pointer
if (gf) { if (gf) {
return gf->error(values); return gf->error(values);
} else { } else {
// If not valid, pointer, it means this component was pruned, // If nullptr this component was pruned, so we return maximum error. This
// so we return maximum error. // way the negative exponential will give a probability value close to 0.0.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
} }
} }
@ -231,8 +229,8 @@ double HybridGaussianFactor::potentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) { auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues); return PotentiallyPrunedComponentError(gf, continuousValues);
}; };
DecisionTree<Key, double> error_tree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree; return error_tree;
@ -242,7 +240,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
double HybridGaussianFactor::error(const HybridValues &values) const { double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree. // Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete()); const sharedFactor gf = factors_(values.discrete());
return potentiallyPrunedComponentError(gf, values.continuous()); return PotentiallyPrunedComponentError(gf, values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -189,10 +189,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/ */
static Factors augment(const FactorValuePairs &factors); static Factors augment(const FactorValuePairs &factors);
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;
/// Helper struct to assist private constructor below. /// Helper struct to assist private constructor below.
struct ConstructorHelper; struct ConstructorHelper;

View File

@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true); index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
} }
/* ************************************************************************ */
static void printFactor(const std::shared_ptr<Factor> &factor,
const DiscreteValues &assignment,
const KeyFormatter &keyFormatter) {
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
hgf->operator()(assignment)
->print("HybridGaussianFactor, component:", keyFormatter);
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
factor->print("GaussianFactor:\n", keyFormatter);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
factor->print("DiscreteFactor:\n", keyFormatter);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (hc->isContinuous()) {
factor->print("GaussianConditional:\n", keyFormatter);
} else if (hc->isDiscrete()) {
factor->print("DiscreteConditional:\n", keyFormatter);
} else {
hc->asHybrid()
->choose(assignment)
->print("HybridConditional, component:\n", keyFormatter);
}
} else {
factor->print("Unknown factor type\n", keyFormatter);
}
}
/* ************************************************************************ */ /* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors( void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str, const HybridValues &values, const std::string &str,
@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
&printCondition) const { &printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl; std::cout << str << "size: " << size() << std::endl << std::endl;
std::stringstream ss;
for (size_t i = 0; i < factors_.size(); i++) { for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i]; auto &&factor = factors_[i];
std::cout << "Factor " << i << ": "; if (factor == nullptr) {
std::cout << "Factor " << i << ": nullptr\n";
// Clear the stringstream
ss.str(std::string());
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = " << factor->error(values) << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
<< "\n";
} else {
// Is hybrid
auto conditionalComponent =
hc->asHybrid()->operator()(values.discrete());
conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << df->error(values.discrete()) << std::endl;
}
} else {
continue; continue;
} }
const double errorValue = factor->error(values);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass
// Print the factor
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
printFactor(factor, values.discrete(), keyFormatter);
std::cout << "\n"; std::cout << "\n";
} }
std::cout.flush(); std::cout.flush();

View File

@ -231,4 +231,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const; GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
}; };
// traits
template <>
struct traits<HybridGaussianFactorGraph>
: public Testable<HybridGaussianFactorGraph> {};
} // namespace gtsam } // namespace gtsam

View File

@ -114,11 +114,11 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
return {new_order, levels}; return {new_order, levels};
} }
/* *************************************************************************** /* ****************************************************************************/
*/
using MotionModel = BetweenFactor<double>; using MotionModel = BetweenFactor<double>;
// Test fixture with switching network. // Test fixture with switching network.
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1))
struct Switching { struct Switching {
size_t K; size_t K;
DiscreteKeys modes; DiscreteKeys modes;
@ -140,8 +140,8 @@ struct Switching {
: K(K) { : K(K) {
using noiseModel::Isotropic; using noiseModel::Isotropic;
// Create DiscreteKeys for binary K modes. // Create DiscreteKeys for K-1 binary modes.
for (size_t k = 0; k < K; k++) { for (size_t k = 0; k < K - 1; k++) {
modes.emplace_back(M(k), 2); modes.emplace_back(M(k), 2);
} }
@ -153,25 +153,26 @@ struct Switching {
} }
// Create hybrid factor graph. // Create hybrid factor graph.
// Add a prior on X(0).
// Add a prior ϕ(X(0)) on X(0).
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>( nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma)); X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));
// Add "motion models". // Add "motion models" ϕ(X(k),X(k+1)).
for (size_t k = 0; k < K - 1; k++) { for (size_t k = 0; k < K - 1; k++) {
auto motion_models = motionModels(k, between_sigma); auto motion_models = motionModels(k, between_sigma);
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k], nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
motion_models); motion_models);
} }
// Add measurement factors // Add measurement factors ϕ(X(k);z_k).
auto measurement_noise = Isotropic::Sigma(1, prior_sigma); auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
for (size_t k = 1; k < K; k++) { for (size_t k = 1; k < K; k++) {
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>( nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(k), measurements.at(k), measurement_noise); X(k), measurements.at(k), measurement_noise);
} }
// Add "mode chain" // Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2))
addModeChain(&nonlinearFactorGraph, discrete_transition_prob); addModeChain(&nonlinearFactorGraph, discrete_transition_prob);
// Create the linearization point. // Create the linearization point.
@ -179,8 +180,6 @@ struct Switching {
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1)); linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
} }
// The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint); linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
} }
@ -196,7 +195,7 @@ struct Switching {
} }
/** /**
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1).
* E.g. if K=4, we want M0, M1 and M2. * E.g. if K=4, we want M0, M1 and M2.
* *
* @param fg The factor graph to which the mode chain is added. * @param fg The factor graph to which the mode chain is added.

View File

@ -62,32 +62,117 @@ TEST(HybridBayesNet, Add) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia). // Test API for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, EvaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6"); const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
HybridValues values; bayesNet.push_back(pAsia);
values.insert(asiaKey, 0); HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
// choose
GaussianBayesNet empty;
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
// evaluate
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
// optimize
EXPECT(assert_equal(one, bayesNet.optimize()));
EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete())));
// sample
std::mt19937_64 rng(42);
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
// error
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
// logProbability
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
// toFactorGraph
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
EXPECT(assert_equal(expectedFG, fg));
// prune, imperative :-(
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net. // Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) { TEST(HybridBayesNet, Tiny) {
auto bn = tiny::createHybridBayesNet(); auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode)
EXPECT_LONGS_EQUAL(3, bn.size()); EXPECT_LONGS_EQUAL(3, bayesNet.size());
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}}; const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
auto fg = bn.toFactorGraph(vv); HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
// Check Invariants for components
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
gc1 = hgc->choose(one.discrete());
GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
GaussianConditional::CheckInvariants(*gc0, vv);
GaussianConditional::CheckInvariants(*gc1, vv);
GaussianConditional::CheckInvariants(*px, vv);
HybridGaussianConditional::CheckInvariants(*hgc, zero);
HybridGaussianConditional::CheckInvariants(*hgc, one);
// choose
GaussianBayesNet expectedChosen;
expectedChosen.push_back(gc0);
expectedChosen.push_back(px);
auto chosen0 = bayesNet.choose(zero.discrete());
auto chosen1 = bayesNet.choose(one.discrete());
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
// logProbability
const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
// evaluate
EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9);
// optimize
EXPECT(assert_equal(one, bayesNet.optimize()));
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
// sample
std::mt19937_64 rng(42);
EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
// error
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
px->negLogConstant() - log(0.4);
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
px->negLogConstant() - log(0.6);
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
// toFactorGraph
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
EXPECT_LONGS_EQUAL(3, fg.size()); EXPECT_LONGS_EQUAL(3, fg.size());
// Check that the ratio of probPrime to evaluate is the same for all modes. // Check that the ratio of probPrime to evaluate is the same for all modes.
std::vector<double> ratio(2); std::vector<double> ratio(2);
for (size_t mode : {0, 1}) { ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
const HybridValues hv{vv, {{M(0), mode}}}; ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
}
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8); EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
// prune, imperative :-(
auto pruned = bayesNet.prune(1);
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
EXPECT(!pruned.equals(bayesNet));
} }
/* ****************************************************************************/ /* ****************************************************************************/
@ -223,12 +308,15 @@ TEST(HybridBayesNet, Optimize) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test Bayes net error // Test Bayes net error
TEST(HybridBayesNet, Pruning) { TEST(HybridBayesNet, Pruning) {
// Create switching network with three continuous variables and two discrete:
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
Switching s(3); Switching s(3);
HybridBayesNet::shared_ptr posterior = HybridBayesNet::shared_ptr posterior =
s.linearizedFactorGraph.eliminateSequential(); s.linearizedFactorGraph.eliminateSequential();
EXPECT_LONGS_EQUAL(5, posterior->size()); EXPECT_LONGS_EQUAL(5, posterior->size());
// Optimize
HybridValues delta = posterior->optimize(); HybridValues delta = posterior->optimize();
auto actualTree = posterior->evaluate(delta.continuous()); auto actualTree = posterior->evaluate(delta.continuous());
@ -254,7 +342,6 @@ TEST(HybridBayesNet, Pruning) {
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues); logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
logProbability += logProbability +=
posterior->at(3)->asDiscrete()->logProbability(hybridValues); posterior->at(3)->asDiscrete()->logProbability(hybridValues);
logProbability += logProbability +=
@ -316,10 +403,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
#endif #endif
// regression // regression
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
DecisionTreeFactor::ADT potentials( DecisionTreeFactor::ADT potentials(
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); s.modes, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
// Prune! // Prune!
posterior->prune(maxNrLeaves); posterior->prune(maxNrLeaves);

View File

@ -168,6 +168,9 @@ TEST(HybridGaussianConditional, ContinuousParents) {
// Check that the continuous parent keys are correct: // Check that the continuous parent keys are correct:
EXPECT(continuousParentKeys.size() == 1); EXPECT(continuousParentKeys.size() == 1);
EXPECT(continuousParentKeys[0] == X(0)); EXPECT(continuousParentKeys[0] == X(0));
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv0));
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv1));
} }
/* ************************************************************************* */ /* ************************************************************************* */