commit
c37eb49547
|
@ -31,11 +31,12 @@
|
|||
|
||||
#include <boost/algorithm/string/classification.hpp>
|
||||
#include <boost/algorithm/string/split.hpp>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace boost::algorithm;
|
||||
|
||||
|
@ -43,19 +44,30 @@ using symbol_shorthand::L;
|
|||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
||||
const size_t kMaxLoopCount = 2000; // Example default value
|
||||
const size_t kMaxNrHypotheses = 10;
|
||||
|
||||
auto kOpenLoopModel = noiseModel::Diagonal::Sigmas(Vector3::Ones() * 10);
|
||||
const double kOpenLoopConstant = kOpenLoopModel->negLogConstant();
|
||||
|
||||
auto kPriorNoiseModel = noiseModel::Diagonal::Sigmas(
|
||||
(Vector(3) << 0.0001, 0.0001, 0.0001).finished());
|
||||
|
||||
auto kPoseNoiseModel = noiseModel::Diagonal::Sigmas(
|
||||
(Vector(3) << 1.0 / 30.0, 1.0 / 30.0, 1.0 / 100.0).finished());
|
||||
const double kPoseNoiseConstant = kPoseNoiseModel->negLogConstant();
|
||||
|
||||
// Experiment Class
|
||||
class Experiment {
|
||||
public:
|
||||
// Parameters with default values
|
||||
size_t maxLoopCount = 3000;
|
||||
|
||||
// 3000: {1: 62s, 2: 21s, 3: 20s, 4: 31s, 5: 39s} No DT optimizations
|
||||
// 3000: {1: 65s, 2: 20s, 3: 16s, 4: 21s, 5: 28s} With DT optimizations
|
||||
// 3000: {1: 59s, 2: 19s, 3: 18s, 4: 26s, 5: 33s} With DT optimizations +
|
||||
// merge
|
||||
size_t updateFrequency = 3;
|
||||
|
||||
size_t maxNrHypotheses = 10;
|
||||
|
||||
private:
|
||||
std::string filename_;
|
||||
HybridSmoother smoother_;
|
||||
|
@ -72,7 +84,7 @@ class Experiment {
|
|||
*/
|
||||
void writeResult(const Values& result, size_t numPoses,
|
||||
const std::string& filename = "Hybrid_city10000.txt") {
|
||||
ofstream outfile;
|
||||
std::ofstream outfile;
|
||||
outfile.open(filename);
|
||||
|
||||
for (size_t i = 0; i < numPoses; ++i) {
|
||||
|
@ -98,9 +110,8 @@ class Experiment {
|
|||
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
||||
X(keyS), X(keyT), measurement, kPoseNoiseModel);
|
||||
|
||||
std::vector<NonlinearFactorValuePair> factors{
|
||||
{f0, kOpenLoopModel->negLogConstant()},
|
||||
{f1, kPoseNoiseModel->negLogConstant()}};
|
||||
std::vector<NonlinearFactorValuePair> factors{{f0, kOpenLoopConstant},
|
||||
{f1, kPoseNoiseConstant}};
|
||||
HybridNonlinearFactor mixtureFactor(l, factors);
|
||||
return mixtureFactor;
|
||||
}
|
||||
|
@ -108,14 +119,14 @@ class Experiment {
|
|||
/// @brief Create hybrid odometry factor with discrete measurement choices.
|
||||
HybridNonlinearFactor hybridOdometryFactor(
|
||||
size_t numMeasurements, size_t keyS, size_t keyT, const DiscreteKey& m,
|
||||
const std::vector<Pose2>& poseArray,
|
||||
const SharedNoiseModel& poseNoiseModel) {
|
||||
const std::vector<Pose2>& poseArray) {
|
||||
auto f0 = std::make_shared<BetweenFactor<Pose2>>(
|
||||
X(keyS), X(keyT), poseArray[0], poseNoiseModel);
|
||||
X(keyS), X(keyT), poseArray[0], kPoseNoiseModel);
|
||||
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
||||
X(keyS), X(keyT), poseArray[1], poseNoiseModel);
|
||||
X(keyS), X(keyT), poseArray[1], kPoseNoiseModel);
|
||||
|
||||
std::vector<NonlinearFactorValuePair> factors{{f0, 0.0}, {f1, 0.0}};
|
||||
std::vector<NonlinearFactorValuePair> factors{{f0, kPoseNoiseConstant},
|
||||
{f1, kPoseNoiseConstant}};
|
||||
HybridNonlinearFactor mixtureFactor(m, factors);
|
||||
return mixtureFactor;
|
||||
}
|
||||
|
@ -123,9 +134,9 @@ class Experiment {
|
|||
/// @brief Perform smoother update and optimize the graph.
|
||||
void smootherUpdate(HybridSmoother& smoother,
|
||||
HybridNonlinearFactorGraph& graph, const Values& initial,
|
||||
size_t kMaxNrHypotheses, Values* result) {
|
||||
size_t maxNrHypotheses, Values* result) {
|
||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||
smoother.update(linearized, kMaxNrHypotheses);
|
||||
smoother.update(linearized, maxNrHypotheses);
|
||||
// throw if x0 not in hybridBayesNet_:
|
||||
const KeySet& keys = smoother.hybridBayesNet().keys();
|
||||
if (keys.find(X(0)) == keys.end()) {
|
||||
|
@ -142,11 +153,11 @@ class Experiment {
|
|||
: filename_(filename), smoother_(0.99) {}
|
||||
|
||||
/// @brief Run the main experiment with a given maxLoopCount.
|
||||
void run(size_t maxLoopCount) {
|
||||
void run() {
|
||||
// Prepare reading
|
||||
ifstream in(filename_);
|
||||
std::ifstream in(filename_);
|
||||
if (!in.is_open()) {
|
||||
cerr << "Failed to open file: " << filename_ << endl;
|
||||
std::cerr << "Failed to open file: " << filename_ << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -167,11 +178,14 @@ class Experiment {
|
|||
|
||||
// Initial update
|
||||
clock_t beforeUpdate = clock();
|
||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||
clock_t afterUpdate = clock();
|
||||
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
|
||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||
|
||||
// Flag to decide whether to run smoother update
|
||||
size_t numberOfHybridFactors = 0;
|
||||
|
||||
// Start main loop
|
||||
size_t keyS = 0, keyT = 0;
|
||||
clock_t startTime = clock();
|
||||
|
@ -192,9 +206,6 @@ class Experiment {
|
|||
poseArray[i] = Pose2(x, y, rad);
|
||||
}
|
||||
|
||||
// Flag to decide whether to run smoother update
|
||||
bool doSmootherUpdate = false;
|
||||
|
||||
// Take the first one as the initial estimate
|
||||
Pose2 odomPose = poseArray[0];
|
||||
if (keyS == keyT - 1) {
|
||||
|
@ -202,11 +213,11 @@ class Experiment {
|
|||
if (numMeasurements > 1) {
|
||||
// Add hybrid factor
|
||||
DiscreteKey m(M(discreteCount), numMeasurements);
|
||||
HybridNonlinearFactor mixtureFactor = hybridOdometryFactor(
|
||||
numMeasurements, keyS, keyT, m, poseArray, kPoseNoiseModel);
|
||||
HybridNonlinearFactor mixtureFactor =
|
||||
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
|
||||
graph_.push_back(mixtureFactor);
|
||||
discreteCount++;
|
||||
doSmootherUpdate = true;
|
||||
numberOfHybridFactors += 1;
|
||||
std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
|
||||
} else {
|
||||
graph_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
|
||||
|
@ -221,18 +232,20 @@ class Experiment {
|
|||
// print loop closure event keys:
|
||||
std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
|
||||
graph_.add(loopFactor);
|
||||
doSmootherUpdate = true;
|
||||
numberOfHybridFactors += 1;
|
||||
loopCount++;
|
||||
}
|
||||
|
||||
if (doSmootherUpdate) {
|
||||
if (numberOfHybridFactors >= updateFrequency) {
|
||||
// print the keys involved in the smoother update
|
||||
std::cout << "Smoother update: " << graph_.size() << std::endl;
|
||||
gttic_(SmootherUpdate);
|
||||
beforeUpdate = clock();
|
||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||
afterUpdate = clock();
|
||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||
gttoc_(SmootherUpdate);
|
||||
doSmootherUpdate = false;
|
||||
numberOfHybridFactors = 0;
|
||||
}
|
||||
|
||||
// Record timing for odometry edges only
|
||||
|
@ -258,7 +271,7 @@ class Experiment {
|
|||
|
||||
// Final update
|
||||
beforeUpdate = clock();
|
||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
||||
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||
afterUpdate = clock();
|
||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||
|
||||
|
@ -288,7 +301,7 @@ class Experiment {
|
|||
// }
|
||||
|
||||
// Write timing info to file
|
||||
ofstream outfileTime;
|
||||
std::ofstream outfileTime;
|
||||
std::string timeFileName = "Hybrid_City10000_time.txt";
|
||||
outfileTime.open(timeFileName);
|
||||
for (auto accTime : timeList) {
|
||||
|
@ -300,15 +313,47 @@ class Experiment {
|
|||
};
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
// Function to parse command-line arguments
|
||||
void parseArguments(int argc, char* argv[], size_t& maxLoopCount,
|
||||
size_t& updateFrequency, size_t& maxNrHypotheses) {
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
std::string arg = argv[i];
|
||||
if (arg == "--max-loop-count" && i + 1 < argc) {
|
||||
maxLoopCount = std::stoul(argv[++i]);
|
||||
} else if (arg == "--update-frequency" && i + 1 < argc) {
|
||||
updateFrequency = std::stoul(argv[++i]);
|
||||
} else if (arg == "--max-nr-hypotheses" && i + 1 < argc) {
|
||||
maxNrHypotheses = std::stoul(argv[++i]);
|
||||
} else if (arg == "--help") {
|
||||
std::cout << "Usage: " << argv[0] << " [options]\n"
|
||||
<< "Options:\n"
|
||||
<< " --max-loop-count <value> Set the maximum loop "
|
||||
"count (default: 3000)\n"
|
||||
<< " --update-frequency <value> Set the update frequency "
|
||||
"(default: 3)\n"
|
||||
<< " --max-nr-hypotheses <value> Set the maximum number of "
|
||||
"hypotheses (default: 10)\n"
|
||||
<< " --help Show this help message\n";
|
||||
std::exit(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Main function
|
||||
int main(int argc, char* argv[]) {
|
||||
Experiment experiment(findExampleDataFile("T1_city10000_04.txt"));
|
||||
// Experiment experiment("../data/mh_T1_city10000_04.txt"); //Type #1 only
|
||||
// Experiment experiment("../data/mh_T3b_city10000_10.txt"); //Type #3 only
|
||||
// Experiment experiment("../data/mh_T1_T3_city10000_04.txt"); //Type #1 +
|
||||
// Type #3
|
||||
|
||||
// Parse command-line arguments
|
||||
parseArguments(argc, argv, experiment.maxLoopCount,
|
||||
experiment.updateFrequency, experiment.maxNrHypotheses);
|
||||
|
||||
// Run the experiment
|
||||
experiment.run(kMaxLoopCount);
|
||||
experiment.run();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -74,8 +74,8 @@ namespace gtsam {
|
|||
|
||||
/// equality up to tolerance
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||
if (!other) return false;
|
||||
if (!q.isLeaf()) return false;
|
||||
const Leaf* other = static_cast<const Leaf*>(&q);
|
||||
return compare(this->constant_, other->constant_);
|
||||
}
|
||||
|
||||
|
@ -202,36 +202,39 @@ namespace gtsam {
|
|||
* @param node The root node of the decision tree.
|
||||
* @return NodePtr
|
||||
*/
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
|
||||
static NodePtr Unique(const NodePtr& node) {
|
||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
||||
// Choice node, we recurse!
|
||||
// Make non-const copy so we can update
|
||||
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||
if (node->isLeaf()) return node; // Leaf node, return as is
|
||||
|
||||
// Iterate over all the branches
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
auto branch = choice->branches_[i];
|
||||
f->push_back(Unique(branch));
|
||||
}
|
||||
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||
// Choice node, we recurse!
|
||||
// Make non-const copy so we can update
|
||||
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||
|
||||
#ifdef GTSAM_DT_MERGING
|
||||
// If all the branches are the same, we can merge them into one
|
||||
if (f->allSame_) {
|
||||
assert(f->branches().size() > 0);
|
||||
NodePtr f0 = f->branches_[0];
|
||||
|
||||
NodePtr newLeaf(
|
||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||
return newLeaf;
|
||||
}
|
||||
#endif
|
||||
return f;
|
||||
} else {
|
||||
// Leaf node, return as is
|
||||
return node;
|
||||
// Iterate over all the branches
|
||||
for (const auto& branch : choice->branches_) {
|
||||
f->push_back(Unique(branch));
|
||||
}
|
||||
|
||||
// If all the branches are the same, we can merge them into one
|
||||
if (f->allSame_) {
|
||||
assert(f->branches().size() > 0);
|
||||
auto f0 = std::static_pointer_cast<const Leaf>(f->branches_[0]);
|
||||
return std::make_shared<Leaf>(f0->constant());
|
||||
}
|
||||
|
||||
return f;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
static NodePtr Unique(const NodePtr& node) {
|
||||
// No-op when GTSAM_DT_MERGING is not defined
|
||||
return node;
|
||||
}
|
||||
|
||||
#endif
|
||||
bool isLeaf() const override { return false; }
|
||||
|
||||
/// Constructor, given choice label and mandatory expected branch count.
|
||||
|
@ -322,9 +325,9 @@ namespace gtsam {
|
|||
const NodePtr& branch = branches_[i];
|
||||
|
||||
// Check if zero
|
||||
if (!showZero) {
|
||||
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
|
||||
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
|
||||
if (!showZero && branch->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<const Leaf>(branch);
|
||||
if (valueFormatter(leaf->constant()).compare("0")) continue;
|
||||
}
|
||||
|
||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||
|
@ -346,8 +349,8 @@ namespace gtsam {
|
|||
|
||||
/// equality
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||
if (!other) return false;
|
||||
if (q.isLeaf()) return false;
|
||||
const Choice* other = static_cast<const Choice*>(&q);
|
||||
if (this->label_ != other->label_) return false;
|
||||
if (branches_.size() != other->branches_.size()) return false;
|
||||
// we don't care about shared pointers being equal here
|
||||
|
@ -570,11 +573,13 @@ namespace gtsam {
|
|||
struct ApplyUnary {
|
||||
const Unary& op;
|
||||
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
|
||||
if (node->isLeaf()) {
|
||||
// Apply the unary operation to the leaf's constant value
|
||||
auto leaf = std::static_pointer_cast<Leaf>(node);
|
||||
leaf->constant_ = op(leaf->constant_);
|
||||
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||
} else {
|
||||
// Recurse into the choice branches
|
||||
auto choice = std::static_pointer_cast<Choice>(node);
|
||||
for (NodePtr& branch : choice->branches()) {
|
||||
(*this)(branch);
|
||||
}
|
||||
|
@ -622,8 +627,7 @@ namespace gtsam {
|
|||
for (Iterator it = begin; it != end; it++) {
|
||||
if (it->root_->isLeaf())
|
||||
continue;
|
||||
std::shared_ptr<const Choice> c =
|
||||
std::dynamic_pointer_cast<const Choice>(it->root_);
|
||||
auto c = std::static_pointer_cast<const Choice>(it->root_);
|
||||
if (!highestLabel || c->label() > *highestLabel) {
|
||||
highestLabel = c->label();
|
||||
nrChoices = c->nrChoices();
|
||||
|
@ -729,11 +733,7 @@ namespace gtsam {
|
|||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||
auto node = build(begin, end, beginY, endY);
|
||||
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||
return Choice::Unique(choice);
|
||||
} else {
|
||||
return node;
|
||||
}
|
||||
return Choice::Unique(node);
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
@ -742,18 +742,17 @@ namespace gtsam {
|
|||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||
const typename DecisionTree<L, X>::NodePtr& f,
|
||||
std::function<Y(const X&)> Y_of_X) {
|
||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
|
||||
if (f->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<LXLeaf>(f);
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
}
|
||||
|
||||
// Check if Choice
|
||||
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
||||
if (!choice) throw std::invalid_argument(
|
||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||
// Now a Choice!
|
||||
auto choice = std::static_pointer_cast<const LXChoice>(f);
|
||||
|
||||
// Create a new Choice node with the same label
|
||||
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||
|
@ -773,18 +772,17 @@ namespace gtsam {
|
|||
const typename DecisionTree<M, X>::NodePtr& f,
|
||||
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
||||
using LY = DecisionTree<L, Y>;
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||
|
||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||
if (f->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<const MXLeaf>(f);
|
||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||
}
|
||||
|
||||
// Check if Choice
|
||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
|
||||
// Now is Choice!
|
||||
auto choice = std::static_pointer_cast<const MXChoice>(f);
|
||||
|
||||
// get new label
|
||||
const M oldLabel = choice->label();
|
||||
|
@ -826,13 +824,14 @@ namespace gtsam {
|
|||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(leaf->constant());
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
||||
|
||||
if (node->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||
return f(leaf->constant());
|
||||
}
|
||||
|
||||
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||
}
|
||||
};
|
||||
|
@ -863,13 +862,14 @@ namespace gtsam {
|
|||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(*leaf);
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
|
||||
|
||||
if (node->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||
return f(*leaf);
|
||||
}
|
||||
|
||||
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||
}
|
||||
};
|
||||
|
@ -898,13 +898,16 @@ namespace gtsam {
|
|||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(assignment, leaf->constant());
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||
|
||||
if (node->isLeaf()) {
|
||||
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||
return f(assignment, leaf->constant());
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
assignment[choice->label()] = i; // Set assignment for label to i
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
|
||||
HybridBayesNet result;
|
||||
result.reserve(size());
|
||||
|
||||
// Go through all the Gaussian conditionals, restrict them according to
|
||||
// fixed values, and then prune further.
|
||||
|
@ -84,9 +85,9 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
}
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
result.push_back(prunedHybridGaussianConditional);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// Add the non-HybridGaussianConditional conditional
|
||||
result.push_back(gc);
|
||||
} else if (conditional->isContinuous()) {
|
||||
// Add the non-Hybrid GaussianConditional conditional
|
||||
result.push_back(conditional);
|
||||
} else
|
||||
throw std::runtime_error(
|
||||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||
|
|
|
@ -23,6 +23,7 @@ namespace gtsam {
|
|||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys) {
|
||||
KeyVector allKeys;
|
||||
allKeys.reserve(continuousKeys.size() + discreteKeys.size());
|
||||
std::copy(continuousKeys.begin(), continuousKeys.end(),
|
||||
std::back_inserter(allKeys));
|
||||
std::transform(discreteKeys.begin(), discreteKeys.end(),
|
||||
|
@ -34,6 +35,7 @@ KeyVector CollectKeys(const KeyVector &continuousKeys,
|
|||
/* ************************************************************************ */
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
|
||||
KeyVector allKeys;
|
||||
allKeys.reserve(keys1.size() + keys2.size());
|
||||
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
|
||||
std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys));
|
||||
return allKeys;
|
||||
|
|
|
@ -467,7 +467,7 @@ std::shared_ptr<GaussianConditional> HessianFactor::eliminateCholesky(const Orde
|
|||
info_.choleskyPartial(nFrontals);
|
||||
|
||||
// TODO(frank): pre-allocate GaussianConditional and write into it
|
||||
const VerticalBlockMatrix Ab = info_.split(nFrontals);
|
||||
VerticalBlockMatrix Ab = info_.split(nFrontals);
|
||||
conditional = std::make_shared<GaussianConditional>(keys_, nFrontals, std::move(Ab));
|
||||
|
||||
// Erase the eliminated keys in this factor
|
||||
|
|
Loading…
Reference in New Issue