extra test for HybridBayesNet optimize
parent
cb2d2e678d
commit
86320ff3b5
|
|
@ -125,8 +125,14 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (size_t idx = 0; idx < size(); idx++) {
|
||||
GaussianMixture gm = *this->atGaussian(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
try {
|
||||
GaussianMixture gm = *this->atGaussian(idx);
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} catch (std::exception &exc) {
|
||||
// if factor at `idx` is discrete-only, just continue.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return gbn;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
for (auto &fp : factors) {
|
||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||
gfg.push_back(ptr->inner());
|
||||
} else if (auto p =
|
||||
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
|
||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
||||
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
||||
gfg.push_back(
|
||||
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
}
|
||||
|
|
@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
|||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
|
||||
OptionalOrderingType orderingType) const {
|
||||
KeySet discrete_keys;
|
||||
for (auto &factor : factors_) {
|
||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||
discrete_keys.insert(k.first);
|
||||
}
|
||||
}
|
||||
|
||||
const VariableIndex index(factors_);
|
||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
Base::push_back(sharedFactor);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief
|
||||
*
|
||||
* @param orderingType
|
||||
* @return const Ordering
|
||||
*/
|
||||
const Ordering getHybridOrdering(
|
||||
OptionalOrderingType orderingType = boost::none) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
||||
#include "Switching.h"
|
||||
|
||||
|
|
@ -87,7 +88,7 @@ TEST(HybridBayesNet, Choose) {
|
|||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
TEST(HybridBayesNet, Optimize) {
|
||||
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering ordering;
|
||||
|
|
@ -119,6 +120,42 @@ TEST(HybridBayesNet, Optimize) {
|
|||
EXPECT(assert_equal(expected_delta, delta));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test bayes net optimize
|
||||
TEST(HybridBayesNet, Optimize) {
|
||||
Switching s(4);
|
||||
|
||||
Ordering ordering;
|
||||
for (auto&& kvp : s.linearizationPoint) {
|
||||
ordering += kvp.key;
|
||||
}
|
||||
|
||||
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
delta.print();
|
||||
VectorValues correct;
|
||||
correct.insert(X(1), 0 * Vector1::Ones());
|
||||
correct.insert(X(2), 1 * Vector1::Ones());
|
||||
correct.insert(X(3), 2 * Vector1::Ones());
|
||||
correct.insert(X(4), 3 * Vector1::Ones());
|
||||
|
||||
DiscreteValues assignment111;
|
||||
assignment111[M(1)] = 1;
|
||||
assignment111[M(2)] = 1;
|
||||
assignment111[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
|
||||
|
||||
DiscreteValues assignment101;
|
||||
assignment101[M(1)] = 1;
|
||||
assignment101[M(2)] = 0;
|
||||
assignment101[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue