Store the values

release/4.3a0
Frank Dellaert 2024-09-28 10:48:51 -07:00
parent acccef8024
commit 3797996e89
5 changed files with 49 additions and 48 deletions

View File

@ -27,14 +27,16 @@
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
#include "gtsam/base/types.h"
namespace gtsam { namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::Factors HybridGaussianFactor::augment( HybridGaussianFactor::FactorValuePairs HybridGaussianFactor::augment(
const FactorValuePairs &factors) { const FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values. // Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers. // Done because we can't have sqrt of negative numbers.
Factors gaussianFactors; DecisionTree<Key, GaussianFactor::shared_ptr> gaussianFactors;
AlgebraicDecisionTree<Key> valueTree; AlgebraicDecisionTree<Key> valueTree;
std::tie(gaussianFactors, valueTree) = unzip(factors); std::tie(gaussianFactors, valueTree) = unzip(factors);
@ -42,16 +44,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
double min_value = valueTree.min(); double min_value = valueTree.min();
// Finally, update the [A|b] matrices. // Finally, update the [A|b] matrices.
auto update = [&min_value](const GaussianFactorValuePair &gfv) { auto update = [&min_value](const auto &gfv) -> GaussianFactorValuePair {
auto [gf, value] = gfv; auto [gf, value] = gfv;
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf); auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
if (!jf) return gf; if (!jf) return {gf, 0.0}; // should this be zero or infinite?
double normalized_value = value - min_value; double normalized_value = value - min_value;
// If the value is 0, do nothing // If the value is 0, do nothing
if (normalized_value == 0.0) return gf; if (normalized_value == 0.0) return {gf, 0.0};
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
gfg.push_back(jf); gfg.push_back(jf);
@ -62,18 +64,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
auto constantFactor = std::make_shared<JacobianFactor>(c); auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor); gfg.push_back(constantFactor);
return std::dynamic_pointer_cast<GaussianFactor>( return {std::make_shared<JacobianFactor>(gfg), normalized_value};
std::make_shared<JacobianFactor>(gfg));
}; };
return Factors(factors, update); return FactorValuePairs(factors, update);
} }
/* *******************************************************************************/ /* *******************************************************************************/
struct HybridGaussianFactor::ConstructorHelper { struct HybridGaussianFactor::ConstructorHelper {
KeyVector continuousKeys; // Continuous keys extracted from factors KeyVector continuousKeys; // Continuous keys extracted from factors
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
FactorValuePairs pairs; // Used only if factorsTree is empty FactorValuePairs pairs; // The decision tree with factors and scalars
Factors factorsTree;
ConstructorHelper(const DiscreteKey &discreteKey, ConstructorHelper(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors) const std::vector<GaussianFactor::shared_ptr> &factors)
@ -85,9 +85,10 @@ struct HybridGaussianFactor::ConstructorHelper {
break; break;
} }
} }
// Build the FactorValuePairs DecisionTree
// Build the DecisionTree from the factor vector pairs = FactorValuePairs(
factorsTree = Factors(discreteKeys, factors); DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
[](const auto &f) { return std::pair{f, 0.0}; });
} }
ConstructorHelper(const DiscreteKey &discreteKey, ConstructorHelper(const DiscreteKey &discreteKey,
@ -109,6 +110,7 @@ struct HybridGaussianFactor::ConstructorHelper {
const FactorValuePairs &factorPairs) const FactorValuePairs &factorPairs)
: discreteKeys(discreteKeys) { : discreteKeys(discreteKeys) {
// Extract continuous keys from the first non-null factor // Extract continuous keys from the first non-null factor
// TODO: just stop after first non-null factor
factorPairs.visit([&](const GaussianFactorValuePair &pair) { factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first && continuousKeys.empty()) { if (pair.first && continuousKeys.empty()) {
continuousKeys = pair.first->keys(); continuousKeys = pair.first->keys();
@ -123,14 +125,13 @@ struct HybridGaussianFactor::ConstructorHelper {
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper) HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
: Base(helper.continuousKeys, helper.discreteKeys), : Base(helper.continuousKeys, helper.discreteKeys),
factors_(helper.factorsTree.empty() ? augment(helper.pairs) factors_(augment(helper.pairs)) {}
: helper.factorsTree) {}
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor( HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey, const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors) const std::vector<GaussianFactor::shared_ptr> &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {} : HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor( HybridGaussianFactor::HybridGaussianFactor(
@ -140,8 +141,8 @@ HybridGaussianFactor::HybridGaussianFactor(
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys, HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors) const FactorValuePairs &factorPairs)
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {} : HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
@ -153,10 +154,12 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
if (factors_.empty() ^ e->factors_.empty()) return false; if (factors_.empty() ^ e->factors_.empty()) return false;
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && auto compareFunc = [tol](const auto &pair1, const auto &pair2) {
factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) { auto f1 = pair1.first, f2 = pair2.first;
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol)); bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
}); return match && gtsam::equal(pair1.second, pair2.second, tol);
};
return Base::equals(*e, tol) && factors_.equals(e->factors_, compareFunc);
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -171,15 +174,16 @@ void HybridGaussianFactor::print(const std::string &s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const sharedFactor &gf) -> std::string { [&](const auto &pair) -> std::string {
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf) { if (pair.first) {
gf->print("", formatter); pair.first->print("", formatter);
return rd.str(); return rd.str();
} else { } else {
return "nullptr"; return "nullptr";
} }
std::cout << "scalar: " << pair.second << "\n";
}); });
} }
std::cout << "}" << std::endl; std::cout << "}" << std::endl;
@ -188,7 +192,7 @@ void HybridGaussianFactor::print(const std::string &s,
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()( HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
return factors_(assignment); return factors_(assignment).first;
} }
/* *******************************************************************************/ /* *******************************************************************************/
@ -207,7 +211,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; }; auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
return {factors_, wrap}; return {factors_, wrap};
} }
@ -229,8 +233,8 @@ static double PotentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const sharedFactor &gf) { auto errorFunc = [this, &continuousValues](const auto &pair) {
return PotentiallyPrunedComponentError(gf, continuousValues); return PotentiallyPrunedComponentError(pair.first, continuousValues);
}; };
DecisionTree<Key, double> error_tree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree; return error_tree;
@ -239,8 +243,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const { double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree. // Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete()); const auto pair = factors_(values.discrete());
return PotentiallyPrunedComponentError(gf, values.continuous()); return PotentiallyPrunedComponentError(pair.first, values.continuous());
} }
} // namespace gtsam } // namespace gtsam

View File

@ -66,12 +66,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// typedef for Decision Tree of Gaussian factors and arbitrary value. /// typedef for Decision Tree of Gaussian factors and arbitrary value.
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>; using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
/// typedef for Decision Tree of Gaussian factors.
using Factors = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_; FactorValuePairs factors_;
public: public:
/// @name Constructors /// @name Constructors
@ -110,10 +108,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m. * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
* *
* @param discreteKeys Discrete variables and their cardinalities. * @param discreteKeys Discrete variables and their cardinalities.
* @param factors The decision tree of Gaussian factor/scalar pairs. * @param factorPairs The decision tree of Gaussian factor/scalar pairs.
*/ */
HybridGaussianFactor(const DiscreteKeys &discreteKeys, HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors); const FactorValuePairs &factorPairs);
/// @} /// @}
/// @name Testable /// @name Testable
@ -158,7 +156,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// Getter for GaussianFactor decision tree /// Getter for GaussianFactor decision tree
const Factors &factors() const { return factors_; } const FactorValuePairs &factors() const { return factors_; }
/// Add HybridNonlinearFactor to a Sum, syntactic sugar. /// Add HybridNonlinearFactor to a Sum, syntactic sugar.
friend GaussianFactorGraphTree &operator+=( friend GaussianFactorGraphTree &operator+=(
@ -184,10 +182,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* value in the `b` vector as an additional row. * value in the `b` vector as an additional row.
* *
* @param factors DecisionTree of GaussianFactors and arbitrary scalars. * @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors. * @return FactorValuePairs
* @return HybridGaussianFactor::Factors
*/ */
static Factors augment(const FactorValuePairs &factors); static FactorValuePairs augment(const FactorValuePairs &factors);
/// Helper struct to assist private constructor below. /// Helper struct to assist private constructor below.
struct ConstructorHelper; struct ConstructorHelper;

View File

@ -238,8 +238,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
// Case where we have a HybridGaussianFactor with no continuous keys. // Case where we have a HybridGaussianFactor with no continuous keys.
// In this case, compute discrete probabilities. // In this case, compute discrete probabilities.
auto logProbability = auto logProbability = [&](const auto &pair) -> double {
[&](const GaussianFactor::shared_ptr &factor) -> double { auto [factor, _] = pair;
if (!factor) return 0.0; if (!factor) return 0.0;
return factor->error(VectorValues()); return factor->error(VectorValues());
}; };

View File

@ -196,8 +196,8 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
} }
}; };
DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>> HybridGaussianFactor::FactorValuePairs linearized_factors(factors_,
linearized_factors(factors_, linearizeDT); linearizeDT);
return std::make_shared<HybridGaussianFactor>(discreteKeys_, return std::make_shared<HybridGaussianFactor>(discreteKeys_,
linearized_factors); linearized_factors);

View File

@ -52,11 +52,11 @@ BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice") BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor"); BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors, BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs,
"gtsam_HybridGaussianFactor_Factors"); "gtsam_HybridGaussianFactor_Factors");
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf, BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Leaf,
"gtsam_HybridGaussianFactor_Factors_Leaf"); "gtsam_HybridGaussianFactor_Factors_Leaf");
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice, BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Choice,
"gtsam_HybridGaussianFactor_Factors_Choice"); "gtsam_HybridGaussianFactor_Factors_Choice");
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional, BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,