diff --git a/gtsam/hybrid/MixtureFactor.h b/gtsam/hybrid/MixtureFactor.h index fc1a9a2b8..5285dd191 100644 --- a/gtsam/hybrid/MixtureFactor.h +++ b/gtsam/hybrid/MixtureFactor.h @@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor { } /// Error for HybridValues is not provided for nonlinear hybrid factor. - double error(const HybridValues &values) const override { + double error(const HybridValues& values) const override { throw std::runtime_error( "MixtureFactor::error(HybridValues) not implemented."); } + /** + * @brief Get the dimension of the factor (number of rows on linearization). + * Returns the dimension of the first component factor. + * @return size_t + */ size_t dim() const { - // TODO(Varun) - throw std::runtime_error("MixtureFactor::dim not implemented."); + const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_); + auto factor = factors_(assignments.at(0)); + return factor->dim(); } /// Testable diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index fe3212eda..9e4d66bf2 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -70,8 +70,7 @@ MixtureFactor } /* ************************************************************************* */ -// Test the error of the MixtureFactor -TEST(MixtureFactor, Error) { +static MixtureFactor getMixtureFactor() { DiscreteKey m1(1, 2); double between0 = 0.0; @@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) { boost::make_shared>(X(1), X(2), between1, model); std::vector factors{f0, f1}; - MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors); + return MixtureFactor({X(1), X(2)}, {m1}, factors); +} + +/* ************************************************************************* */ +// Test the error of the MixtureFactor +TEST(MixtureFactor, Error) { + auto mixtureFactor = getMixtureFactor(); Values continuousValues; continuousValues.insert(X(1), 0); @@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) { AlgebraicDecisionTree error_tree = mixtureFactor.error(continuousValues); + DiscreteKey m1(1, 2); std::vector discrete_keys = {m1}; std::vector errors = {0.5, 0}; AlgebraicDecisionTree expected_error(discrete_keys, errors); @@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) { EXPECT(assert_equal(expected_error, error_tree)); } +/* ************************************************************************* */ +// Test dim of the MixtureFactor +TEST(MixtureFactor, Dim) { + auto mixtureFactor = getMixtureFactor(); + EXPECT_LONGS_EQUAL(1, mixtureFactor.dim()); +} + /* ************************************************************************* */ int main() { TestResult tr;