Restrict for hybrid factors (and discrete)

release/4.3a0
Frank Dellaert 2025-02-01 02:28:01 -05:00
parent 352c7f2efa
commit ea27bac018
18 changed files with 149 additions and 24 deletions

View File

@ -536,5 +536,11 @@ namespace gtsam {
return DecisionTreeFactor(this->discreteKeys(), thresholded); return DecisionTreeFactor(this->discreteKeys(), thresholded);
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("DecisionTreeFactor::restrict not implemented");
}
/* ************************************************************************ */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -220,6 +220,10 @@ namespace gtsam {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @} /// @}
/// @name Advanced Interface /// @name Advanced Interface
/// @{ /// @{

View File

@ -167,8 +167,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/** /**
* @brief Scale the factor values by the maximum * @brief Scale the factor values by the maximum
* to prevent underflow/overflow. * to prevent underflow/overflow.
* *
* @return DiscreteFactor::shared_ptr * @return DiscreteFactor::shared_ptr
*/ */
DiscreteFactor::shared_ptr scale() const; DiscreteFactor::shared_ptr scale() const;
@ -178,6 +178,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
*/ */
virtual uint64_t nrValues() const = 0; virtual uint64_t nrValues() const = 0;
/// Restrict the factor to the given assignment.
virtual DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const = 0;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -391,12 +391,12 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const { DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add); return combine(nrFrontals, Ring::add);
} }
/* ************************************************************************ */ /* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const { DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
return combine(keys, Ring::add); return combine(keys, Ring::add);
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -418,7 +418,6 @@ DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
return combine(keys, Ring::max); return combine(keys, Ring::max);
} }
/* ************************************************************************ */ /* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const { TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor. // Initialize new factor.
@ -781,5 +780,11 @@ TableFactor TableFactor::prune(size_t maxNrAssignments) const {
return TableFactor(this->discreteKeys(), pruned_vec); return TableFactor(this->discreteKeys(), pruned_vec);
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("TableFactor::restrict not implemented");
}
/* ************************************************************************ */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -342,6 +342,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/ */
uint64_t nrValues() const override { return sparse_table_.nonZeros(); } uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -69,11 +69,13 @@ HybridBayesNet HybridBayesNet::prune(
// Go through all the Gaussian conditionals, restrict them according to // Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further. // fixed values, and then prune further.
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) { for (std::shared_ptr<HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue; if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional. // No-op if not a HybridGaussianConditional.
if (marginalThreshold) conditional = conditional->restrict(fixed); if (marginalThreshold)
conditional = std::static_pointer_cast<HybridConditional>(
conditional->restrict(fixed));
// Now decide on type what to do: // Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) { if (auto hgc = conditional->asHybrid()) {

View File

@ -170,8 +170,8 @@ double HybridConditional::evaluate(const HybridValues &values) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::shared_ptr HybridConditional::restrict( std::shared_ptr<Factor> HybridConditional::restrict(
const DiscreteValues &discreteValues) const { const DiscreteValues &assignment) const {
if (auto gc = asGaussian()) { if (auto gc = asGaussian()) {
return std::make_shared<HybridConditional>(gc); return std::make_shared<HybridConditional>(gc);
} else if (auto dc = asDiscrete()) { } else if (auto dc = asDiscrete()) {
@ -184,21 +184,20 @@ HybridConditional::shared_ptr HybridConditional::restrict(
"HybridConditional::restrict: conditional type not handled"); "HybridConditional::restrict: conditional type not handled");
// Case 1: Fully determined, return corresponding Gaussian conditional // Case 1: Fully determined, return corresponding Gaussian conditional
auto parentValues = discreteValues.filter(discreteKeys_); auto parentValues = assignment.filter(discreteKeys_);
if (parentValues.size() == discreteKeys_.size()) { if (parentValues.size() == discreteKeys_.size()) {
return std::make_shared<HybridConditional>(hgc->choose(parentValues)); return std::make_shared<HybridConditional>(hgc->choose(parentValues));
} }
// Case 2: Some live parents remain, build a new tree // Case 2: Some live parents remain, build a new tree
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_); auto remainingKeys = assignment.missingKeys(discreteKeys_);
if (!unspecifiedParentKeys.empty()) { if (!remainingKeys.empty()) {
auto newTree = hgc->factors(); auto newTree = hgc->factors();
for (const auto &[key, value] : parentValues) { for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value); newTree = newTree.choose(key, value);
} }
return std::make_shared<HybridConditional>( return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys, std::make_shared<HybridGaussianConditional>(remainingKeys, newTree));
newTree));
} }
// Case 3: No changes needed, return original // Case 3: No changes needed, return original

View File

@ -153,7 +153,8 @@ class GTSAM_EXPORT HybridConditional
* @return HybridGaussianConditional::shared_ptr otherwise * @return HybridGaussianConditional::shared_ptr otherwise
*/ */
HybridGaussianConditional::shared_ptr asHybrid() const { HybridGaussianConditional::shared_ptr asHybrid() const {
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_); if (!isHybrid()) return nullptr;
return std::static_pointer_cast<HybridGaussianConditional>(inner_);
} }
/** /**
@ -162,7 +163,8 @@ class GTSAM_EXPORT HybridConditional
* @return GaussianConditional::shared_ptr otherwise * @return GaussianConditional::shared_ptr otherwise
*/ */
GaussianConditional::shared_ptr asGaussian() const { GaussianConditional::shared_ptr asGaussian() const {
return std::dynamic_pointer_cast<GaussianConditional>(inner_); if (!isContinuous()) return nullptr;
return std::static_pointer_cast<GaussianConditional>(inner_);
} }
/** /**
@ -172,7 +174,8 @@ class GTSAM_EXPORT HybridConditional
*/ */
template <typename T = DiscreteConditional> template <typename T = DiscreteConditional>
typename T::shared_ptr asDiscrete() const { typename T::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<T>(inner_); if (!isDiscrete()) return nullptr;
return std::static_pointer_cast<T>(inner_);
} }
/// Get the type-erased pointer to the inner type /// Get the type-erased pointer to the inner type
@ -221,7 +224,8 @@ class GTSAM_EXPORT HybridConditional
* which is just a GaussianConditional. If this conditional is *not* a hybrid * which is just a GaussianConditional. If this conditional is *not* a hybrid
* conditional, just return that. * conditional, just return that.
*/ */
shared_ptr restrict(const DiscreteValues& discreteValues) const; std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @} /// @}

View File

@ -133,10 +133,14 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor. /// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; } const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors. /// Compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree( virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0; const VectorValues &values) const = 0;
/// Restrict the factor to the given discrete values.
virtual std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const = 0;
/// @} /// @}
private: private:

View File

@ -363,4 +363,12 @@ double HybridGaussianConditional::evaluate(const HybridValues &values) const {
return conditional->evaluate(values.continuous()); return conditional->evaluate(values.continuous());
} }
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianConditional::restrict(
const DiscreteValues &assignment) const {
throw std::runtime_error(
"HybridGaussianConditional::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -241,6 +241,10 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Return true if the conditional has already been pruned. /// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; } bool pruned() const { return pruned_; }
/// Restrict to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @} /// @}
private: private:

View File

@ -199,4 +199,12 @@ double HybridGaussianFactor::error(const HybridValues& values) const {
return PotentiallyPrunedComponentError(pair, values.continuous()); return PotentiallyPrunedComponentError(pair, values.continuous());
} }
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -157,6 +157,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/ */
virtual HybridGaussianProductFactor asProductFactor() const; virtual HybridGaussianProductFactor asProductFactor() const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @} /// @}
private: private:

View File

@ -239,4 +239,21 @@ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors); return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
} }
/* ************************************************************************ */
std::shared_ptr<Factor> HybridNonlinearFactor::restrict(
const DiscreteValues& assignment) const {
auto restrictedFactors = factors_.restrict(assignment);
auto filtered = assignment.filter(discreteKeys_);
if (filtered.size() == discreteKeys_.size()) {
auto [nonlinearFactor, val] = factors_(filtered);
return nonlinearFactor;
} else {
auto remainingKeys = assignment.missingKeys(discreteKeys());
return std::make_shared<HybridNonlinearFactor>(remainingKeys,
factors_.restrict(filtered));
}
}
/* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -80,6 +80,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
} }
public: public:
/// @name Constructors
/// @{
/// Default constructor, mainly for serialization. /// Default constructor, mainly for serialization.
HybridNonlinearFactor() = default; HybridNonlinearFactor() = default;
@ -137,7 +140,7 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
* @return double The error of this factor. * @return double The error of this factor.
*/ */
double error(const Values& continuousValues, double error(const Values& continuousValues,
const DiscreteValues& discreteValues) const; const DiscreteValues& assignment) const;
/** /**
* @brief Compute error of factor given hybrid values. * @brief Compute error of factor given hybrid values.
@ -154,7 +157,8 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
*/ */
size_t dim() const; size_t dim() const;
/// Testable /// @}
/// @name Testable
/// @{ /// @{
/// print to stdout /// print to stdout
@ -165,15 +169,16 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
bool equals(const HybridFactor& other, double tol = 1e-9) const override; bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @} /// @}
/// @name Standard API
/// @{
/// Getter for NonlinearFactor decision tree /// Getter for NonlinearFactor decision tree
const FactorValuePairs& factors() const { return factors_; } const FactorValuePairs& factors() const { return factors_; }
/// Linearize specific nonlinear factors based on the assignment in /// Linearize specific nonlinear factors based on the assignment in
/// discreteValues. /// discreteValues.
GaussianFactor::shared_ptr linearize( GaussianFactor::shared_ptr linearize(const Values& continuousValues,
const Values& continuousValues, const DiscreteValues& assignment) const;
const DiscreteValues& discreteValues) const;
/// Linearize all the continuous factors to get a HybridGaussianFactor. /// Linearize all the continuous factors to get a HybridGaussianFactor.
std::shared_ptr<HybridGaussianFactor> linearize( std::shared_ptr<HybridGaussianFactor> linearize(
@ -183,6 +188,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
HybridNonlinearFactor::shared_ptr prune( HybridNonlinearFactor::shared_ptr prune(
const DecisionTreeFactor& discreteProbs) const; const DecisionTreeFactor& discreteProbs) const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @}
private: private:
/// Helper struct to assist private constructor below. /// Helper struct to assist private constructor below.
struct ConstructorHelper; struct ConstructorHelper;

View File

@ -221,5 +221,30 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
return p / p.sum(); return p / p.sum();
} }
/* ************************************************************************ */
HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
const DiscreteValues& discreteValues) const {
using std::dynamic_pointer_cast;
HybridNonlinearFactorGraph result;
result.reserve(size());
for (auto& f : factors_) {
// First check if it is a valid factor
if (!f) {
continue;
}
// Check if it is a hybrid factor
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
result.push_back(hf->restrict(discreteValues));
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
result.push_back(df->restrict(discreteValues));
} else {
result.push_back(f); // Everything else is just added as is
}
}
return result;
}
/* ************************************************************************ */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -116,6 +116,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
AlgebraicDecisionTree<Key> discretePosterior( AlgebraicDecisionTree<Key> discretePosterior(
const Values& continuousValues) const; const Values& continuousValues) const;
/// Restrict all factors in the graph to the given discrete values.
HybridNonlinearFactorGraph restrict(
const DiscreteValues& assignment) const;
/// @} /// @}
}; };

View File

@ -131,6 +131,18 @@ TEST(HybridNonlinearFactor, Dim) {
EXPECT_LONGS_EQUAL(1, hybridFactor.dim()); EXPECT_LONGS_EQUAL(1, hybridFactor.dim());
} }
/* ************************************************************************* */
// Test restrict method
TEST(HybridNonlinearFactor, Restrict) {
using namespace test_constructor;
HybridNonlinearFactor factor(m1, {f0, f1});
DiscreteValues assignment = {{m1.first, 0}};
auto restricted = factor.restrict(assignment);
auto betweenFactor = dynamic_pointer_cast<BetweenFactor<double>>(restricted);
CHECK(betweenFactor);
EXPECT(assert_equal(*f0, *betweenFactor));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;