Merge branch 'develop' into city10000

release/4.3a0
Varun Agrawal 2025-02-03 16:02:03 -05:00
commit 69424c6b29
31 changed files with 314 additions and 123 deletions

View File

@ -68,12 +68,15 @@ class Experiment {
size_t maxNrHypotheses = 10;
size_t reLinearizationFrequency = 10;
double marginalThreshold = 0.9999;
private:
std::string filename_;
HybridSmoother smoother_;
HybridNonlinearFactorGraph graph_;
HybridNonlinearFactorGraph newFactors_, allFactors_;
Values initial_;
Values result_;
/**
* @brief Write the result of optimization to file.
@ -83,7 +86,7 @@ class Experiment {
* @param filename The file name to save the result to.
*/
void writeResult(const Values& result, size_t numPoses,
const std::string& filename = "Hybrid_city10000.txt") {
const std::string& filename = "Hybrid_city10000.txt") const {
std::ofstream outfile;
outfile.open(filename);
@ -100,9 +103,9 @@ class Experiment {
* @brief Create a hybrid loop closure factor where
* 0 - loose noise model and 1 - loop noise model.
*/
HybridNonlinearFactor hybridLoopClosureFactor(size_t loopCounter, size_t keyS,
size_t keyT,
const Pose2& measurement) {
HybridNonlinearFactor hybridLoopClosureFactor(
size_t loopCounter, size_t keyS, size_t keyT,
const Pose2& measurement) const {
DiscreteKey l(L(loopCounter), 2);
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
@ -119,7 +122,7 @@ class Experiment {
/// @brief Create hybrid odometry factor with discrete measurement choices.
HybridNonlinearFactor hybridOdometryFactor(
size_t numMeasurements, size_t keyS, size_t keyT, const DiscreteKey& m,
const std::vector<Pose2>& poseArray) {
const std::vector<Pose2>& poseArray) const {
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
X(keyS), X(keyT), poseArray[0], kPoseNoiseModel);
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
@ -132,25 +135,59 @@ class Experiment {
}
/// @brief Perform smoother update and optimize the graph.
void smootherUpdate(HybridSmoother& smoother,
HybridNonlinearFactorGraph& graph, const Values& initial,
size_t maxNrHypotheses, Values* result) {
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
smoother.update(linearized, maxNrHypotheses);
// throw if x0 not in hybridBayesNet_:
const KeySet& keys = smoother.hybridBayesNet().keys();
if (keys.find(X(0)) == keys.end()) {
throw std::runtime_error("x0 not in hybridBayesNet_");
auto smootherUpdate(size_t maxNrHypotheses) {
std::cout << "Smoother update: " << newFactors_.size() << std::endl;
gttic_(SmootherUpdate);
clock_t beforeUpdate = clock();
auto linearized = newFactors_.linearize(initial_);
smoother_.update(*linearized, maxNrHypotheses);
allFactors_.push_back(newFactors_);
newFactors_.resize(0);
clock_t afterUpdate = clock();
return afterUpdate - beforeUpdate;
}
/// @brief Re-linearize, solve ALL, and re-initialize smoother.
auto reInitialize() {
std::cout << "================= Re-Initialize: " << allFactors_.size()
<< std::endl;
clock_t beforeUpdate = clock();
allFactors_ = allFactors_.restrict(smoother_.fixedValues());
auto linearized = allFactors_.linearize(initial_);
auto bayesNet = linearized->eliminateSequential();
HybridValues delta = bayesNet->optimize();
initial_ = initial_.retract(delta.continuous());
smoother_.reInitialize(std::move(*bayesNet));
clock_t afterUpdate = clock();
std::cout << "Took " << (afterUpdate - beforeUpdate) / CLOCKS_PER_SEC
<< " seconds." << std::endl;
return afterUpdate - beforeUpdate;
}
// Parse line from file
std::pair<std::vector<Pose2>, std::pair<size_t, size_t>> parseLine(
const std::string& line) const {
std::vector<std::string> parts;
split(parts, line, is_any_of(" "));
size_t keyS = stoi(parts[1]);
size_t keyT = stoi(parts[3]);
int numMeasurements = stoi(parts[5]);
std::vector<Pose2> poseArray(numMeasurements);
for (int i = 0; i < numMeasurements; ++i) {
double x = stod(parts[6 + 3 * i]);
double y = stod(parts[7 + 3 * i]);
double rad = stod(parts[8 + 3 * i]);
poseArray[i] = Pose2(x, y, rad);
}
graph.resize(0);
// HybridValues delta = smoother.hybridBayesNet().optimize();
// result->insert_or_assign(initial.retract(delta.continuous()));
return {poseArray, {keyS, keyT}};
}
public:
/// Construct with filename of experiment to run
explicit Experiment(const std::string& filename)
: filename_(filename), smoother_(0.99) {}
: filename_(filename), smoother_(marginalThreshold) {}
/// @brief Run the main experiment with a given maxLoopCount.
void run() {
@ -162,49 +199,34 @@ class Experiment {
}
// Initialize local variables
size_t discreteCount = 0, index = 0;
size_t loopCount = 0;
size_t discreteCount = 0, index = 0, loopCount = 0, updateCount = 0;
std::list<double> timeList;
// Set up initial prior
double x = 0.0;
double y = 0.0;
double rad = 0.0;
Pose2 priorPose(x, y, rad);
Pose2 priorPose(0, 0, 0);
initial_.insert(X(0), priorPose);
graph_.push_back(PriorFactor<Pose2>(X(0), priorPose, kPriorNoiseModel));
newFactors_.push_back(
PriorFactor<Pose2>(X(0), priorPose, kPriorNoiseModel));
// Initial update
clock_t beforeUpdate = clock();
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
clock_t afterUpdate = clock();
auto time = smootherUpdate(maxNrHypotheses);
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
smootherUpdateTimes.push_back({index, time});
// Flag to decide whether to run smoother update
size_t numberOfHybridFactors = 0;
// Start main loop
Values result;
size_t keyS = 0, keyT = 0;
clock_t startTime = clock();
std::string line;
while (getline(in, line) && index < maxLoopCount) {
std::vector<std::string> parts;
split(parts, line, is_any_of(" "));
keyS = stoi(parts[1]);
keyT = stoi(parts[3]);
int numMeasurements = stoi(parts[5]);
std::vector<Pose2> poseArray(numMeasurements);
for (int i = 0; i < numMeasurements; ++i) {
x = stod(parts[6 + 3 * i]);
y = stod(parts[7 + 3 * i]);
rad = stod(parts[8 + 3 * i]);
poseArray[i] = Pose2(x, y, rad);
}
auto [poseArray, keys] = parseLine(line);
keyS = keys.first;
keyT = keys.second;
size_t numMeasurements = poseArray.size();
// Take the first one as the initial estimate
Pose2 odomPose = poseArray[0];
@ -215,13 +237,13 @@ class Experiment {
DiscreteKey m(M(discreteCount), numMeasurements);
HybridNonlinearFactor mixtureFactor =
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
graph_.push_back(mixtureFactor);
newFactors_.push_back(mixtureFactor);
discreteCount++;
numberOfHybridFactors += 1;
std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
} else {
graph_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
kPoseNoiseModel));
newFactors_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
kPoseNoiseModel));
}
// Insert next pose initial guess
initial_.insert(X(keyT), initial_.at<Pose2>(X(keyS)) * odomPose);
@ -231,21 +253,20 @@ class Experiment {
hybridLoopClosureFactor(loopCount, keyS, keyT, odomPose);
// print loop closure event keys:
std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
graph_.add(loopFactor);
newFactors_.add(loopFactor);
numberOfHybridFactors += 1;
loopCount++;
}
if (numberOfHybridFactors >= updateFrequency) {
// print the keys involved in the smoother update
std::cout << "Smoother update: " << graph_.size() << std::endl;
gttic_(SmootherUpdate);
beforeUpdate = clock();
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
afterUpdate = clock();
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
gttoc_(SmootherUpdate);
auto time = smootherUpdate(maxNrHypotheses);
smootherUpdateTimes.push_back({index, time});
numberOfHybridFactors = 0;
updateCount++;
if (updateCount % reLinearizationFrequency == 0) {
reInitialize();
}
}
// Record timing for odometry edges only
@ -270,17 +291,15 @@ class Experiment {
}
// Final update
beforeUpdate = clock();
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
afterUpdate = clock();
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
time = smootherUpdate(maxNrHypotheses);
smootherUpdateTimes.push_back({index, time});
// Final optimize
gttic_(HybridSmootherOptimize);
HybridValues delta = smoother_.optimize();
gttoc_(HybridSmootherOptimize);
result_.insert_or_assign(initial_.retract(delta.continuous()));
result.insert_or_assign(initial_.retract(delta.continuous()));
std::cout << "Final error: " << smoother_.hybridBayesNet().error(delta)
<< std::endl;
@ -291,7 +310,7 @@ class Experiment {
<< std::endl;
// Write results to file
writeResult(result_, keyT + 1, "Hybrid_City10000.txt");
writeResult(result, keyT + 1, "Hybrid_City10000.txt");
// TODO Write to file
// for (size_t i = 0; i < smoother_update_times.size(); i++) {

View File

@ -393,6 +393,13 @@ namespace gtsam {
return DecisionTree(newRoot);
}
/** Choose multiple values. */
DecisionTree restrict(const Assignment<L>& assignment) const {
NodePtr newRoot = root_;
for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v);
return DecisionTree(newRoot);
}
/** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;

View File

@ -540,5 +540,11 @@ namespace gtsam {
return DecisionTreeFactor(this->discreteKeys(), thresholded);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("DecisionTreeFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -220,6 +220,10 @@ namespace gtsam {
return combine(keys, Ring::max);
}
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @}
/// @name Advanced Interface
/// @{

View File

@ -167,8 +167,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/**
* @brief Scale the factor values by the maximum
* to prevent underflow/overflow.
*
* @return DiscreteFactor::shared_ptr
*
* @return DiscreteFactor::shared_ptr
*/
DiscreteFactor::shared_ptr scale() const;
@ -178,6 +178,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
*/
virtual uint64_t nrValues() const = 0;
/// Restrict the factor to the given assignment.
virtual DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const = 0;
/// @}
/// @name Wrapper support
/// @{

View File

@ -391,12 +391,12 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
return combine(nrFrontals, Ring::add);
return combine(nrFrontals, Ring::add);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
return combine(keys, Ring::add);
return combine(keys, Ring::add);
}
/* ************************************************************************ */
@ -418,7 +418,6 @@ DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
return combine(keys, Ring::max);
}
/* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor.
@ -781,5 +780,11 @@ TableFactor TableFactor::prune(size_t maxNrAssignments) const {
return TableFactor(this->discreteKeys(), pruned_vec);
}
/* ************************************************************************ */
DiscreteFactor::shared_ptr TableFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("TableFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -342,6 +342,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
/// Restrict the factor to the given assignment.
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override;
/// @}
/// @name Wrapper support
/// @{

View File

@ -59,21 +59,30 @@ TEST(ADT, arithmetic) {
// Negate and subtraction
CHECK(assert_equal(-a, zero - a));
#ifdef GTSAM_DT_MERGING
CHECK(assert_equal({zero}, a - a));
#else
CHECK(assert_equal({A, 0, 0}, a - a));
#endif
CHECK(assert_equal(a + b, b + a));
CHECK(assert_equal({A, 3, 4}, a + 2));
CHECK(assert_equal({B, 1, 2}, b - 2));
// Multiplication
#ifdef GTSAM_DT_MERGING
CHECK(assert_equal(zero, zero * a));
CHECK(assert_equal(zero, a * zero));
#else
CHECK(assert_equal({A, 0, 0}, zero * a));
#endif
CHECK(assert_equal(a, one * a));
CHECK(assert_equal(a, a * one));
CHECK(assert_equal(a * b, b * a));
#ifdef GTSAM_DT_MERGING
// division
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
CHECK(assert_equal(b, (a * b) / a));
#endif
}
/* ************************************************************************** */

View File

@ -228,9 +228,9 @@ TEST(DecisionTree, Example) {
// Test choose 0
DT actual0 = notba.choose(A, 0);
#ifdef GTSAM_DT_MERGING
EXPECT(assert_equal(DT(0.0), actual0));
EXPECT(assert_equal(DT(0), actual0));
#else
EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
EXPECT(assert_equal(DT(B, 0, 0), actual0));
#endif
DOT(actual0);
@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) {
#endif
}
/* ************************************************************************** */
// Test apply with assignment.
TEST(DecisionTree, Restrict) {
// Create three level tree
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
DT::LabelC("A", 2)};
DT tree(keys, "1 2 3 4 5 6 7 8");
DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}});
EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree));
DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}});
EXPECT(assert_equal(DT(8), restrictMore));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -124,8 +124,10 @@ TEST(DecisionTreeFactor, Divide) {
EXPECT(assert_inequal(pS, s));
// The underlying data should be the same
#ifdef GTSAM_DT_MERGING
using ADT = AlgebraicDecisionTree<Key>;
EXPECT(assert_equal(ADT(pS), ADT(s)));
#endif
KeySet keys(joint.keys());
keys.insert(pA.keys().begin(), pA.keys().end());

View File

@ -69,11 +69,13 @@ HybridBayesNet HybridBayesNet::prune(
// Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further.
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
for (std::shared_ptr<HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional.
if (marginalThreshold) conditional = conditional->restrict(fixed);
if (marginalThreshold)
conditional = std::static_pointer_cast<HybridConditional>(
conditional->restrict(fixed));
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {

View File

@ -170,8 +170,8 @@ double HybridConditional::evaluate(const HybridValues &values) const {
}
/* ************************************************************************ */
HybridConditional::shared_ptr HybridConditional::restrict(
const DiscreteValues &discreteValues) const {
std::shared_ptr<Factor> HybridConditional::restrict(
const DiscreteValues &assignment) const {
if (auto gc = asGaussian()) {
return std::make_shared<HybridConditional>(gc);
} else if (auto dc = asDiscrete()) {
@ -184,21 +184,20 @@ HybridConditional::shared_ptr HybridConditional::restrict(
"HybridConditional::restrict: conditional type not handled");
// Case 1: Fully determined, return corresponding Gaussian conditional
auto parentValues = discreteValues.filter(discreteKeys_);
auto parentValues = assignment.filter(discreteKeys_);
if (parentValues.size() == discreteKeys_.size()) {
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
}
// Case 2: Some live parents remain, build a new tree
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
if (!unspecifiedParentKeys.empty()) {
auto remainingKeys = assignment.missingKeys(discreteKeys_);
if (!remainingKeys.empty()) {
auto newTree = hgc->factors();
for (const auto &[key, value] : parentValues) {
newTree = newTree.choose(key, value);
}
return std::make_shared<HybridConditional>(
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
newTree));
std::make_shared<HybridGaussianConditional>(remainingKeys, newTree));
}
// Case 3: No changes needed, return original

View File

@ -153,7 +153,8 @@ class GTSAM_EXPORT HybridConditional
* @return HybridGaussianConditional::shared_ptr otherwise
*/
HybridGaussianConditional::shared_ptr asHybrid() const {
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
if (!isHybrid()) return nullptr;
return std::static_pointer_cast<HybridGaussianConditional>(inner_);
}
/**
@ -162,7 +163,8 @@ class GTSAM_EXPORT HybridConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() const {
return std::dynamic_pointer_cast<GaussianConditional>(inner_);
if (!isContinuous()) return nullptr;
return std::static_pointer_cast<GaussianConditional>(inner_);
}
/**
@ -172,7 +174,8 @@ class GTSAM_EXPORT HybridConditional
*/
template <typename T = DiscreteConditional>
typename T::shared_ptr asDiscrete() const {
return std::dynamic_pointer_cast<T>(inner_);
if (!isDiscrete()) return nullptr;
return std::static_pointer_cast<T>(inner_);
}
/// Get the type-erased pointer to the inner type
@ -221,7 +224,8 @@ class GTSAM_EXPORT HybridConditional
* which is just a GaussianConditional. If this conditional is *not* a hybrid
* conditional, just return that.
*/
shared_ptr restrict(const DiscreteValues& discreteValues) const;
std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @}

View File

@ -133,10 +133,14 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// Return only the continuous keys for this factor.
const KeyVector &continuousKeys() const { return continuousKeys_; }
/// Virtual class to compute tree of linear errors.
/// Compute tree of linear errors.
virtual AlgebraicDecisionTree<Key> errorTree(
const VectorValues &values) const = 0;
/// Restrict the factor to the given discrete values.
virtual std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const = 0;
/// @}
private:

View File

@ -363,4 +363,12 @@ double HybridGaussianConditional::evaluate(const HybridValues &values) const {
return conditional->evaluate(values.continuous());
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianConditional::restrict(
const DiscreteValues &assignment) const {
throw std::runtime_error(
"HybridGaussianConditional::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -241,6 +241,10 @@ class GTSAM_EXPORT HybridGaussianConditional
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }
/// Restrict to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @}
private:

View File

@ -199,4 +199,12 @@ double HybridGaussianFactor::error(const HybridValues& values) const {
return PotentiallyPrunedComponentError(pair, values.continuous());
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
const DiscreteValues& assignment) const {
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -157,6 +157,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
virtual HybridGaussianProductFactor asProductFactor() const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues &discreteValues) const override;
/// @}
private:

View File

@ -239,4 +239,21 @@ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
}
/* ************************************************************************ */
std::shared_ptr<Factor> HybridNonlinearFactor::restrict(
const DiscreteValues& assignment) const {
auto restrictedFactors = factors_.restrict(assignment);
auto filtered = assignment.filter(discreteKeys_);
if (filtered.size() == discreteKeys_.size()) {
auto [nonlinearFactor, val] = factors_(filtered);
return nonlinearFactor;
} else {
auto remainingKeys = assignment.missingKeys(discreteKeys());
return std::make_shared<HybridNonlinearFactor>(remainingKeys,
factors_.restrict(filtered));
}
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -80,6 +80,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
}
public:
/// @name Constructors
/// @{
/// Default constructor, mainly for serialization.
HybridNonlinearFactor() = default;
@ -137,7 +140,7 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
* @return double The error of this factor.
*/
double error(const Values& continuousValues,
const DiscreteValues& discreteValues) const;
const DiscreteValues& assignment) const;
/**
* @brief Compute error of factor given hybrid values.
@ -154,7 +157,8 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
*/
size_t dim() const;
/// Testable
/// @}
/// @name Testable
/// @{
/// print to stdout
@ -165,15 +169,16 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
/// @}
/// @name Standard API
/// @{
/// Getter for NonlinearFactor decision tree
const FactorValuePairs& factors() const { return factors_; }
/// Linearize specific nonlinear factors based on the assignment in
/// discreteValues.
GaussianFactor::shared_ptr linearize(
const Values& continuousValues,
const DiscreteValues& discreteValues) const;
GaussianFactor::shared_ptr linearize(const Values& continuousValues,
const DiscreteValues& assignment) const;
/// Linearize all the continuous factors to get a HybridGaussianFactor.
std::shared_ptr<HybridGaussianFactor> linearize(
@ -183,6 +188,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
HybridNonlinearFactor::shared_ptr prune(
const DecisionTreeFactor& discreteProbs) const;
/// Restrict the factor to the given discrete values.
std::shared_ptr<Factor> restrict(
const DiscreteValues& assignment) const override;
/// @}
private:
/// Helper struct to assist private constructor below.
struct ConstructorHelper;

View File

@ -221,5 +221,30 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
return p / p.sum();
}
/* ************************************************************************ */
HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
const DiscreteValues& discreteValues) const {
using std::dynamic_pointer_cast;
HybridNonlinearFactorGraph result;
result.reserve(size());
for (auto& f : factors_) {
// First check if it is a valid factor
if (!f) {
continue;
}
// Check if it is a hybrid factor
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
result.push_back(hf->restrict(discreteValues));
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
result.push_back(df->restrict(discreteValues));
} else {
result.push_back(f); // Everything else is just added as is
}
}
return result;
}
/* ************************************************************************ */
} // namespace gtsam

View File

@ -116,6 +116,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
AlgebraicDecisionTree<Key> discretePosterior(
const Values& continuousValues) const;
/// Restrict all factors in the graph to the given discrete values.
HybridNonlinearFactorGraph restrict(
const DiscreteValues& assignment) const;
/// @}
};

View File

@ -27,7 +27,6 @@ namespace gtsam {
class GTSAM_EXPORT HybridSmoother {
private:
HybridBayesNet hybridBayesNet_;
HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode.
std::optional<double> marginalThreshold_;
@ -44,6 +43,16 @@ class GTSAM_EXPORT HybridSmoother {
HybridSmoother(const std::optional<double> marginalThreshold = {})
: marginalThreshold_(marginalThreshold) {}
/// Return fixed values:
const DiscreteValues& fixedValues() const { return fixedValues_; }
/**
* Re-initialize the smoother from a new hybrid Bayes Net.
*/
void reInitialize(HybridBayesNet&& hybridBayesNet) {
hybridBayesNet_ = std::move(hybridBayesNet);
}
/**
* Given new factors, perform an incremental update.
* The relevant densities in the `hybridBayesNet` will be added to the input

View File

@ -318,21 +318,27 @@ TEST(HybridGaussianConditional, Restrict) {
const auto hc =
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
const HybridConditional::shared_ptr same = hc->restrict({});
const auto same =
std::dynamic_pointer_cast<HybridConditional>(hc->restrict({}));
CHECK(same);
EXPECT(same->isHybrid());
EXPECT(same->asHybrid()->nrComponents() == 4);
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
const auto oneParent =
std::dynamic_pointer_cast<HybridConditional>(hc->restrict({{M(1), 0}}));
CHECK(oneParent);
EXPECT(oneParent->isHybrid());
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr oneParent2 =
hc->restrict({{M(7), 0}, {M(1), 0}});
const auto oneParent2 = std::dynamic_pointer_cast<HybridConditional>(
hc->restrict({{M(7), 0}, {M(1), 0}}));
CHECK(oneParent2);
EXPECT(oneParent2->isHybrid());
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
const HybridConditional::shared_ptr gaussian =
hc->restrict({{M(1), 0}, {M(2), 1}});
const auto gaussian = std::dynamic_pointer_cast<HybridConditional>(
hc->restrict({{M(1), 0}, {M(2), 1}}));
CHECK(gaussian);
EXPECT(gaussian->asGaussian());
}

View File

@ -131,6 +131,18 @@ TEST(HybridNonlinearFactor, Dim) {
EXPECT_LONGS_EQUAL(1, hybridFactor.dim());
}
/* ************************************************************************* */
// Test restrict method
TEST(HybridNonlinearFactor, Restrict) {
using namespace test_constructor;
HybridNonlinearFactor factor(m1, {f0, f1});
DiscreteValues assignment = {{m1.first, 0}};
auto restricted = factor.restrict(assignment);
auto betweenFactor = dynamic_pointer_cast<BetweenFactor<double>>(restricted);
CHECK(betweenFactor);
EXPECT(assert_equal(*f0, *betweenFactor));
}
/* ************************************************************************* */
int main() {
TestResult tr;

View File

@ -102,12 +102,16 @@ TEST(HybridSmoother, IncrementalSmoother) {
graph.resize(0);
}
EXPECT_LONGS_EQUAL(11,
smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
auto& hybridBayesNet = smoother.hybridBayesNet();
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(11, hybridBayesNet.at(5)->asDiscrete()->nrValues());
#else
EXPECT_LONGS_EQUAL(16, hybridBayesNet.at(5)->asDiscrete()->nrValues());
#endif
// Get the continuous delta update as well as
// the optimal discrete assignment.
HybridValues delta = smoother.hybridBayesNet().optimize();
HybridValues delta = hybridBayesNet.optimize();
// Check discrete assignment
DiscreteValues expected_discrete;
@ -156,8 +160,12 @@ TEST(HybridSmoother, ValidPruningError) {
graph.resize(0);
}
EXPECT_LONGS_EQUAL(14,
smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
auto& hybridBayesNet = smoother.hybridBayesNet();
#ifdef GTSAM_DT_MERGING
EXPECT_LONGS_EQUAL(14, hybridBayesNet.at(8)->asDiscrete()->nrValues());
#else
EXPECT_LONGS_EQUAL(128, hybridBayesNet.at(8)->asDiscrete()->nrValues());
#endif
// Get the continuous delta update as well as
// the optimal discrete assignment.

View File

@ -53,11 +53,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Multiply into a decisiontree
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("AllDiff::error not implemented");
}
/*
* Ensure Arc-consistency by checking every possible value of domain j.
* @param j domain to be checked

View File

@ -91,11 +91,6 @@ class BinaryAllDiff : public Constraint {
const Domains&) const override {
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
}
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("BinaryAllDiff::error not implemented");
}
};
} // namespace gtsam

View File

@ -125,6 +125,17 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
return toDecisionTreeFactor().max(keys);
}
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("Constraint::error not implemented");
}
/// Compute error for each assignment and return as a tree
DiscreteFactor::shared_ptr restrict(
const DiscreteValues& assignment) const override {
throw std::runtime_error("Constraint::restrict not implemented");
}
/// @}
/// @name Wrapper support
/// @{

View File

@ -69,11 +69,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
}
}
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("Domain::error not implemented");
}
// Return concise string representation, mostly to debug arc consistency.
// Converts from base 0 to base1.
std::string base1Str() const;

View File

@ -49,11 +49,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
}
}
/// Compute error for each assignment and return as a tree
AlgebraicDecisionTree<Key> errorTree() const override {
throw std::runtime_error("SingleValue::error not implemented");
}
/// Calculate value
double evaluate(const Assignment<Key>& values) const override;