Merge pull request #1369 from borglab/hybrid/various-fixes

release/4.3a0
Varun Agrawal 2023-01-04 10:28:52 -05:00 committed by GitHub
commit 78926f7d51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 49 additions and 33 deletions

View File

@ -156,9 +156,9 @@ namespace gtsam {
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const { const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); DiscreteKeys pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering. // Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); DiscreteKeys rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs); const auto assignments = DiscreteValues::CartesianProduct(rpairs);
// Construct unordered_map with values // Construct unordered_map with values

View File

@ -69,8 +69,7 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(conditional); result.push_back(conditional);
if (conditional) { if (conditional) {
return GraphAndConstant( return GraphAndConstant(result, conditional->logNormalizationConstant());
result, conditional->logNormalizationConstant());
} else { } else {
return GraphAndConstant(result, 0.0); return GraphAndConstant(result, 0.0);
} }
@ -163,7 +162,13 @@ KeyVector GaussianMixture::continuousParents() const {
/* ************************************************************************* */ /* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood( boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const { const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals // Check that values has all frontals
for (auto &&kv : frontals) {
if (frontals.find(kv.first) == frontals.end()) {
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
}
}
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents(); const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods( const GaussianMixtureFactor::Factors likelihoods(

View File

@ -26,7 +26,6 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other) : Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
->discreteKeys()), ->discreteKeys()),

View File

@ -60,10 +60,10 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianFactorGraphTree addGaussian( static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &sum, const GaussianFactorGraphTree &gfgTree,
const GaussianFactor::shared_ptr &factor) { const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (gfgTree.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0)); return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
@ -74,20 +74,18 @@ static GaussianFactorGraphTree addGaussian(
result.push_back(factor); result.push_back(factor);
return GraphAndConstant(result, graph_z.constant); return GraphAndConstant(result, graph_z.constant);
}; };
return sum.apply(add); return gfgTree.apply(add);
} }
} }
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(dellaert): We need to document why deferredFactors need to be // TODO(dellaert): Implementation-wise, it's probably more efficient to first
// added last, which I would undo if possible. Implementation-wise, it's // collect the discrete keys, and then loop over all assignments to populate a
// probably more efficient to first collect the discrete keys, and then loop // vector.
// over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree); gttic(assembleGraphTree);
GaussianFactorGraphTree result; GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
@ -101,10 +99,10 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} else if (f->isContinuous()) { } else if (f->isContinuous()) {
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) { if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner()); result = addGaussian(result, gf->inner());
} }
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian()); result = addGaussian(result, cg->asGaussian());
} }
} else if (f->isDiscrete()) { } else if (f->isDiscrete()) {
@ -126,10 +124,6 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} }
} }
for (auto &f : deferredFactors) {
result = addGaussian(result, f);
}
gttoc(assembleGraphTree); gttoc(assembleGraphTree);
return result; return result;

View File

@ -99,9 +99,11 @@ void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_ cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl; << " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n", keyFormatter); std::cout << "HybridGaussianISAM:" << std::endl;
isam_.print("", keyFormatter);
linPoint_.print("Linearization Point:\n", keyFormatter); linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter); std::cout << "Nonlinear Graph:" << std::endl;
factors_.print("", keyFormatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -90,7 +90,7 @@ class GTSAM_EXPORT HybridNonlinearISAM {
const Values& getLinearizationPoint() const { return linPoint_; } const Values& getLinearizationPoint() const { return linPoint_; }
/** Return the current discrete assignment */ /** Return the current discrete assignment */
const DiscreteValues& getAssignment() const { return assignment_; } const DiscreteValues& assignment() const { return assignment_; }
/** get underlying nonlinear graph */ /** get underlying nonlinear graph */
const HybridNonlinearFactorGraph& getFactorsUnsafe() const { const HybridNonlinearFactorGraph& getFactorsUnsafe() const {

View File

@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor {
} }
/// Error for HybridValues is not provided for nonlinear hybrid factor. /// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override { double error(const HybridValues& values) const override {
throw std::runtime_error( throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented."); "MixtureFactor::error(HybridValues) not implemented.");
} }
/**
* @brief Get the dimension of the factor (number of rows on linearization).
* Returns the dimension of the first component factor.
* @return size_t
*/
size_t dim() const { size_t dim() const {
// TODO(Varun) const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
throw std::runtime_error("MixtureFactor::dim not implemented."); auto factor = factors_(assignments.at(0));
return factor->dim();
} }
/// Testable /// Testable

View File

@ -114,7 +114,7 @@ TEST(HybridEstimation, Full) {
/****************************************************************************/ /****************************************************************************/
// Test approximate inference with an additional pruning step. // Test approximate inference with an additional pruning step.
TEST_DISABLED(HybridEstimation, Incremental) { TEST(HybridEstimation, Incremental) {
size_t K = 15; size_t K = 15;
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; 7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
@ -151,9 +151,6 @@ TEST_DISABLED(HybridEstimation, Incremental) {
graph.resize(0); graph.resize(0);
} }
/*TODO(Varun) Gives degenerate result due to probability underflow.
Need to normalize probabilities.
*/
HybridValues delta = smoother.hybridBayesNet().optimize(); HybridValues delta = smoother.hybridBayesNet().optimize();
Values result = initial.retract(delta.continuous()); Values result = initial.retract(delta.continuous());

View File

@ -70,8 +70,7 @@ MixtureFactor
} }
/* ************************************************************************* */ /* ************************************************************************* */
// Test the error of the MixtureFactor static MixtureFactor getMixtureFactor() {
TEST(MixtureFactor, Error) {
DiscreteKey m1(1, 2); DiscreteKey m1(1, 2);
double between0 = 0.0; double between0 = 0.0;
@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) {
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model); boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1}; std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); return MixtureFactor({X(1), X(2)}, {m1}, factors);
}
/* ************************************************************************* */
// Test the error of the MixtureFactor
TEST(MixtureFactor, Error) {
auto mixtureFactor = getMixtureFactor();
Values continuousValues; Values continuousValues;
continuousValues.insert<double>(X(1), 0); continuousValues.insert<double>(X(1), 0);
@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) {
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues); AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1}; std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> errors = {0.5, 0}; std::vector<double> errors = {0.5, 0};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors); AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) {
EXPECT(assert_equal(expected_error, error_tree)); EXPECT(assert_equal(expected_error, error_tree));
} }
/* ************************************************************************* */
// Test dim of the MixtureFactor
TEST(MixtureFactor, Dim) {
auto mixtureFactor = getMixtureFactor();
EXPECT_LONGS_EQUAL(1, mixtureFactor.dim());
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;