Merge pull request #1954 from borglab/hybrid-with-tablefactor

release/4.3a0
Varun Agrawal 2024-12-31 19:42:08 -05:00 committed by GitHub
commit 30670ab1a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 10 deletions

View File

@ -20,12 +20,12 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/hybrid/HybridFactor.h>
@ -241,18 +241,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */
/**
* @brief Take negative log-values, shift them so that the minimum value is 0,
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
* and then exponentiate to create a TableFactor (not normalized yet!).
*
* @param errors DecisionTree of (unnormalized) errors.
* @return DecisionTreeFactor::shared_ptr
* @return TableFactor::shared_ptr
*/
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
static TableFactor::shared_ptr DiscreteFactorFromErrors(
const DiscreteKeys &discreteKeys,
const AlgebraicDecisionTree<Key> &errors) {
double min_log = errors.min();
AlgebraicDecisionTree<Key> potentials(
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
return std::make_shared<TableFactor>(discreteKeys, potentials);
}
/* ************************************************************************ */

View File

@ -114,10 +114,10 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
EXPECT(HybridConditional::CheckInvariants(*result.first, values));
// Check that factor is discrete and correct
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
auto factor = std::dynamic_pointer_cast<TableFactor>(result.second);
CHECK(factor);
// regression test
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
EXPECT(assert_equal(TableFactor{m1, "1 1"}, *factor, 1e-5));
}
/* ************************************************************************* */
@ -329,7 +329,7 @@ TEST(HybridBayesNet, Switching) {
// Check the remaining factor for x1
CHECK(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
auto phi_x1 = std::dynamic_pointer_cast<TableFactor>(factor_x1);
CHECK(phi_x1);
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
// We can't really check the error of the decision tree factor phi_x1, because

View File

@ -368,10 +368,9 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EXPECT_LONGS_EQUAL(1, hybridGaussianConditional->nrParents());
// This is now a discreteFactor
auto discreteFactor = dynamic_pointer_cast<DecisionTreeFactor>(factorOnModes);
auto discreteFactor = dynamic_pointer_cast<TableFactor>(factorOnModes);
CHECK(discreteFactor);
EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size());
EXPECT(discreteFactor->root_->isLeaf() == false);
}
/****************************************************************************