extra test for HybridBayesNet optimize
parent
cb2d2e678d
commit
86320ff3b5
|
|
@ -125,8 +125,14 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianBayesNet gbn;
|
GaussianBayesNet gbn;
|
||||||
for (size_t idx = 0; idx < size(); idx++) {
|
for (size_t idx = 0; idx < size(); idx++) {
|
||||||
GaussianMixture gm = *this->atGaussian(idx);
|
try {
|
||||||
gbn.push_back(gm(assignment));
|
GaussianMixture gm = *this->atGaussian(idx);
|
||||||
|
gbn.push_back(gm(assignment));
|
||||||
|
|
||||||
|
} catch (std::exception &exc) {
|
||||||
|
// if factor at `idx` is discrete-only, just continue.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return gbn;
|
return gbn;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
for (auto &fp : factors) {
|
for (auto &fp : factors) {
|
||||||
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||||
gfg.push_back(ptr->inner());
|
gfg.push_back(ptr->inner());
|
||||||
} else if (auto p =
|
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
|
||||||
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
|
gfg.push_back(
|
||||||
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
|
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
|
||||||
} else {
|
} else {
|
||||||
// It is an orphan wrapped conditional
|
// It is an orphan wrapped conditional
|
||||||
}
|
}
|
||||||
|
|
@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
|
||||||
|
OptionalOrderingType orderingType) const {
|
||||||
|
KeySet discrete_keys;
|
||||||
|
for (auto &factor : factors_) {
|
||||||
|
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||||
|
discrete_keys.insert(k.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const VariableIndex index(factors_);
|
||||||
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
|
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||||
|
return ordering;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
Base::push_back(sharedFactor);
|
Base::push_back(sharedFactor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief
|
||||||
|
*
|
||||||
|
* @param orderingType
|
||||||
|
* @return const Ordering
|
||||||
|
*/
|
||||||
|
const Ordering getHybridOrdering(
|
||||||
|
OptionalOrderingType orderingType = boost::none) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
||||||
|
|
@ -87,7 +88,7 @@ TEST(HybridBayesNet, Choose) {
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test bayes net optimize
|
// Test bayes net optimize
|
||||||
TEST(HybridBayesNet, Optimize) {
|
TEST(HybridBayesNet, OptimizeAssignment) {
|
||||||
Switching s(4);
|
Switching s(4);
|
||||||
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
|
|
@ -119,6 +120,42 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
EXPECT(assert_equal(expected_delta, delta));
|
EXPECT(assert_equal(expected_delta, delta));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test bayes net optimize
|
||||||
|
TEST(HybridBayesNet, Optimize) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
for (auto&& kvp : s.linearizationPoint) {
|
||||||
|
ordering += kvp.key;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet =
|
||||||
|
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
|
||||||
|
|
||||||
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
|
delta.print();
|
||||||
|
VectorValues correct;
|
||||||
|
correct.insert(X(1), 0 * Vector1::Ones());
|
||||||
|
correct.insert(X(2), 1 * Vector1::Ones());
|
||||||
|
correct.insert(X(3), 2 * Vector1::Ones());
|
||||||
|
correct.insert(X(4), 3 * Vector1::Ones());
|
||||||
|
|
||||||
|
DiscreteValues assignment111;
|
||||||
|
assignment111[M(1)] = 1;
|
||||||
|
assignment111[M(2)] = 1;
|
||||||
|
assignment111[M(3)] = 1;
|
||||||
|
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
|
||||||
|
|
||||||
|
DiscreteValues assignment101;
|
||||||
|
assignment101[M(1)] = 1;
|
||||||
|
assignment101[M(2)] = 0;
|
||||||
|
assignment101[M(3)] = 1;
|
||||||
|
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue