improve HybridGaussianProductFactor

release/4.3a0
Varun Agrawal 2024-10-08 12:03:35 -04:00
parent 874ba67693
commit 21b4c4c8d3
2 changed files with 18 additions and 8 deletions

View File

@ -26,39 +26,46 @@ namespace gtsam {
using Y = HybridGaussianProductFactor::Y;
/* *******************************************************************************/
static Y add(const Y& y1, const Y& y2) {
GaussianFactorGraph result = y1.first;
result.push_back(y2.first);
return {result, y1.second + y2.second};
};
/* *******************************************************************************/
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
const HybridGaussianProductFactor& b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor& factor) const {
return *this + factor.asProductFactor();
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr& factor) const {
return *this + HybridGaussianProductFactor(factor);
}
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr& factor) {
*this = *this + factor;
return *this;
}
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const HybridGaussianFactor& factor) {
*this = *this + factor;
return *this;
}
/* *******************************************************************************/
void HybridGaussianProductFactor::print(const std::string& s,
const KeyFormatter& formatter) const {
KeySet keys;
@ -76,11 +83,19 @@ void HybridGaussianProductFactor::print(const std::string& s,
}
}
/* *******************************************************************************/
bool HybridGaussianProductFactor::equals(
const HybridGaussianProductFactor& other, double tol) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
});
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const Y& y) {
bool hasNull =
std::any_of(y.first.begin(),
y.first.end(),
std::any_of(y.first.begin(), y.first.end(),
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
};

View File

@ -94,12 +94,7 @@ class HybridGaussianProductFactor
* @return true if equal, false otherwise
*/
bool equals(const HybridGaussianProductFactor& other,
double tol = 1e-9) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) &&
std::abs(a.second - b.second) < tol;
});
}
double tol = 1e-9) const;
/// @}