Store the values
parent
acccef8024
commit
3797996e89
|
@ -27,14 +27,16 @@
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
#include "gtsam/base/types.h"
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
|
HybridGaussianFactor::FactorValuePairs HybridGaussianFactor::augment(
|
||||||
const FactorValuePairs &factors) {
|
const FactorValuePairs &factors) {
|
||||||
// Find the minimum value so we can "proselytize" to positive values.
|
// Find the minimum value so we can "proselytize" to positive values.
|
||||||
// Done because we can't have sqrt of negative numbers.
|
// Done because we can't have sqrt of negative numbers.
|
||||||
Factors gaussianFactors;
|
DecisionTree<Key, GaussianFactor::shared_ptr> gaussianFactors;
|
||||||
AlgebraicDecisionTree<Key> valueTree;
|
AlgebraicDecisionTree<Key> valueTree;
|
||||||
std::tie(gaussianFactors, valueTree) = unzip(factors);
|
std::tie(gaussianFactors, valueTree) = unzip(factors);
|
||||||
|
|
||||||
|
@ -42,16 +44,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
|
||||||
double min_value = valueTree.min();
|
double min_value = valueTree.min();
|
||||||
|
|
||||||
// Finally, update the [A|b] matrices.
|
// Finally, update the [A|b] matrices.
|
||||||
auto update = [&min_value](const GaussianFactorValuePair &gfv) {
|
auto update = [&min_value](const auto &gfv) -> GaussianFactorValuePair {
|
||||||
auto [gf, value] = gfv;
|
auto [gf, value] = gfv;
|
||||||
|
|
||||||
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
|
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
|
||||||
if (!jf) return gf;
|
if (!jf) return {gf, 0.0}; // should this be zero or infinite?
|
||||||
|
|
||||||
double normalized_value = value - min_value;
|
double normalized_value = value - min_value;
|
||||||
|
|
||||||
// If the value is 0, do nothing
|
// If the value is 0, do nothing
|
||||||
if (normalized_value == 0.0) return gf;
|
if (normalized_value == 0.0) return {gf, 0.0};
|
||||||
|
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
gfg.push_back(jf);
|
gfg.push_back(jf);
|
||||||
|
@ -62,18 +64,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
|
||||||
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
||||||
|
|
||||||
gfg.push_back(constantFactor);
|
gfg.push_back(constantFactor);
|
||||||
return std::dynamic_pointer_cast<GaussianFactor>(
|
return {std::make_shared<JacobianFactor>(gfg), normalized_value};
|
||||||
std::make_shared<JacobianFactor>(gfg));
|
|
||||||
};
|
};
|
||||||
return Factors(factors, update);
|
return FactorValuePairs(factors, update);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
struct HybridGaussianFactor::ConstructorHelper {
|
struct HybridGaussianFactor::ConstructorHelper {
|
||||||
KeyVector continuousKeys; // Continuous keys extracted from factors
|
KeyVector continuousKeys; // Continuous keys extracted from factors
|
||||||
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
||||||
FactorValuePairs pairs; // Used only if factorsTree is empty
|
FactorValuePairs pairs; // The decision tree with factors and scalars
|
||||||
Factors factorsTree;
|
|
||||||
|
|
||||||
ConstructorHelper(const DiscreteKey &discreteKey,
|
ConstructorHelper(const DiscreteKey &discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
const std::vector<GaussianFactor::shared_ptr> &factors)
|
||||||
|
@ -85,9 +85,10 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Build the FactorValuePairs DecisionTree
|
||||||
// Build the DecisionTree from the factor vector
|
pairs = FactorValuePairs(
|
||||||
factorsTree = Factors(discreteKeys, factors);
|
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||||
|
[](const auto &f) { return std::pair{f, 0.0}; });
|
||||||
}
|
}
|
||||||
|
|
||||||
ConstructorHelper(const DiscreteKey &discreteKey,
|
ConstructorHelper(const DiscreteKey &discreteKey,
|
||||||
|
@ -109,6 +110,7 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
const FactorValuePairs &factorPairs)
|
const FactorValuePairs &factorPairs)
|
||||||
: discreteKeys(discreteKeys) {
|
: discreteKeys(discreteKeys) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
|
// TODO: just stop after first non-null factor
|
||||||
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
|
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
|
||||||
if (pair.first && continuousKeys.empty()) {
|
if (pair.first && continuousKeys.empty()) {
|
||||||
continuousKeys = pair.first->keys();
|
continuousKeys = pair.first->keys();
|
||||||
|
@ -123,14 +125,13 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
|
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
|
||||||
: Base(helper.continuousKeys, helper.discreteKeys),
|
: Base(helper.continuousKeys, helper.discreteKeys),
|
||||||
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
|
factors_(augment(helper.pairs)) {}
|
||||||
: helper.factorsTree) {}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey &discreteKey,
|
const DiscreteKey &discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr> &factors)
|
const std::vector<GaussianFactor::shared_ptr> &factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
|
@ -140,8 +141,8 @@ HybridGaussianFactor::HybridGaussianFactor(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||||
const FactorValuePairs &factors)
|
const FactorValuePairs &factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
|
@ -153,10 +154,12 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
if (factors_.empty() ^ e->factors_.empty()) return false;
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return Base::equals(*e, tol) &&
|
auto compareFunc = [tol](const auto &pair1, const auto &pair2) {
|
||||||
factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
|
auto f1 = pair1.first, f2 = pair2.first;
|
||||||
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||||
});
|
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
||||||
|
};
|
||||||
|
return Base::equals(*e, tol) && factors_.equals(e->factors_, compareFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -171,15 +174,16 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const sharedFactor &gf) -> std::string {
|
[&](const auto &pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (gf) {
|
if (pair.first) {
|
||||||
gf->print("", formatter);
|
pair.first->print("", formatter);
|
||||||
return rd.str();
|
return rd.str();
|
||||||
} else {
|
} else {
|
||||||
return "nullptr";
|
return "nullptr";
|
||||||
}
|
}
|
||||||
|
std::cout << "scalar: " << pair.second << "\n";
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
std::cout << "}" << std::endl;
|
std::cout << "}" << std::endl;
|
||||||
|
@ -188,7 +192,7 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
|
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
return factors_(assignment);
|
return factors_(assignment).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
@ -207,7 +211,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
|
auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -229,8 +233,8 @@ static double PotentiallyPrunedComponentError(
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
|
auto errorFunc = [this, &continuousValues](const auto &pair) {
|
||||||
return PotentiallyPrunedComponentError(gf, continuousValues);
|
return PotentiallyPrunedComponentError(pair.first, continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
||||||
return error_tree;
|
return error_tree;
|
||||||
|
@ -239,8 +243,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianFactor::error(const HybridValues &values) const {
|
double HybridGaussianFactor::error(const HybridValues &values) const {
|
||||||
// Directly index to get the component, no need to build the whole tree.
|
// Directly index to get the component, no need to build the whole tree.
|
||||||
const sharedFactor gf = factors_(values.discrete());
|
const auto pair = factors_(values.discrete());
|
||||||
return PotentiallyPrunedComponentError(gf, values.continuous());
|
return PotentiallyPrunedComponentError(pair.first, values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -66,12 +66,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
||||||
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
||||||
/// typedef for Decision Tree of Gaussian factors.
|
|
||||||
using Factors = DecisionTree<Key, sharedFactor>;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
Factors factors_;
|
FactorValuePairs factors_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
@ -110,10 +108,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
|
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
|
||||||
*
|
*
|
||||||
* @param discreteKeys Discrete variables and their cardinalities.
|
* @param discreteKeys Discrete variables and their cardinalities.
|
||||||
* @param factors The decision tree of Gaussian factor/scalar pairs.
|
* @param factorPairs The decision tree of Gaussian factor/scalar pairs.
|
||||||
*/
|
*/
|
||||||
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||||
const FactorValuePairs &factors);
|
const FactorValuePairs &factorPairs);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
@ -158,7 +156,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
/// Getter for GaussianFactor decision tree
|
/// Getter for GaussianFactor decision tree
|
||||||
const Factors &factors() const { return factors_; }
|
const FactorValuePairs &factors() const { return factors_; }
|
||||||
|
|
||||||
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
|
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
|
||||||
friend GaussianFactorGraphTree &operator+=(
|
friend GaussianFactorGraphTree &operator+=(
|
||||||
|
@ -184,10 +182,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
* value in the `b` vector as an additional row.
|
* value in the `b` vector as an additional row.
|
||||||
*
|
*
|
||||||
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
|
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
|
||||||
* Gaussian factor in factors.
|
* @return FactorValuePairs
|
||||||
* @return HybridGaussianFactor::Factors
|
|
||||||
*/
|
*/
|
||||||
static Factors augment(const FactorValuePairs &factors);
|
static FactorValuePairs augment(const FactorValuePairs &factors);
|
||||||
|
|
||||||
/// Helper struct to assist private constructor below.
|
/// Helper struct to assist private constructor below.
|
||||||
struct ConstructorHelper;
|
struct ConstructorHelper;
|
||||||
|
|
|
@ -238,8 +238,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute discrete probabilities.
|
||||||
auto logProbability =
|
auto logProbability = [&](const auto &pair) -> double {
|
||||||
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
auto [factor, _] = pair;
|
||||||
if (!factor) return 0.0;
|
if (!factor) return 0.0;
|
||||||
return factor->error(VectorValues());
|
return factor->error(VectorValues());
|
||||||
};
|
};
|
||||||
|
|
|
@ -196,8 +196,8 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>>
|
HybridGaussianFactor::FactorValuePairs linearized_factors(factors_,
|
||||||
linearized_factors(factors_, linearizeDT);
|
linearizeDT);
|
||||||
|
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteKeys_,
|
return std::make_shared<HybridGaussianFactor>(discreteKeys_,
|
||||||
linearized_factors);
|
linearized_factors);
|
||||||
|
|
|
@ -52,11 +52,11 @@ BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
|
||||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs,
|
||||||
"gtsam_HybridGaussianFactor_Factors");
|
"gtsam_HybridGaussianFactor_Factors");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Leaf,
|
||||||
"gtsam_HybridGaussianFactor_Factors_Leaf");
|
"gtsam_HybridGaussianFactor_Factors_Leaf");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Choice,
|
||||||
"gtsam_HybridGaussianFactor_Factors_Choice");
|
"gtsam_HybridGaussianFactor_Factors_Choice");
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,
|
||||||
|
|
Loading…
Reference in New Issue