Merge branch 'develop' into city10000
commit
69424c6b29
|
@ -68,12 +68,15 @@ class Experiment {
|
||||||
|
|
||||||
size_t maxNrHypotheses = 10;
|
size_t maxNrHypotheses = 10;
|
||||||
|
|
||||||
|
size_t reLinearizationFrequency = 10;
|
||||||
|
|
||||||
|
double marginalThreshold = 0.9999;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string filename_;
|
std::string filename_;
|
||||||
HybridSmoother smoother_;
|
HybridSmoother smoother_;
|
||||||
HybridNonlinearFactorGraph graph_;
|
HybridNonlinearFactorGraph newFactors_, allFactors_;
|
||||||
Values initial_;
|
Values initial_;
|
||||||
Values result_;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Write the result of optimization to file.
|
* @brief Write the result of optimization to file.
|
||||||
|
@ -83,7 +86,7 @@ class Experiment {
|
||||||
* @param filename The file name to save the result to.
|
* @param filename The file name to save the result to.
|
||||||
*/
|
*/
|
||||||
void writeResult(const Values& result, size_t numPoses,
|
void writeResult(const Values& result, size_t numPoses,
|
||||||
const std::string& filename = "Hybrid_city10000.txt") {
|
const std::string& filename = "Hybrid_city10000.txt") const {
|
||||||
std::ofstream outfile;
|
std::ofstream outfile;
|
||||||
outfile.open(filename);
|
outfile.open(filename);
|
||||||
|
|
||||||
|
@ -100,9 +103,9 @@ class Experiment {
|
||||||
* @brief Create a hybrid loop closure factor where
|
* @brief Create a hybrid loop closure factor where
|
||||||
* 0 - loose noise model and 1 - loop noise model.
|
* 0 - loose noise model and 1 - loop noise model.
|
||||||
*/
|
*/
|
||||||
HybridNonlinearFactor hybridLoopClosureFactor(size_t loopCounter, size_t keyS,
|
HybridNonlinearFactor hybridLoopClosureFactor(
|
||||||
size_t keyT,
|
size_t loopCounter, size_t keyS, size_t keyT,
|
||||||
const Pose2& measurement) {
|
const Pose2& measurement) const {
|
||||||
DiscreteKey l(L(loopCounter), 2);
|
DiscreteKey l(L(loopCounter), 2);
|
||||||
|
|
||||||
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
|
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
|
||||||
|
@ -119,7 +122,7 @@ class Experiment {
|
||||||
/// @brief Create hybrid odometry factor with discrete measurement choices.
|
/// @brief Create hybrid odometry factor with discrete measurement choices.
|
||||||
HybridNonlinearFactor hybridOdometryFactor(
|
HybridNonlinearFactor hybridOdometryFactor(
|
||||||
size_t numMeasurements, size_t keyS, size_t keyT, const DiscreteKey& m,
|
size_t numMeasurements, size_t keyS, size_t keyT, const DiscreteKey& m,
|
||||||
const std::vector<Pose2>& poseArray) {
|
const std::vector<Pose2>& poseArray) const {
|
||||||
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
|
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
|
||||||
X(keyS), X(keyT), poseArray[0], kPoseNoiseModel);
|
X(keyS), X(keyT), poseArray[0], kPoseNoiseModel);
|
||||||
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
||||||
|
@ -132,25 +135,59 @@ class Experiment {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// @brief Perform smoother update and optimize the graph.
|
/// @brief Perform smoother update and optimize the graph.
|
||||||
void smootherUpdate(HybridSmoother& smoother,
|
auto smootherUpdate(size_t maxNrHypotheses) {
|
||||||
HybridNonlinearFactorGraph& graph, const Values& initial,
|
std::cout << "Smoother update: " << newFactors_.size() << std::endl;
|
||||||
size_t maxNrHypotheses, Values* result) {
|
gttic_(SmootherUpdate);
|
||||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
clock_t beforeUpdate = clock();
|
||||||
smoother.update(linearized, maxNrHypotheses);
|
auto linearized = newFactors_.linearize(initial_);
|
||||||
// throw if x0 not in hybridBayesNet_:
|
smoother_.update(*linearized, maxNrHypotheses);
|
||||||
const KeySet& keys = smoother.hybridBayesNet().keys();
|
allFactors_.push_back(newFactors_);
|
||||||
if (keys.find(X(0)) == keys.end()) {
|
newFactors_.resize(0);
|
||||||
throw std::runtime_error("x0 not in hybridBayesNet_");
|
clock_t afterUpdate = clock();
|
||||||
|
return afterUpdate - beforeUpdate;
|
||||||
}
|
}
|
||||||
graph.resize(0);
|
|
||||||
// HybridValues delta = smoother.hybridBayesNet().optimize();
|
/// @brief Re-linearize, solve ALL, and re-initialize smoother.
|
||||||
// result->insert_or_assign(initial.retract(delta.continuous()));
|
auto reInitialize() {
|
||||||
|
std::cout << "================= Re-Initialize: " << allFactors_.size()
|
||||||
|
<< std::endl;
|
||||||
|
clock_t beforeUpdate = clock();
|
||||||
|
allFactors_ = allFactors_.restrict(smoother_.fixedValues());
|
||||||
|
auto linearized = allFactors_.linearize(initial_);
|
||||||
|
auto bayesNet = linearized->eliminateSequential();
|
||||||
|
HybridValues delta = bayesNet->optimize();
|
||||||
|
initial_ = initial_.retract(delta.continuous());
|
||||||
|
smoother_.reInitialize(std::move(*bayesNet));
|
||||||
|
clock_t afterUpdate = clock();
|
||||||
|
std::cout << "Took " << (afterUpdate - beforeUpdate) / CLOCKS_PER_SEC
|
||||||
|
<< " seconds." << std::endl;
|
||||||
|
return afterUpdate - beforeUpdate;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse line from file
|
||||||
|
std::pair<std::vector<Pose2>, std::pair<size_t, size_t>> parseLine(
|
||||||
|
const std::string& line) const {
|
||||||
|
std::vector<std::string> parts;
|
||||||
|
split(parts, line, is_any_of(" "));
|
||||||
|
|
||||||
|
size_t keyS = stoi(parts[1]);
|
||||||
|
size_t keyT = stoi(parts[3]);
|
||||||
|
|
||||||
|
int numMeasurements = stoi(parts[5]);
|
||||||
|
std::vector<Pose2> poseArray(numMeasurements);
|
||||||
|
for (int i = 0; i < numMeasurements; ++i) {
|
||||||
|
double x = stod(parts[6 + 3 * i]);
|
||||||
|
double y = stod(parts[7 + 3 * i]);
|
||||||
|
double rad = stod(parts[8 + 3 * i]);
|
||||||
|
poseArray[i] = Pose2(x, y, rad);
|
||||||
|
}
|
||||||
|
return {poseArray, {keyS, keyT}};
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Construct with filename of experiment to run
|
/// Construct with filename of experiment to run
|
||||||
explicit Experiment(const std::string& filename)
|
explicit Experiment(const std::string& filename)
|
||||||
: filename_(filename), smoother_(0.99) {}
|
: filename_(filename), smoother_(marginalThreshold) {}
|
||||||
|
|
||||||
/// @brief Run the main experiment with a given maxLoopCount.
|
/// @brief Run the main experiment with a given maxLoopCount.
|
||||||
void run() {
|
void run() {
|
||||||
|
@ -162,49 +199,34 @@ class Experiment {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize local variables
|
// Initialize local variables
|
||||||
size_t discreteCount = 0, index = 0;
|
size_t discreteCount = 0, index = 0, loopCount = 0, updateCount = 0;
|
||||||
size_t loopCount = 0;
|
|
||||||
|
|
||||||
std::list<double> timeList;
|
std::list<double> timeList;
|
||||||
|
|
||||||
// Set up initial prior
|
// Set up initial prior
|
||||||
double x = 0.0;
|
Pose2 priorPose(0, 0, 0);
|
||||||
double y = 0.0;
|
|
||||||
double rad = 0.0;
|
|
||||||
|
|
||||||
Pose2 priorPose(x, y, rad);
|
|
||||||
initial_.insert(X(0), priorPose);
|
initial_.insert(X(0), priorPose);
|
||||||
graph_.push_back(PriorFactor<Pose2>(X(0), priorPose, kPriorNoiseModel));
|
newFactors_.push_back(
|
||||||
|
PriorFactor<Pose2>(X(0), priorPose, kPriorNoiseModel));
|
||||||
|
|
||||||
// Initial update
|
// Initial update
|
||||||
clock_t beforeUpdate = clock();
|
auto time = smootherUpdate(maxNrHypotheses);
|
||||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
|
||||||
clock_t afterUpdate = clock();
|
|
||||||
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
|
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
smootherUpdateTimes.push_back({index, time});
|
||||||
|
|
||||||
// Flag to decide whether to run smoother update
|
// Flag to decide whether to run smoother update
|
||||||
size_t numberOfHybridFactors = 0;
|
size_t numberOfHybridFactors = 0;
|
||||||
|
|
||||||
// Start main loop
|
// Start main loop
|
||||||
|
Values result;
|
||||||
size_t keyS = 0, keyT = 0;
|
size_t keyS = 0, keyT = 0;
|
||||||
clock_t startTime = clock();
|
clock_t startTime = clock();
|
||||||
std::string line;
|
std::string line;
|
||||||
while (getline(in, line) && index < maxLoopCount) {
|
while (getline(in, line) && index < maxLoopCount) {
|
||||||
std::vector<std::string> parts;
|
auto [poseArray, keys] = parseLine(line);
|
||||||
split(parts, line, is_any_of(" "));
|
keyS = keys.first;
|
||||||
|
keyT = keys.second;
|
||||||
keyS = stoi(parts[1]);
|
size_t numMeasurements = poseArray.size();
|
||||||
keyT = stoi(parts[3]);
|
|
||||||
|
|
||||||
int numMeasurements = stoi(parts[5]);
|
|
||||||
std::vector<Pose2> poseArray(numMeasurements);
|
|
||||||
for (int i = 0; i < numMeasurements; ++i) {
|
|
||||||
x = stod(parts[6 + 3 * i]);
|
|
||||||
y = stod(parts[7 + 3 * i]);
|
|
||||||
rad = stod(parts[8 + 3 * i]);
|
|
||||||
poseArray[i] = Pose2(x, y, rad);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Take the first one as the initial estimate
|
// Take the first one as the initial estimate
|
||||||
Pose2 odomPose = poseArray[0];
|
Pose2 odomPose = poseArray[0];
|
||||||
|
@ -215,12 +237,12 @@ class Experiment {
|
||||||
DiscreteKey m(M(discreteCount), numMeasurements);
|
DiscreteKey m(M(discreteCount), numMeasurements);
|
||||||
HybridNonlinearFactor mixtureFactor =
|
HybridNonlinearFactor mixtureFactor =
|
||||||
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
|
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
|
||||||
graph_.push_back(mixtureFactor);
|
newFactors_.push_back(mixtureFactor);
|
||||||
discreteCount++;
|
discreteCount++;
|
||||||
numberOfHybridFactors += 1;
|
numberOfHybridFactors += 1;
|
||||||
std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
|
std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
|
||||||
} else {
|
} else {
|
||||||
graph_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
|
newFactors_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
|
||||||
kPoseNoiseModel));
|
kPoseNoiseModel));
|
||||||
}
|
}
|
||||||
// Insert next pose initial guess
|
// Insert next pose initial guess
|
||||||
|
@ -231,21 +253,20 @@ class Experiment {
|
||||||
hybridLoopClosureFactor(loopCount, keyS, keyT, odomPose);
|
hybridLoopClosureFactor(loopCount, keyS, keyT, odomPose);
|
||||||
// print loop closure event keys:
|
// print loop closure event keys:
|
||||||
std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
|
std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
|
||||||
graph_.add(loopFactor);
|
newFactors_.add(loopFactor);
|
||||||
numberOfHybridFactors += 1;
|
numberOfHybridFactors += 1;
|
||||||
loopCount++;
|
loopCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (numberOfHybridFactors >= updateFrequency) {
|
if (numberOfHybridFactors >= updateFrequency) {
|
||||||
// print the keys involved in the smoother update
|
auto time = smootherUpdate(maxNrHypotheses);
|
||||||
std::cout << "Smoother update: " << graph_.size() << std::endl;
|
smootherUpdateTimes.push_back({index, time});
|
||||||
gttic_(SmootherUpdate);
|
|
||||||
beforeUpdate = clock();
|
|
||||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
|
||||||
afterUpdate = clock();
|
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
|
||||||
gttoc_(SmootherUpdate);
|
|
||||||
numberOfHybridFactors = 0;
|
numberOfHybridFactors = 0;
|
||||||
|
updateCount++;
|
||||||
|
|
||||||
|
if (updateCount % reLinearizationFrequency == 0) {
|
||||||
|
reInitialize();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record timing for odometry edges only
|
// Record timing for odometry edges only
|
||||||
|
@ -270,17 +291,15 @@ class Experiment {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Final update
|
// Final update
|
||||||
beforeUpdate = clock();
|
time = smootherUpdate(maxNrHypotheses);
|
||||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
smootherUpdateTimes.push_back({index, time});
|
||||||
afterUpdate = clock();
|
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
|
||||||
|
|
||||||
// Final optimize
|
// Final optimize
|
||||||
gttic_(HybridSmootherOptimize);
|
gttic_(HybridSmootherOptimize);
|
||||||
HybridValues delta = smoother_.optimize();
|
HybridValues delta = smoother_.optimize();
|
||||||
gttoc_(HybridSmootherOptimize);
|
gttoc_(HybridSmootherOptimize);
|
||||||
|
|
||||||
result_.insert_or_assign(initial_.retract(delta.continuous()));
|
result.insert_or_assign(initial_.retract(delta.continuous()));
|
||||||
|
|
||||||
std::cout << "Final error: " << smoother_.hybridBayesNet().error(delta)
|
std::cout << "Final error: " << smoother_.hybridBayesNet().error(delta)
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
@ -291,7 +310,7 @@ class Experiment {
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
// Write results to file
|
// Write results to file
|
||||||
writeResult(result_, keyT + 1, "Hybrid_City10000.txt");
|
writeResult(result, keyT + 1, "Hybrid_City10000.txt");
|
||||||
|
|
||||||
// TODO Write to file
|
// TODO Write to file
|
||||||
// for (size_t i = 0; i < smoother_update_times.size(); i++) {
|
// for (size_t i = 0; i < smoother_update_times.size(); i++) {
|
||||||
|
|
|
@ -393,6 +393,13 @@ namespace gtsam {
|
||||||
return DecisionTree(newRoot);
|
return DecisionTree(newRoot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Choose multiple values. */
|
||||||
|
DecisionTree restrict(const Assignment<L>& assignment) const {
|
||||||
|
NodePtr newRoot = root_;
|
||||||
|
for (const auto& [l, v] : assignment) newRoot = newRoot->choose(l, v);
|
||||||
|
return DecisionTree(newRoot);
|
||||||
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality,
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
const Binary& op) const;
|
const Binary& op) const;
|
||||||
|
|
|
@ -540,5 +540,11 @@ namespace gtsam {
|
||||||
return DecisionTreeFactor(this->discreteKeys(), thresholded);
|
return DecisionTreeFactor(this->discreteKeys(), thresholded);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
|
throw std::runtime_error("DecisionTreeFactor::restrict not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -220,6 +220,10 @@ namespace gtsam {
|
||||||
return combine(keys, Ring::max);
|
return combine(keys, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Restrict the factor to the given assignment.
|
||||||
|
DiscreteFactor::shared_ptr restrict(
|
||||||
|
const DiscreteValues& assignment) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -178,6 +178,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
*/
|
*/
|
||||||
virtual uint64_t nrValues() const = 0;
|
virtual uint64_t nrValues() const = 0;
|
||||||
|
|
||||||
|
/// Restrict the factor to the given assignment.
|
||||||
|
virtual DiscreteFactor::shared_ptr restrict(
|
||||||
|
const DiscreteValues& assignment) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -391,12 +391,12 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
|
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
|
||||||
return combine(nrFrontals, Ring::add);
|
return combine(nrFrontals, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
|
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
|
||||||
return combine(keys, Ring::add);
|
return combine(keys, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -418,7 +418,6 @@ DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
|
||||||
return combine(keys, Ring::max);
|
return combine(keys, Ring::max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor TableFactor::apply(Unary op) const {
|
TableFactor TableFactor::apply(Unary op) const {
|
||||||
// Initialize new factor.
|
// Initialize new factor.
|
||||||
|
@ -781,5 +780,11 @@ TableFactor TableFactor::prune(size_t maxNrAssignments) const {
|
||||||
return TableFactor(this->discreteKeys(), pruned_vec);
|
return TableFactor(this->discreteKeys(), pruned_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::restrict(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
|
throw std::runtime_error("TableFactor::restrict not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -342,6 +342,10 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
|
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }
|
||||||
|
|
||||||
|
/// Restrict the factor to the given assignment.
|
||||||
|
DiscreteFactor::shared_ptr restrict(
|
||||||
|
const DiscreteValues& assignment) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -59,21 +59,30 @@ TEST(ADT, arithmetic) {
|
||||||
|
|
||||||
// Negate and subtraction
|
// Negate and subtraction
|
||||||
CHECK(assert_equal(-a, zero - a));
|
CHECK(assert_equal(-a, zero - a));
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
CHECK(assert_equal({zero}, a - a));
|
CHECK(assert_equal({zero}, a - a));
|
||||||
|
#else
|
||||||
|
CHECK(assert_equal({A, 0, 0}, a - a));
|
||||||
|
#endif
|
||||||
CHECK(assert_equal(a + b, b + a));
|
CHECK(assert_equal(a + b, b + a));
|
||||||
CHECK(assert_equal({A, 3, 4}, a + 2));
|
CHECK(assert_equal({A, 3, 4}, a + 2));
|
||||||
CHECK(assert_equal({B, 1, 2}, b - 2));
|
CHECK(assert_equal({B, 1, 2}, b - 2));
|
||||||
|
|
||||||
// Multiplication
|
// Multiplication
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
CHECK(assert_equal(zero, zero * a));
|
CHECK(assert_equal(zero, zero * a));
|
||||||
CHECK(assert_equal(zero, a * zero));
|
#else
|
||||||
|
CHECK(assert_equal({A, 0, 0}, zero * a));
|
||||||
|
#endif
|
||||||
CHECK(assert_equal(a, one * a));
|
CHECK(assert_equal(a, one * a));
|
||||||
CHECK(assert_equal(a, a * one));
|
CHECK(assert_equal(a, a * one));
|
||||||
CHECK(assert_equal(a * b, b * a));
|
CHECK(assert_equal(a * b, b * a));
|
||||||
|
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
// division
|
// division
|
||||||
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
|
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
|
||||||
CHECK(assert_equal(b, (a * b) / a));
|
CHECK(assert_equal(b, (a * b) / a));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
|
@ -228,9 +228,9 @@ TEST(DecisionTree, Example) {
|
||||||
// Test choose 0
|
// Test choose 0
|
||||||
DT actual0 = notba.choose(A, 0);
|
DT actual0 = notba.choose(A, 0);
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT(assert_equal(DT(0.0), actual0));
|
EXPECT(assert_equal(DT(0), actual0));
|
||||||
#else
|
#else
|
||||||
EXPECT(assert_equal(DT({0.0, 0.0}), actual0));
|
EXPECT(assert_equal(DT(B, 0, 0), actual0));
|
||||||
#endif
|
#endif
|
||||||
DOT(actual0);
|
DOT(actual0);
|
||||||
|
|
||||||
|
@ -618,6 +618,21 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test apply with assignment.
|
||||||
|
TEST(DecisionTree, Restrict) {
|
||||||
|
// Create three level tree
|
||||||
|
const vector<DT::LabelC> keys{DT::LabelC("C", 2), DT::LabelC("B", 2),
|
||||||
|
DT::LabelC("A", 2)};
|
||||||
|
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
|
DT restrictedTree = tree.restrict({{"A", 0}, {"B", 1}});
|
||||||
|
EXPECT(assert_equal(DT({DT::LabelC("C", 2)}, "3 7"), restrictedTree));
|
||||||
|
|
||||||
|
DT restrictMore = tree.restrict({{"A", 1}, {"B", 1}, {"C", 1}});
|
||||||
|
EXPECT(assert_equal(DT(8), restrictMore));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -124,8 +124,10 @@ TEST(DecisionTreeFactor, Divide) {
|
||||||
EXPECT(assert_inequal(pS, s));
|
EXPECT(assert_inequal(pS, s));
|
||||||
|
|
||||||
// The underlying data should be the same
|
// The underlying data should be the same
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
using ADT = AlgebraicDecisionTree<Key>;
|
using ADT = AlgebraicDecisionTree<Key>;
|
||||||
EXPECT(assert_equal(ADT(pS), ADT(s)));
|
EXPECT(assert_equal(ADT(pS), ADT(s)));
|
||||||
|
#endif
|
||||||
|
|
||||||
KeySet keys(joint.keys());
|
KeySet keys(joint.keys());
|
||||||
keys.insert(pA.keys().begin(), pA.keys().end());
|
keys.insert(pA.keys().begin(), pA.keys().end());
|
||||||
|
|
|
@ -69,11 +69,13 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
|
|
||||||
// Go through all the Gaussian conditionals, restrict them according to
|
// Go through all the Gaussian conditionals, restrict them according to
|
||||||
// fixed values, and then prune further.
|
// fixed values, and then prune further.
|
||||||
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
|
for (std::shared_ptr<HybridConditional> conditional : *this) {
|
||||||
if (conditional->isDiscrete()) continue;
|
if (conditional->isDiscrete()) continue;
|
||||||
|
|
||||||
// No-op if not a HybridGaussianConditional.
|
// No-op if not a HybridGaussianConditional.
|
||||||
if (marginalThreshold) conditional = conditional->restrict(fixed);
|
if (marginalThreshold)
|
||||||
|
conditional = std::static_pointer_cast<HybridConditional>(
|
||||||
|
conditional->restrict(fixed));
|
||||||
|
|
||||||
// Now decide on type what to do:
|
// Now decide on type what to do:
|
||||||
if (auto hgc = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
|
|
|
@ -170,8 +170,8 @@ double HybridConditional::evaluate(const HybridValues &values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
HybridConditional::shared_ptr HybridConditional::restrict(
|
std::shared_ptr<Factor> HybridConditional::restrict(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &assignment) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return std::make_shared<HybridConditional>(gc);
|
return std::make_shared<HybridConditional>(gc);
|
||||||
} else if (auto dc = asDiscrete()) {
|
} else if (auto dc = asDiscrete()) {
|
||||||
|
@ -184,21 +184,20 @@ HybridConditional::shared_ptr HybridConditional::restrict(
|
||||||
"HybridConditional::restrict: conditional type not handled");
|
"HybridConditional::restrict: conditional type not handled");
|
||||||
|
|
||||||
// Case 1: Fully determined, return corresponding Gaussian conditional
|
// Case 1: Fully determined, return corresponding Gaussian conditional
|
||||||
auto parentValues = discreteValues.filter(discreteKeys_);
|
auto parentValues = assignment.filter(discreteKeys_);
|
||||||
if (parentValues.size() == discreteKeys_.size()) {
|
if (parentValues.size() == discreteKeys_.size()) {
|
||||||
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
|
return std::make_shared<HybridConditional>(hgc->choose(parentValues));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 2: Some live parents remain, build a new tree
|
// Case 2: Some live parents remain, build a new tree
|
||||||
auto unspecifiedParentKeys = discreteValues.missingKeys(discreteKeys_);
|
auto remainingKeys = assignment.missingKeys(discreteKeys_);
|
||||||
if (!unspecifiedParentKeys.empty()) {
|
if (!remainingKeys.empty()) {
|
||||||
auto newTree = hgc->factors();
|
auto newTree = hgc->factors();
|
||||||
for (const auto &[key, value] : parentValues) {
|
for (const auto &[key, value] : parentValues) {
|
||||||
newTree = newTree.choose(key, value);
|
newTree = newTree.choose(key, value);
|
||||||
}
|
}
|
||||||
return std::make_shared<HybridConditional>(
|
return std::make_shared<HybridConditional>(
|
||||||
std::make_shared<HybridGaussianConditional>(unspecifiedParentKeys,
|
std::make_shared<HybridGaussianConditional>(remainingKeys, newTree));
|
||||||
newTree));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 3: No changes needed, return original
|
// Case 3: No changes needed, return original
|
||||||
|
|
|
@ -153,7 +153,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @return HybridGaussianConditional::shared_ptr otherwise
|
* @return HybridGaussianConditional::shared_ptr otherwise
|
||||||
*/
|
*/
|
||||||
HybridGaussianConditional::shared_ptr asHybrid() const {
|
HybridGaussianConditional::shared_ptr asHybrid() const {
|
||||||
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
|
if (!isHybrid()) return nullptr;
|
||||||
|
return std::static_pointer_cast<HybridGaussianConditional>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -162,7 +163,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @return GaussianConditional::shared_ptr otherwise
|
* @return GaussianConditional::shared_ptr otherwise
|
||||||
*/
|
*/
|
||||||
GaussianConditional::shared_ptr asGaussian() const {
|
GaussianConditional::shared_ptr asGaussian() const {
|
||||||
return std::dynamic_pointer_cast<GaussianConditional>(inner_);
|
if (!isContinuous()) return nullptr;
|
||||||
|
return std::static_pointer_cast<GaussianConditional>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -172,7 +174,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
*/
|
*/
|
||||||
template <typename T = DiscreteConditional>
|
template <typename T = DiscreteConditional>
|
||||||
typename T::shared_ptr asDiscrete() const {
|
typename T::shared_ptr asDiscrete() const {
|
||||||
return std::dynamic_pointer_cast<T>(inner_);
|
if (!isDiscrete()) return nullptr;
|
||||||
|
return std::static_pointer_cast<T>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
|
@ -221,7 +224,8 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* which is just a GaussianConditional. If this conditional is *not* a hybrid
|
* which is just a GaussianConditional. If this conditional is *not* a hybrid
|
||||||
* conditional, just return that.
|
* conditional, just return that.
|
||||||
*/
|
*/
|
||||||
shared_ptr restrict(const DiscreteValues& discreteValues) const;
|
std::shared_ptr<Factor> restrict(
|
||||||
|
const DiscreteValues& assignment) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
@ -133,10 +133,14 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
/// Return only the continuous keys for this factor.
|
/// Return only the continuous keys for this factor.
|
||||||
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
const KeyVector &continuousKeys() const { return continuousKeys_; }
|
||||||
|
|
||||||
/// Virtual class to compute tree of linear errors.
|
/// Compute tree of linear errors.
|
||||||
virtual AlgebraicDecisionTree<Key> errorTree(
|
virtual AlgebraicDecisionTree<Key> errorTree(
|
||||||
const VectorValues &values) const = 0;
|
const VectorValues &values) const = 0;
|
||||||
|
|
||||||
|
/// Restrict the factor to the given discrete values.
|
||||||
|
virtual std::shared_ptr<Factor> restrict(
|
||||||
|
const DiscreteValues &discreteValues) const = 0;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -363,4 +363,12 @@ double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
||||||
return conditional->evaluate(values.continuous());
|
return conditional->evaluate(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::shared_ptr<Factor> HybridGaussianConditional::restrict(
|
||||||
|
const DiscreteValues &assignment) const {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional::restrict not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -241,6 +241,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
/// Return true if the conditional has already been pruned.
|
/// Return true if the conditional has already been pruned.
|
||||||
bool pruned() const { return pruned_; }
|
bool pruned() const { return pruned_; }
|
||||||
|
|
||||||
|
/// Restrict to the given discrete values.
|
||||||
|
std::shared_ptr<Factor> restrict(
|
||||||
|
const DiscreteValues &discreteValues) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -199,4 +199,12 @@ double HybridGaussianFactor::error(const HybridValues& values) const {
|
||||||
return PotentiallyPrunedComponentError(pair, values.continuous());
|
return PotentiallyPrunedComponentError(pair, values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::shared_ptr<Factor> HybridGaussianFactor::restrict(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
|
throw std::runtime_error("HybridGaussianFactor::restrict not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -157,6 +157,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
virtual HybridGaussianProductFactor asProductFactor() const;
|
virtual HybridGaussianProductFactor asProductFactor() const;
|
||||||
|
|
||||||
|
/// Restrict the factor to the given discrete values.
|
||||||
|
std::shared_ptr<Factor> restrict(
|
||||||
|
const DiscreteValues &discreteValues) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -239,4 +239,21 @@ HybridNonlinearFactor::shared_ptr HybridNonlinearFactor::prune(
|
||||||
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
|
return std::make_shared<HybridNonlinearFactor>(discreteKeys(), prunedFactors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
std::shared_ptr<Factor> HybridNonlinearFactor::restrict(
|
||||||
|
const DiscreteValues& assignment) const {
|
||||||
|
auto restrictedFactors = factors_.restrict(assignment);
|
||||||
|
auto filtered = assignment.filter(discreteKeys_);
|
||||||
|
if (filtered.size() == discreteKeys_.size()) {
|
||||||
|
auto [nonlinearFactor, val] = factors_(filtered);
|
||||||
|
return nonlinearFactor;
|
||||||
|
} else {
|
||||||
|
auto remainingKeys = assignment.missingKeys(discreteKeys());
|
||||||
|
return std::make_shared<HybridNonlinearFactor>(remainingKeys,
|
||||||
|
factors_.restrict(filtered));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -80,6 +80,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Default constructor, mainly for serialization.
|
/// Default constructor, mainly for serialization.
|
||||||
HybridNonlinearFactor() = default;
|
HybridNonlinearFactor() = default;
|
||||||
|
|
||||||
|
@ -137,7 +140,7 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
* @return double The error of this factor.
|
* @return double The error of this factor.
|
||||||
*/
|
*/
|
||||||
double error(const Values& continuousValues,
|
double error(const Values& continuousValues,
|
||||||
const DiscreteValues& discreteValues) const;
|
const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute error of factor given hybrid values.
|
* @brief Compute error of factor given hybrid values.
|
||||||
|
@ -154,7 +157,8 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
size_t dim() const;
|
size_t dim() const;
|
||||||
|
|
||||||
/// Testable
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// print to stdout
|
/// print to stdout
|
||||||
|
@ -165,15 +169,16 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Getter for NonlinearFactor decision tree
|
/// Getter for NonlinearFactor decision tree
|
||||||
const FactorValuePairs& factors() const { return factors_; }
|
const FactorValuePairs& factors() const { return factors_; }
|
||||||
|
|
||||||
/// Linearize specific nonlinear factors based on the assignment in
|
/// Linearize specific nonlinear factors based on the assignment in
|
||||||
/// discreteValues.
|
/// discreteValues.
|
||||||
GaussianFactor::shared_ptr linearize(
|
GaussianFactor::shared_ptr linearize(const Values& continuousValues,
|
||||||
const Values& continuousValues,
|
const DiscreteValues& assignment) const;
|
||||||
const DiscreteValues& discreteValues) const;
|
|
||||||
|
|
||||||
/// Linearize all the continuous factors to get a HybridGaussianFactor.
|
/// Linearize all the continuous factors to get a HybridGaussianFactor.
|
||||||
std::shared_ptr<HybridGaussianFactor> linearize(
|
std::shared_ptr<HybridGaussianFactor> linearize(
|
||||||
|
@ -183,6 +188,12 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
HybridNonlinearFactor::shared_ptr prune(
|
HybridNonlinearFactor::shared_ptr prune(
|
||||||
const DecisionTreeFactor& discreteProbs) const;
|
const DecisionTreeFactor& discreteProbs) const;
|
||||||
|
|
||||||
|
/// Restrict the factor to the given discrete values.
|
||||||
|
std::shared_ptr<Factor> restrict(
|
||||||
|
const DiscreteValues& assignment) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Helper struct to assist private constructor below.
|
/// Helper struct to assist private constructor below.
|
||||||
struct ConstructorHelper;
|
struct ConstructorHelper;
|
||||||
|
|
|
@ -221,5 +221,30 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
|
||||||
return p / p.sum();
|
return p / p.sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
|
||||||
|
const DiscreteValues& discreteValues) const {
|
||||||
|
using std::dynamic_pointer_cast;
|
||||||
|
|
||||||
|
HybridNonlinearFactorGraph result;
|
||||||
|
result.reserve(size());
|
||||||
|
for (auto& f : factors_) {
|
||||||
|
// First check if it is a valid factor
|
||||||
|
if (!f) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Check if it is a hybrid factor
|
||||||
|
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||||
|
result.push_back(hf->restrict(discreteValues));
|
||||||
|
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
|
||||||
|
result.push_back(df->restrict(discreteValues));
|
||||||
|
} else {
|
||||||
|
result.push_back(f); // Everything else is just added as is
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -116,6 +116,10 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
||||||
AlgebraicDecisionTree<Key> discretePosterior(
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
const Values& continuousValues) const;
|
const Values& continuousValues) const;
|
||||||
|
|
||||||
|
/// Restrict all factors in the graph to the given discrete values.
|
||||||
|
HybridNonlinearFactorGraph restrict(
|
||||||
|
const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,6 @@ namespace gtsam {
|
||||||
class GTSAM_EXPORT HybridSmoother {
|
class GTSAM_EXPORT HybridSmoother {
|
||||||
private:
|
private:
|
||||||
HybridBayesNet hybridBayesNet_;
|
HybridBayesNet hybridBayesNet_;
|
||||||
HybridGaussianFactorGraph remainingFactorGraph_;
|
|
||||||
|
|
||||||
/// The threshold above which we make a decision about a mode.
|
/// The threshold above which we make a decision about a mode.
|
||||||
std::optional<double> marginalThreshold_;
|
std::optional<double> marginalThreshold_;
|
||||||
|
@ -44,6 +43,16 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
HybridSmoother(const std::optional<double> marginalThreshold = {})
|
HybridSmoother(const std::optional<double> marginalThreshold = {})
|
||||||
: marginalThreshold_(marginalThreshold) {}
|
: marginalThreshold_(marginalThreshold) {}
|
||||||
|
|
||||||
|
/// Return fixed values:
|
||||||
|
const DiscreteValues& fixedValues() const { return fixedValues_; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Re-initialize the smoother from a new hybrid Bayes Net.
|
||||||
|
*/
|
||||||
|
void reInitialize(HybridBayesNet&& hybridBayesNet) {
|
||||||
|
hybridBayesNet_ = std::move(hybridBayesNet);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Given new factors, perform an incremental update.
|
* Given new factors, perform an incremental update.
|
||||||
* The relevant densities in the `hybridBayesNet` will be added to the input
|
* The relevant densities in the `hybridBayesNet` will be added to the input
|
||||||
|
|
|
@ -318,21 +318,27 @@ TEST(HybridGaussianConditional, Restrict) {
|
||||||
const auto hc =
|
const auto hc =
|
||||||
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
|
std::make_shared<HybridConditional>(two_mode_measurement::hgc);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr same = hc->restrict({});
|
const auto same =
|
||||||
|
std::dynamic_pointer_cast<HybridConditional>(hc->restrict({}));
|
||||||
|
CHECK(same);
|
||||||
EXPECT(same->isHybrid());
|
EXPECT(same->isHybrid());
|
||||||
EXPECT(same->asHybrid()->nrComponents() == 4);
|
EXPECT(same->asHybrid()->nrComponents() == 4);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr oneParent = hc->restrict({{M(1), 0}});
|
const auto oneParent =
|
||||||
|
std::dynamic_pointer_cast<HybridConditional>(hc->restrict({{M(1), 0}}));
|
||||||
|
CHECK(oneParent);
|
||||||
EXPECT(oneParent->isHybrid());
|
EXPECT(oneParent->isHybrid());
|
||||||
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
EXPECT(oneParent->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr oneParent2 =
|
const auto oneParent2 = std::dynamic_pointer_cast<HybridConditional>(
|
||||||
hc->restrict({{M(7), 0}, {M(1), 0}});
|
hc->restrict({{M(7), 0}, {M(1), 0}}));
|
||||||
|
CHECK(oneParent2);
|
||||||
EXPECT(oneParent2->isHybrid());
|
EXPECT(oneParent2->isHybrid());
|
||||||
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
EXPECT(oneParent2->asHybrid()->nrComponents() == 2);
|
||||||
|
|
||||||
const HybridConditional::shared_ptr gaussian =
|
const auto gaussian = std::dynamic_pointer_cast<HybridConditional>(
|
||||||
hc->restrict({{M(1), 0}, {M(2), 1}});
|
hc->restrict({{M(1), 0}, {M(2), 1}}));
|
||||||
|
CHECK(gaussian);
|
||||||
EXPECT(gaussian->asGaussian());
|
EXPECT(gaussian->asGaussian());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -131,6 +131,18 @@ TEST(HybridNonlinearFactor, Dim) {
|
||||||
EXPECT_LONGS_EQUAL(1, hybridFactor.dim());
|
EXPECT_LONGS_EQUAL(1, hybridFactor.dim());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test restrict method
|
||||||
|
TEST(HybridNonlinearFactor, Restrict) {
|
||||||
|
using namespace test_constructor;
|
||||||
|
HybridNonlinearFactor factor(m1, {f0, f1});
|
||||||
|
DiscreteValues assignment = {{m1.first, 0}};
|
||||||
|
auto restricted = factor.restrict(assignment);
|
||||||
|
auto betweenFactor = dynamic_pointer_cast<BetweenFactor<double>>(restricted);
|
||||||
|
CHECK(betweenFactor);
|
||||||
|
EXPECT(assert_equal(*f0, *betweenFactor));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
@ -102,12 +102,16 @@ TEST(HybridSmoother, IncrementalSmoother) {
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(11,
|
auto& hybridBayesNet = smoother.hybridBayesNet();
|
||||||
smoother.hybridBayesNet().at(5)->asDiscrete()->nrValues());
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
EXPECT_LONGS_EQUAL(11, hybridBayesNet.at(5)->asDiscrete()->nrValues());
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(16, hybridBayesNet.at(5)->asDiscrete()->nrValues());
|
||||||
|
#endif
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
HybridValues delta = smoother.hybridBayesNet().optimize();
|
HybridValues delta = hybridBayesNet.optimize();
|
||||||
|
|
||||||
// Check discrete assignment
|
// Check discrete assignment
|
||||||
DiscreteValues expected_discrete;
|
DiscreteValues expected_discrete;
|
||||||
|
@ -156,8 +160,12 @@ TEST(HybridSmoother, ValidPruningError) {
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(14,
|
auto& hybridBayesNet = smoother.hybridBayesNet();
|
||||||
smoother.hybridBayesNet().at(8)->asDiscrete()->nrValues());
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
EXPECT_LONGS_EQUAL(14, hybridBayesNet.at(8)->asDiscrete()->nrValues());
|
||||||
|
#else
|
||||||
|
EXPECT_LONGS_EQUAL(128, hybridBayesNet.at(8)->asDiscrete()->nrValues());
|
||||||
|
#endif
|
||||||
|
|
||||||
// Get the continuous delta update as well as
|
// Get the continuous delta update as well as
|
||||||
// the optimal discrete assignment.
|
// the optimal discrete assignment.
|
||||||
|
|
|
@ -53,11 +53,6 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
|
||||||
/// Multiply into a decisiontree
|
/// Multiply into a decisiontree
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
|
||||||
throw std::runtime_error("AllDiff::error not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Ensure Arc-consistency by checking every possible value of domain j.
|
* Ensure Arc-consistency by checking every possible value of domain j.
|
||||||
* @param j domain to be checked
|
* @param j domain to be checked
|
||||||
|
|
|
@ -91,11 +91,6 @@ class BinaryAllDiff : public Constraint {
|
||||||
const Domains&) const override {
|
const Domains&) const override {
|
||||||
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
|
||||||
throw std::runtime_error("BinaryAllDiff::error not implemented");
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -125,6 +125,17 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
||||||
return toDecisionTreeFactor().max(keys);
|
return toDecisionTreeFactor().max(keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
AlgebraicDecisionTree<Key> errorTree() const override {
|
||||||
|
throw std::runtime_error("Constraint::error not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute error for each assignment and return as a tree
|
||||||
|
DiscreteFactor::shared_ptr restrict(
|
||||||
|
const DiscreteValues& assignment) const override {
|
||||||
|
throw std::runtime_error("Constraint::restrict not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Wrapper support
|
/// @name Wrapper support
|
||||||
/// @{
|
/// @{
|
||||||
|
|
|
@ -69,11 +69,6 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
|
||||||
throw std::runtime_error("Domain::error not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return concise string representation, mostly to debug arc consistency.
|
// Return concise string representation, mostly to debug arc consistency.
|
||||||
// Converts from base 0 to base1.
|
// Converts from base 0 to base1.
|
||||||
std::string base1Str() const;
|
std::string base1Str() const;
|
||||||
|
|
|
@ -49,11 +49,6 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override {
|
|
||||||
throw std::runtime_error("SingleValue::error not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Calculate value
|
/// Calculate value
|
||||||
double evaluate(const Assignment<Key>& values) const override;
|
double evaluate(const Assignment<Key>& values) const override;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue