HybridBayesTree::optimize

release/4.3a0
Varun Agrawal 2022-08-26 11:36:13 -04:00
parent f0df82ac04
commit cb2d2e678d
3 changed files with 152 additions and 0 deletions

View File

@ -35,4 +35,41 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;
KeyVector added_keys;
// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();
KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());
// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
// Bayes Net.
if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
GaussianConditional::shared_ptr gaussian_conditional =
(*gm)(assignment);
gbn.push_back(gaussian_conditional);
}
}
}
// Return the optimized bayes net.
return gbn.optimize();
}
} // namespace gtsam

View File

@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;
/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
* @param assignment The discrete values assignment to select the Gaussian
* mixtures.
* @return VectorValues
*/
VectorValues optimize(const DiscreteValues& assignment) const;
/// @}
};

View File

@ -0,0 +1,106 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridBayesTree.cpp
* @brief Unit tests for HybridBayesTree
* @author Varun Agrawal
* @date August 2022
*/
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
#include "Switching.h"
// Include for test suite
#include <CppUnitLite/TestHarness.h>
using namespace std;
using namespace gtsam;
using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
/* ****************************************************************************/
// Test for optimizing a HybridBayesTree.
TEST(HybridBayesTree, Optimize) {
Switching s(4);
HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;
// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
// Add the Gaussian factors, 1 prior on X(1),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 7; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}
isam.update(graph1);
DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;
VectorValues delta = isam.optimize(assignment);
// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
EXPECT(assert_equal(expected_delta, delta));
// Create ordering.
Ordering ordering;
for (size_t k = 1; k <= s.K; k++) ordering += X(k);
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
hybridBayesNet->print();
GaussianBayesNet gbn = hybridBayesNet->choose(assignment);
// EXPECT_LONGS_EQUAL(4, gbn.size());
// EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
// hybridBayesNet->atGaussian(0)))(assignment),
// *gbn.at(0)));
// EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
// hybridBayesNet->atGaussian(1)))(assignment),
// *gbn.at(1)));
// EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
// hybridBayesNet->atGaussian(2)))(assignment),
// *gbn.at(2)));
// EXPECT(assert_equal(*(*boost::dynamic_pointer_cast<GaussianMixture>(
// hybridBayesNet->atGaussian(3)))(assignment),
// *gbn.at(3)));
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */