extra test for HybridBayesNet optimize

release/4.3a0
Varun Agrawal 2022-08-26 16:45:44 -04:00
parent cb2d2e678d
commit 86320ff3b5
4 changed files with 74 additions and 6 deletions

View File

@ -125,8 +125,14 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
try {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
continue;
}
}
return gbn;
}

View File

@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
gfg.push_back(ptr->inner());
} else if (auto p =
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
gfg.push_back(
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
} else {
// It is an orphan wrapped conditional
}
@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}
/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
OptionalOrderingType orderingType) const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}
} // namespace gtsam

View File

@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
Base::push_back(sharedFactor);
}
}
/**
* @brief
*
* @param orderingType
* @return const Ordering
*/
const Ordering getHybridOrdering(
OptionalOrderingType orderingType = boost::none) const;
};
} // namespace gtsam

View File

@ -19,6 +19,7 @@
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include "Switching.h"
@ -87,7 +88,7 @@ TEST(HybridBayesNet, Choose) {
/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
TEST(HybridBayesNet, OptimizeAssignment) {
Switching s(4);
Ordering ordering;
@ -119,6 +120,42 @@ TEST(HybridBayesNet, Optimize) {
EXPECT(assert_equal(expected_delta, delta));
}
/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
HybridValues delta = hybridBayesNet->optimize();
delta.print();
VectorValues correct;
correct.insert(X(1), 0 * Vector1::Ones());
correct.insert(X(2), 1 * Vector1::Ones());
correct.insert(X(3), 2 * Vector1::Ones());
correct.insert(X(4), 3 * Vector1::Ones());
DiscreteValues assignment111;
assignment111[M(1)] = 1;
assignment111[M(2)] = 1;
assignment111[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
DiscreteValues assignment101;
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
}
/* ************************************************************************* */
int main() {
TestResult tr;