Merge pull request #1854 from borglab/feature/more_testing
More tests, some small bugfixesrelease/4.3a0
commit
caa3821b2b
|
@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
// sample each node in turn in topological sort order (parents first)
|
||||
for (auto it = std::make_reverse_iterator(end());
|
||||
it != std::make_reverse_iterator(begin()); ++it) {
|
||||
(*it)->sampleInPlace(&result);
|
||||
const DiscreteConditional::shared_ptr& conditional = *it;
|
||||
// Sample the conditional only if value for j not already in result
|
||||
const Key j = conditional->firstFrontalKey();
|
||||
if (result.count(j) == 0) {
|
||||
conditional->sampleInPlace(&result);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
|||
|
||||
/* ************************************************************************** */
|
||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||
assert(nrFrontals() == 1);
|
||||
Key j = (firstFrontalKey());
|
||||
// throw if more than one frontal:
|
||||
if (nrFrontals() != 1) {
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::sampleInPlace can only be called on single "
|
||||
"variable conditionals");
|
||||
}
|
||||
Key j = firstFrontalKey();
|
||||
// throw if values already contains j:
|
||||
if (values->count(j) > 0) {
|
||||
throw std::invalid_argument(
|
||||
"DiscreteConditional::sampleInPlace: values already contains j");
|
||||
}
|
||||
size_t sampled = sample(*values); // Sample variable given parents
|
||||
(*values)[j] = sampled; // store result in partial solution
|
||||
}
|
||||
|
@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double DiscreteConditional::negLogConstant() const {
|
||||
return 0.0;
|
||||
}
|
||||
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
||||
|
|
|
@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
||||
}
|
||||
|
||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
||||
/// Evaluate, just look up in AlgebraicDecisionTree
|
||||
double evaluate(const DiscreteValues& values) const {
|
||||
return ADT::operator()(values);
|
||||
}
|
||||
|
|
|
@ -206,7 +206,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment.
|
||||
gbn.push_back((*gm)(assignment));
|
||||
gbn.push_back(gm->choose(assignment));
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, add Gaussian conditional.
|
||||
gbn.push_back(gc);
|
||||
|
|
|
@ -127,6 +127,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
||||
* value assignment.
|
||||
*
|
||||
* @note Any pure discrete factors are ignored.
|
||||
*
|
||||
* @param assignment The discrete value assignment for the discrete keys.
|
||||
* @return GaussianBayesNet
|
||||
*/
|
||||
|
|
|
@ -168,7 +168,7 @@ size_t HybridGaussianConditional::nrComponents() const {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
|
||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||
const DiscreteValues &discreteValues) const {
|
||||
auto &ptr = conditionals_(discreteValues);
|
||||
if (!ptr) return nullptr;
|
||||
|
@ -192,10 +192,9 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
|||
|
||||
// Check the base and the factors:
|
||||
return BaseFactor::equals(*e, tol) &&
|
||||
conditionals_.equals(e->conditionals_,
|
||||
[tol](const GaussianConditional::shared_ptr &f1,
|
||||
const GaussianConditional::shared_ptr &f2) {
|
||||
return f1->equals(*(f2), tol);
|
||||
conditionals_.equals(
|
||||
e->conditionals_, [tol](const auto &f1, const auto &f2) {
|
||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -159,9 +159,15 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
/// @{
|
||||
|
||||
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
GaussianConditional::shared_ptr choose(
|
||||
const DiscreteValues &discreteValues) const;
|
||||
|
||||
/// @brief Syntactic sugar for choose.
|
||||
GaussianConditional::shared_ptr operator()(
|
||||
const DiscreteValues &discreteValues) const {
|
||||
return choose(discreteValues);
|
||||
}
|
||||
|
||||
/// Returns the total number of continuous components
|
||||
size_t nrComponents() const;
|
||||
|
||||
|
|
|
@ -154,9 +154,8 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
|
||||
// Check the base and the factors:
|
||||
return Base::equals(*e, tol) &&
|
||||
factors_.equals(e->factors_,
|
||||
[tol](const sharedFactor &f1, const sharedFactor &f2) {
|
||||
return f1->equals(*f2, tol);
|
||||
factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
|
||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -213,16 +212,15 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
double HybridGaussianFactor::potentiallyPrunedComponentError(
|
||||
const sharedFactor &gf, const VectorValues &values) const {
|
||||
/// Helper method to compute the error of a component.
|
||||
static double PotentiallyPrunedComponentError(
|
||||
const GaussianFactor::shared_ptr &gf, const VectorValues &values) {
|
||||
// Check if valid pointer
|
||||
if (gf) {
|
||||
return gf->error(values);
|
||||
} else {
|
||||
// If not valid, pointer, it means this component was pruned,
|
||||
// so we return maximum error.
|
||||
// This way the negative exponential will give
|
||||
// a probability value close to 0.0.
|
||||
// If nullptr this component was pruned, so we return maximum error. This
|
||||
// way the negative exponential will give a probability value close to 0.0.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
}
|
||||
|
@ -231,8 +229,8 @@ double HybridGaussianFactor::potentiallyPrunedComponentError(
|
|||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
// functor to convert from sharedFactor to double error value.
|
||||
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
|
||||
return this->potentiallyPrunedComponentError(gf, continuousValues);
|
||||
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
|
||||
return PotentiallyPrunedComponentError(gf, continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||
return error_tree;
|
||||
|
@ -242,7 +240,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
|||
double HybridGaussianFactor::error(const HybridValues &values) const {
|
||||
// Directly index to get the component, no need to build the whole tree.
|
||||
const sharedFactor gf = factors_(values.discrete());
|
||||
return potentiallyPrunedComponentError(gf, values.continuous());
|
||||
return PotentiallyPrunedComponentError(gf, values.continuous());
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -189,10 +189,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
*/
|
||||
static Factors augment(const FactorValuePairs &factors);
|
||||
|
||||
/// Helper method to compute the error of a component.
|
||||
double potentiallyPrunedComponentError(
|
||||
const sharedFactor &gf, const VectorValues &continuousValues) const;
|
||||
|
||||
/// Helper struct to assist private constructor below.
|
||||
struct ConstructorHelper;
|
||||
|
||||
|
|
|
@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
|||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||
const DiscreteValues &assignment,
|
||||
const KeyFormatter &keyFormatter) {
|
||||
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
hgf->operator()(assignment)
|
||||
->print("HybridGaussianFactor, component:", keyFormatter);
|
||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
factor->print("GaussianFactor:\n", keyFormatter);
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
factor->print("DiscreteFactor:\n", keyFormatter);
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
if (hc->isContinuous()) {
|
||||
factor->print("GaussianConditional:\n", keyFormatter);
|
||||
} else if (hc->isDiscrete()) {
|
||||
factor->print("DiscreteConditional:\n", keyFormatter);
|
||||
} else {
|
||||
hc->asHybrid()
|
||||
->choose(assignment)
|
||||
->print("HybridConditional, component:\n", keyFormatter);
|
||||
}
|
||||
} else {
|
||||
factor->print("Unknown factor type\n", keyFormatter);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridGaussianFactorGraph::printErrors(
|
||||
const HybridValues &values, const std::string &str,
|
||||
|
@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
&printCondition) const {
|
||||
std::cout << str << "size: " << size() << std::endl << std::endl;
|
||||
|
||||
std::stringstream ss;
|
||||
|
||||
for (size_t i = 0; i < factors_.size(); i++) {
|
||||
auto &&factor = factors_[i];
|
||||
std::cout << "Factor " << i << ": ";
|
||||
|
||||
// Clear the stringstream
|
||||
ss.str(std::string());
|
||||
|
||||
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << factor->error(values) << std::endl;
|
||||
std::cout << "Factor " << i << ": nullptr\n";
|
||||
continue;
|
||||
}
|
||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
if (hc->isContinuous()) {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
||||
} else if (hc->isDiscrete()) {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
|
||||
<< "\n";
|
||||
} else {
|
||||
// Is hybrid
|
||||
auto conditionalComponent =
|
||||
hc->asHybrid()->operator()(values.discrete());
|
||||
conditionalComponent->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << conditionalComponent->error(values)
|
||||
<< "\n";
|
||||
}
|
||||
}
|
||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
|
||||
const double errorValue = factor->error(values);
|
||||
if (!printCondition(factor.get(), errorValue, i))
|
||||
continue; // User-provided filter did not pass
|
||||
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << errorValue << "\n";
|
||||
}
|
||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||
if (factor == nullptr) {
|
||||
std::cout << "nullptr"
|
||||
<< "\n";
|
||||
} else {
|
||||
factor->print(ss.str(), keyFormatter);
|
||||
std::cout << "error = " << df->error(values.discrete()) << std::endl;
|
||||
}
|
||||
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Print the factor
|
||||
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
|
||||
printFactor(factor, values.discrete(), keyFormatter);
|
||||
std::cout << "\n";
|
||||
}
|
||||
std::cout.flush();
|
||||
|
|
|
@ -231,4 +231,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
|
||||
};
|
||||
|
||||
// traits
|
||||
template <>
|
||||
struct traits<HybridGaussianFactorGraph>
|
||||
: public Testable<HybridGaussianFactorGraph> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -114,11 +114,11 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
|||
return {new_order, levels};
|
||||
}
|
||||
|
||||
/* ***************************************************************************
|
||||
*/
|
||||
/* ****************************************************************************/
|
||||
using MotionModel = BetweenFactor<double>;
|
||||
|
||||
// Test fixture with switching network.
|
||||
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1))
|
||||
struct Switching {
|
||||
size_t K;
|
||||
DiscreteKeys modes;
|
||||
|
@ -140,8 +140,8 @@ struct Switching {
|
|||
: K(K) {
|
||||
using noiseModel::Isotropic;
|
||||
|
||||
// Create DiscreteKeys for binary K modes.
|
||||
for (size_t k = 0; k < K; k++) {
|
||||
// Create DiscreteKeys for K-1 binary modes.
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
modes.emplace_back(M(k), 2);
|
||||
}
|
||||
|
||||
|
@ -153,25 +153,26 @@ struct Switching {
|
|||
}
|
||||
|
||||
// Create hybrid factor graph.
|
||||
// Add a prior on X(0).
|
||||
|
||||
// Add a prior ϕ(X(0)) on X(0).
|
||||
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
|
||||
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));
|
||||
|
||||
// Add "motion models".
|
||||
// Add "motion models" ϕ(X(k),X(k+1)).
|
||||
for (size_t k = 0; k < K - 1; k++) {
|
||||
auto motion_models = motionModels(k, between_sigma);
|
||||
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
|
||||
motion_models);
|
||||
}
|
||||
|
||||
// Add measurement factors
|
||||
// Add measurement factors ϕ(X(k);z_k).
|
||||
auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
|
||||
for (size_t k = 1; k < K; k++) {
|
||||
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
|
||||
X(k), measurements.at(k), measurement_noise);
|
||||
}
|
||||
|
||||
// Add "mode chain"
|
||||
// Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2))
|
||||
addModeChain(&nonlinearFactorGraph, discrete_transition_prob);
|
||||
|
||||
// Create the linearization point.
|
||||
|
@ -179,8 +180,6 @@ struct Switching {
|
|||
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
|
||||
}
|
||||
|
||||
// The ground truth is robot moving forward
|
||||
// and one less than the linearization point
|
||||
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
|
||||
}
|
||||
|
||||
|
@ -196,7 +195,7 @@ struct Switching {
|
|||
}
|
||||
|
||||
/**
|
||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
|
||||
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1).
|
||||
* E.g. if K=4, we want M0, M1 and M2.
|
||||
*
|
||||
* @param fg The factor graph to which the mode chain is added.
|
||||
|
|
|
@ -62,32 +62,117 @@ TEST(HybridBayesNet, Add) {
|
|||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
||||
// Test API for a pure discrete Bayes net P(Asia).
|
||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||
HybridBayesNet bayesNet;
|
||||
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
|
||||
HybridValues values;
|
||||
values.insert(asiaKey, 0);
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
|
||||
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
|
||||
bayesNet.push_back(pAsia);
|
||||
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
||||
|
||||
// choose
|
||||
GaussianBayesNet empty;
|
||||
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
|
||||
|
||||
// evaluate
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
|
||||
|
||||
// optimize
|
||||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
|
||||
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||
|
||||
// error
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||
|
||||
// logProbability
|
||||
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||
EXPECT(assert_equal(expectedFG, fg));
|
||||
|
||||
// prune, imperative :-(
|
||||
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test creation of a tiny hybrid Bayes net.
|
||||
TEST(HybridBayesNet, Tiny) {
|
||||
auto bn = tiny::createHybridBayesNet();
|
||||
EXPECT_LONGS_EQUAL(3, bn.size());
|
||||
auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode)
|
||||
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
||||
|
||||
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
||||
auto fg = bn.toFactorGraph(vv);
|
||||
HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
|
||||
|
||||
// Check Invariants for components
|
||||
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
|
||||
GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
|
||||
gc1 = hgc->choose(one.discrete());
|
||||
GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
|
||||
GaussianConditional::CheckInvariants(*gc0, vv);
|
||||
GaussianConditional::CheckInvariants(*gc1, vv);
|
||||
GaussianConditional::CheckInvariants(*px, vv);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, zero);
|
||||
HybridGaussianConditional::CheckInvariants(*hgc, one);
|
||||
|
||||
// choose
|
||||
GaussianBayesNet expectedChosen;
|
||||
expectedChosen.push_back(gc0);
|
||||
expectedChosen.push_back(px);
|
||||
auto chosen0 = bayesNet.choose(zero.discrete());
|
||||
auto chosen1 = bayesNet.choose(one.discrete());
|
||||
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
|
||||
|
||||
// logProbability
|
||||
const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
|
||||
const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
|
||||
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
|
||||
|
||||
// evaluate
|
||||
EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9);
|
||||
|
||||
// optimize
|
||||
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
|
||||
|
||||
// sample
|
||||
std::mt19937_64 rng(42);
|
||||
EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||
|
||||
// error
|
||||
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
|
||||
px->negLogConstant() - log(0.4);
|
||||
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||
px->negLogConstant() - log(0.6);
|
||||
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||
|
||||
// toFactorGraph
|
||||
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||
std::vector<double> ratio(2);
|
||||
for (size_t mode : {0, 1}) {
|
||||
const HybridValues hv{vv, {{M(0), mode}}};
|
||||
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
|
||||
}
|
||||
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
||||
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||
|
||||
// prune, imperative :-(
|
||||
auto pruned = bayesNet.prune(1);
|
||||
EXPECT_LONGS_EQUAL(1, pruned.at(0)->asHybrid()->nrComponents());
|
||||
EXPECT(!pruned.equals(bayesNet));
|
||||
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
|
@ -223,12 +308,15 @@ TEST(HybridBayesNet, Optimize) {
|
|||
/* ****************************************************************************/
|
||||
// Test Bayes net error
|
||||
TEST(HybridBayesNet, Pruning) {
|
||||
// Create switching network with three continuous variables and two discrete:
|
||||
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||
Switching s(3);
|
||||
|
||||
HybridBayesNet::shared_ptr posterior =
|
||||
s.linearizedFactorGraph.eliminateSequential();
|
||||
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||
|
||||
// Optimize
|
||||
HybridValues delta = posterior->optimize();
|
||||
auto actualTree = posterior->evaluate(delta.continuous());
|
||||
|
||||
|
@ -254,7 +342,6 @@ TEST(HybridBayesNet, Pruning) {
|
|||
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
|
||||
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
|
||||
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
|
||||
logProbability +=
|
||||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||
logProbability +=
|
||||
|
@ -316,10 +403,9 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
#endif
|
||||
|
||||
// regression
|
||||
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
||||
DecisionTreeFactor::ADT potentials(
|
||||
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
||||
s.modes, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
||||
DiscreteConditional expected_discrete_conditionals(1, s.modes, potentials);
|
||||
|
||||
// Prune!
|
||||
posterior->prune(maxNrLeaves);
|
||||
|
|
|
@ -168,6 +168,9 @@ TEST(HybridGaussianConditional, ContinuousParents) {
|
|||
// Check that the continuous parent keys are correct:
|
||||
EXPECT(continuousParentKeys.size() == 1);
|
||||
EXPECT(continuousParentKeys[0] == X(0));
|
||||
|
||||
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv0));
|
||||
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv1));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue