Merge pull request #1954 from borglab/hybrid-with-tablefactor
commit
30670ab1a5
|
|
@ -20,12 +20,12 @@
|
|||
|
||||
#include <gtsam/base/utilities.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/discrete/DiscreteValues.h>
|
||||
#include <gtsam/discrete/TableFactor.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
|
@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
|||
/* ************************************************************************ */
|
||||
/**
|
||||
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
||||
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||
* and then exponentiate to create a TableFactor (not normalized yet!).
|
||||
*
|
||||
* @param errors DecisionTree of (unnormalized) errors.
|
||||
* @return DecisionTreeFactor::shared_ptr
|
||||
* @return TableFactor::shared_ptr
|
||||
*/
|
||||
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||
static TableFactor::shared_ptr DiscreteFactorFromErrors(
|
||||
const DiscreteKeys &discreteKeys,
|
||||
const AlgebraicDecisionTree<Key> &errors) {
|
||||
double min_log = errors.min();
|
||||
AlgebraicDecisionTree<Key> potentials(
|
||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||
return std::make_shared<TableFactor>(discreteKeys, potentials);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
|
|
|||
|
|
@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
|||
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
|
||||
|
||||
// Check that factor is discrete and correct
|
||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
|
||||
CHECK(factor);
|
||||
// regression test
|
||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
||||
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {
|
|||
|
||||
// Check the remaining factor for x1
|
||||
CHECK(factor_x1);
|
||||
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
|
||||
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
|
||||
CHECK(phi_x1);
|
||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||
// We can't really check the error of the decision tree factor phi_x1, because
|
||||
|
|
|
|||
|
|
@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
|
|||
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
|
||||
|
||||
// This is now a discreteFactor
|
||||
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
|
||||
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
|
||||
CHECK(discreteFactor);
|
||||
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
|
||||
EXPECT(discreteFactor->root_->isLeaf() == false);
|
||||
}
|
||||
|
||||
/****************************************************************************
|
||||
|
|
|
|||
Loading…
Reference in New Issue