improve HybridGaussianProductFactor

release/4.3a0
Varun Agrawal 2024-10-08 12:03:35 -04:00
parent 874ba67693
commit 21b4c4c8d3
2 changed files with 18 additions and 8 deletions

View File

@ -26,39 +26,46 @@ namespace gtsam {
using Y = HybridGaussianProductFactor::Y; using Y = HybridGaussianProductFactor::Y;
/* *******************************************************************************/
static Y add(const Y& y1, const Y& y2) { static Y add(const Y& y1, const Y& y2) {
GaussianFactorGraph result = y1.first; GaussianFactorGraph result = y1.first;
result.push_back(y2.first); result.push_back(y2.first);
return {result, y1.second + y2.second}; return {result, y1.second + y2.second};
}; };
/* *******************************************************************************/
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a, HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
const HybridGaussianProductFactor& b) { const HybridGaussianProductFactor& b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add)); return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
} }
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+( HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor& factor) const { const HybridGaussianFactor& factor) const {
return *this + factor.asProductFactor(); return *this + factor.asProductFactor();
} }
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::operator+( HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr& factor) const { const GaussianFactor::shared_ptr& factor) const {
return *this + HybridGaussianProductFactor(factor); return *this + HybridGaussianProductFactor(factor);
} }
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr& factor) { const GaussianFactor::shared_ptr& factor) {
*this = *this + factor; *this = *this + factor;
return *this; return *this;
} }
/* *******************************************************************************/
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=( HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
const HybridGaussianFactor& factor) { const HybridGaussianFactor& factor) {
*this = *this + factor; *this = *this + factor;
return *this; return *this;
} }
/* *******************************************************************************/
void HybridGaussianProductFactor::print(const std::string& s, void HybridGaussianProductFactor::print(const std::string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
KeySet keys; KeySet keys;
@ -76,11 +83,19 @@ void HybridGaussianProductFactor::print(const std::string& s,
} }
} }
/* *******************************************************************************/
bool HybridGaussianProductFactor::equals(
const HybridGaussianProductFactor& other, double tol) const {
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
});
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const { HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const Y& y) { auto emptyGaussian = [](const Y& y) {
bool hasNull = bool hasNull =
std::any_of(y.first.begin(), std::any_of(y.first.begin(), y.first.end(),
y.first.end(),
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; }); [](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y; return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
}; };

View File

@ -94,12 +94,7 @@ class HybridGaussianProductFactor
* @return true if equal, false otherwise * @return true if equal, false otherwise
*/ */
bool equals(const HybridGaussianProductFactor& other, bool equals(const HybridGaussianProductFactor& other,
double tol = 1e-9) const { double tol = 1e-9) const;
return Base::equals(other, [tol](const Y& a, const Y& b) {
return a.first.equals(b.first, tol) &&
std::abs(a.second - b.second) < tol;
});
}
/// @} /// @}