Restrict for hybrid factors (and discrete)

release/4.3a0
Frank Dellaert 2025-02-01 02:28:01 -05:00
parent 352c7f2efa
commit ea27bac018
18 changed files with 149 additions and 24 deletions

View File

@ -536,5 +536,11 @@ namespace gtsam {
return DecisionTreeFactor(this->discreteKeys(), thresholded);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("DecisionTreeFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -220,6 +220,10 @@ namespace gtsam {
return combine(keys, Ring::max);
}
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @}
/// @name Advanced Interface
/// @{

View File

@ -178,6 +178,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
*/
virtual uint64_t nrValues() const = 0;
/// Restrict the factor to the given assignment.
virtual DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const = 0;
/// @}
/// @name Wrapper support
/// @{

View File

@ -391,12 +391,12 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add);
return combine(nrFrontals, Ring::add);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
return combine(keys, Ring::add);
return combine(keys, Ring::add);
}
/* ************************************************************************ */
@ -418,7 +418,6 @@ DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
/* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor.
@ -781,5 +780,11 @@ TableFactor TableFactor::prune(size_t maxNrAssignments) const {
return TableFactor(this->discreteKeys(), pruned_vec);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("TableFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -342,6 +342,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @}
/// @name Wrapper support
/// @{

View File

@ -69,11 +69,13 @@ HybridBayesNet HybridBayesNet::prune(
// Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further.
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
for (std::shared_ptr<HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional.
if (marginalThreshold) conditional = conditional->restrict(fixed);
if (marginalThreshold)
conditional = std::static_pointer_cast<HybridConditional>(
conditional->restrict(fixed));
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {

View File

@ -170,8 +170,8 @@ double HybridConditional::evaluate(const HybridValues &values) const {
}
/* ************************************************************************ */
HybridConditional::shared_ptr HybridConditional::restrict(
const DiscreteValues &discreteValues) const {
std::shared_ptr<Factor> HybridConditional::restrict(
const DiscreteValues &assignment) const {
if (auto gc = asGaussian()) {
return std::make_shared<HybridConditional>(gc);
} else if (auto dc = asDiscrete()) {
@ -184,21 +184,20 @@ HybridConditional::shared_ptr HybridConditional::restrict(
"HybridConditional::restrict: conditional type not handled");
// Case 1: Fully determined, return corresponding Gaussian conditional
auto parentValues = discreteValues.filter(discreteKeys_);
auto parentValues = assignment.filter(discreteKeys_);
if (parentValues.size() == discreteKeys_.size()) {
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
}
// Case 2: Some live parents remain, build a new tree
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
if (!unspecifiedParentKeys.empty()) {
auto remainingKeys = assignment.missingKeys(discreteKeys_);
if (!remainingKeys.empty()) {
auto newTree = hgc->factors();
for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value);
}
return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
newTree));
std::make_shared<HybridGaussianConditional>(remainingKeys, newTree));
}
// Case 3: No changes needed, return original

View File

@ -153,7 +153,8 @@ class GTSAM_EXPORT HybridConditional
* @return HybridGaussianConditional::shared_ptr otherwise
*/
HybridGaussianConditional::shared_ptr asHybrid() const {
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
if (!isHybrid()) return nullptr;
return std::static_pointer_cast<HybridGaussianConditional>(inner_);
}
/**
@ -162,7 +163,8 @@ class GTSAM_EXPORT HybridConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return std::dynamic_pointer_cast<GaussianConditional>(inner_);
if (!isContinuous()) return nullptr;
return std::static_pointer_cast<GaussianConditional>(inner_);
}
/**
@ -172,7 +174,8 @@ class GTSAM_EXPORT HybridConditional
*/
template <typename T = DiscreteConditional>
typename T::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<T>(inner_);
if (!isDiscrete()) return nullptr;
return std::static_pointer_cast<T>(inner_);
}
/// Get the type-erased pointer to the inner type
@ -221,7 +224,8 @@ class GTSAM_EXPORT HybridConditional
* which is just a GaussianConditional. If this conditional is *not* a hybrid
* conditional, just return that.
*/
shared_ptr restrict(const DiscreteValues& discreteValues) const;
std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @}

View File

@ -133,10 +133,14 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors.
/// Compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
/// Restrict the factor to the given discrete values.
virtual std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const = 0;
/// @}
private:

View File

@ -363,4 +363,12 @@ double HybridGaussianConditional::evaluate(const HybridValues &values) const {
return conditional->evaluate(values.continuous());
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianConditional::restrict(
const DiscreteValues &assignment) const {
throw std::runtime_error(
"HybridGaussianConditional::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -241,6 +241,10 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }
/// Restrict to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @}
private:

View File

@ -199,4 +199,12 @@ double HybridGaussianFactor::error(const HybridValues& values) const {
return PotentiallyPrunedComponentError(pair, values.continuous());
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -157,6 +157,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
virtual HybridGaussianProductFactor asProductFactor() const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @}
private:

View File

@ -239,4 +239,21 @@ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridNonlinearFactor::restrict(
const DiscreteValues& assignment) const {
auto restrictedFactors = factors_.restrict(assignment);
auto filtered = assignment.filter(discreteKeys_);
if (filtered.size() == discreteKeys_.size()) {
auto [nonlinearFactor, val] = factors_(filtered);
return nonlinearFactor;
} else {
auto remainingKeys = assignment.missingKeys(discreteKeys());
return std::make_shared<HybridNonlinearFactor>(remainingKeys,
factors_.restrict(filtered));
}
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -80,6 +80,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
}
public:
/// @name Constructors
/// @{
/// Default constructor, mainly for serialization.
HybridNonlinearFactor() = default;
@ -137,7 +140,7 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
* @return double The error of this factor.
*/
double error(const Values& continuousValues,
const DiscreteValues& discreteValues) const;
const DiscreteValues& assignment) const;
/**
* @brief Compute error of factor given hybrid values.
@ -154,7 +157,8 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
*/
size_t dim() const;
/// Testable
/// @}
/// @name Testable
/// @{
/// print to stdout
@ -165,15 +169,16 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard API
/// @{
/// Getter for NonlinearFactor decision tree
const FactorValuePairs& factors() const { return factors_; }
/// Linearize specific nonlinear factors based on the assignment in
/// discreteValues.
GaussianFactor::shared_ptr linearize(
const Values& continuousValues,
const DiscreteValues& discreteValues) const;
GaussianFactor::shared_ptr linearize(const Values& continuousValues,
const DiscreteValues& assignment) const;
/// Linearize all the continuous factors to get a HybridGaussianFactor.
std::shared_ptr<HybridGaussianFactor> linearize(
@ -183,6 +188,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
HybridNonlinearFactor::shared_ptr prune(
const DecisionTreeFactor& discreteProbs) const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @}
private:
/// Helper struct to assist private constructor below.
struct ConstructorHelper;

View File

@ -221,5 +221,30 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
return p / p.sum();
}
/* ************************************************************************ */
HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
const DiscreteValues& discreteValues) const {
using std::dynamic_pointer_cast;
HybridNonlinearFactorGraph result;
result.reserve(size());
for (auto& f : factors_) {
// First check if it is a valid factor
if (!f) {
continue;
}
// Check if it is a hybrid factor
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
result.push_back(hf->restrict(discreteValues));
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
result.push_back(df->restrict(discreteValues));
} else {
result.push_back(f); // Everything else is just added as is
}
}
return result;
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -116,6 +116,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
AlgebraicDecisionTree<Key> discretePosterior(
const Values& continuousValues) const;
/// Restrict all factors in the graph to the given discrete values.
HybridNonlinearFactorGraph restrict(
const DiscreteValues& assignment) const;
/// @}
};

View File

@ -131,6 +131,18 @@ TEST(HybridNonlinearFactor, Dim) {
EXPECT_LONGS_EQUAL(1, hybridFactor.dim());
}
/* ************************************************************************* */
// Test restrict method
TEST(HybridNonlinearFactor, Restrict) {
using namespace test_constructor;
HybridNonlinearFactor factor(m1, {f0, f1});
DiscreteValues assignment = {{m1.first, 0}};
auto restricted = factor.restrict(assignment);
auto betweenFactor = dynamic_pointer_cast<BetweenFactor<double>>(restricted);
CHECK(betweenFactor);
EXPECT(assert_equal(*f0, *betweenFactor));
}
/* ************************************************************************* */
int main() {
TestResult tr;