Merge branch 'develop' into discrete-elimination-refactor
commit
858c64e167
|
@ -31,11 +31,12 @@
|
||||||
|
|
||||||
#include <boost/algorithm/string/classification.hpp>
|
#include <boost/algorithm/string/classification.hpp>
|
||||||
#include <boost/algorithm/string/split.hpp>
|
#include <boost/algorithm/string/split.hpp>
|
||||||
|
#include <cstdlib>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using namespace boost::algorithm;
|
using namespace boost::algorithm;
|
||||||
|
|
||||||
|
@ -43,19 +44,30 @@ using symbol_shorthand::L;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
|
||||||
const size_t kMaxLoopCount = 2000; // Example default value
|
|
||||||
const size_t kMaxNrHypotheses = 10;
|
|
||||||
|
|
||||||
auto kOpenLoopModel = noiseModel::Diagonal::Sigmas(Vector3::Ones() * 10);
|
auto kOpenLoopModel = noiseModel::Diagonal::Sigmas(Vector3::Ones() * 10);
|
||||||
|
const double kOpenLoopConstant = kOpenLoopModel->negLogConstant();
|
||||||
|
|
||||||
auto kPriorNoiseModel = noiseModel::Diagonal::Sigmas(
|
auto kPriorNoiseModel = noiseModel::Diagonal::Sigmas(
|
||||||
(Vector(3) << 0.0001, 0.0001, 0.0001).finished());
|
(Vector(3) << 0.0001, 0.0001, 0.0001).finished());
|
||||||
|
|
||||||
auto kPoseNoiseModel = noiseModel::Diagonal::Sigmas(
|
auto kPoseNoiseModel = noiseModel::Diagonal::Sigmas(
|
||||||
(Vector(3) << 1.0 / 30.0, 1.0 / 30.0, 1.0 / 100.0).finished());
|
(Vector(3) << 1.0 / 30.0, 1.0 / 30.0, 1.0 / 100.0).finished());
|
||||||
|
const double kPoseNoiseConstant = kPoseNoiseModel->negLogConstant();
|
||||||
|
|
||||||
// Experiment Class
|
// Experiment Class
|
||||||
class Experiment {
|
class Experiment {
|
||||||
|
public:
|
||||||
|
// Parameters with default values
|
||||||
|
size_t maxLoopCount = 3000;
|
||||||
|
|
||||||
|
// 3000: {1: 62s, 2: 21s, 3: 20s, 4: 31s, 5: 39s} No DT optimizations
|
||||||
|
// 3000: {1: 65s, 2: 20s, 3: 16s, 4: 21s, 5: 28s} With DT optimizations
|
||||||
|
// 3000: {1: 59s, 2: 19s, 3: 18s, 4: 26s, 5: 33s} With DT optimizations +
|
||||||
|
// merge
|
||||||
|
size_t updateFrequency = 3;
|
||||||
|
|
||||||
|
size_t maxNrHypotheses = 10;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string filename_;
|
std::string filename_;
|
||||||
HybridSmoother smoother_;
|
HybridSmoother smoother_;
|
||||||
|
@ -72,7 +84,7 @@ class Experiment {
|
||||||
*/
|
*/
|
||||||
void writeResult(const Values& result, size_t numPoses,
|
void writeResult(const Values& result, size_t numPoses,
|
||||||
const std::string& filename = "Hybrid_city10000.txt") {
|
const std::string& filename = "Hybrid_city10000.txt") {
|
||||||
ofstream outfile;
|
std::ofstream outfile;
|
||||||
outfile.open(filename);
|
outfile.open(filename);
|
||||||
|
|
||||||
for (size_t i = 0; i < numPoses; ++i) {
|
for (size_t i = 0; i < numPoses; ++i) {
|
||||||
|
@ -98,9 +110,8 @@ class Experiment {
|
||||||
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
||||||
X(keyS), X(keyT), measurement, kPoseNoiseModel);
|
X(keyS), X(keyT), measurement, kPoseNoiseModel);
|
||||||
|
|
||||||
std::vector<NonlinearFactorValuePair> factors{
|
std::vector<NonlinearFactorValuePair> factors{{f0, kOpenLoopConstant},
|
||||||
{f0, kOpenLoopModel->negLogConstant()},
|
{f1, kPoseNoiseConstant}};
|
||||||
{f1, kPoseNoiseModel->negLogConstant()}};
|
|
||||||
HybridNonlinearFactor mixtureFactor(l, factors);
|
HybridNonlinearFactor mixtureFactor(l, factors);
|
||||||
return mixtureFactor;
|
return mixtureFactor;
|
||||||
}
|
}
|
||||||
|
@ -114,9 +125,8 @@ class Experiment {
|
||||||
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
auto f1 = std::make_shared<BetweenFactor<Pose2>>(
|
||||||
X(keyS), X(keyT), poseArray[1], kPoseNoiseModel);
|
X(keyS), X(keyT), poseArray[1], kPoseNoiseModel);
|
||||||
|
|
||||||
std::vector<NonlinearFactorValuePair> factors{
|
std::vector<NonlinearFactorValuePair> factors{{f0, kPoseNoiseConstant},
|
||||||
{f0, kPoseNoiseModel->negLogConstant()},
|
{f1, kPoseNoiseConstant}};
|
||||||
{f1, kPoseNoiseModel->negLogConstant()}};
|
|
||||||
HybridNonlinearFactor mixtureFactor(m, factors);
|
HybridNonlinearFactor mixtureFactor(m, factors);
|
||||||
return mixtureFactor;
|
return mixtureFactor;
|
||||||
}
|
}
|
||||||
|
@ -124,9 +134,9 @@ class Experiment {
|
||||||
/// @brief Perform smoother update and optimize the graph.
|
/// @brief Perform smoother update and optimize the graph.
|
||||||
void smootherUpdate(HybridSmoother& smoother,
|
void smootherUpdate(HybridSmoother& smoother,
|
||||||
HybridNonlinearFactorGraph& graph, const Values& initial,
|
HybridNonlinearFactorGraph& graph, const Values& initial,
|
||||||
size_t kMaxNrHypotheses, Values* result) {
|
size_t maxNrHypotheses, Values* result) {
|
||||||
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
HybridGaussianFactorGraph linearized = *graph.linearize(initial);
|
||||||
smoother.update(linearized, kMaxNrHypotheses);
|
smoother.update(linearized, maxNrHypotheses);
|
||||||
// throw if x0 not in hybridBayesNet_:
|
// throw if x0 not in hybridBayesNet_:
|
||||||
const KeySet& keys = smoother.hybridBayesNet().keys();
|
const KeySet& keys = smoother.hybridBayesNet().keys();
|
||||||
if (keys.find(X(0)) == keys.end()) {
|
if (keys.find(X(0)) == keys.end()) {
|
||||||
|
@ -143,11 +153,11 @@ class Experiment {
|
||||||
: filename_(filename), smoother_(0.99) {}
|
: filename_(filename), smoother_(0.99) {}
|
||||||
|
|
||||||
/// @brief Run the main experiment with a given maxLoopCount.
|
/// @brief Run the main experiment with a given maxLoopCount.
|
||||||
void run(size_t maxLoopCount) {
|
void run() {
|
||||||
// Prepare reading
|
// Prepare reading
|
||||||
ifstream in(filename_);
|
std::ifstream in(filename_);
|
||||||
if (!in.is_open()) {
|
if (!in.is_open()) {
|
||||||
cerr << "Failed to open file: " << filename_ << endl;
|
std::cerr << "Failed to open file: " << filename_ << std::endl;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,11 +178,14 @@ class Experiment {
|
||||||
|
|
||||||
// Initial update
|
// Initial update
|
||||||
clock_t beforeUpdate = clock();
|
clock_t beforeUpdate = clock();
|
||||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||||
clock_t afterUpdate = clock();
|
clock_t afterUpdate = clock();
|
||||||
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
|
std::vector<std::pair<size_t, double>> smootherUpdateTimes;
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||||
|
|
||||||
|
// Flag to decide whether to run smoother update
|
||||||
|
size_t numberOfHybridFactors = 0;
|
||||||
|
|
||||||
// Start main loop
|
// Start main loop
|
||||||
size_t keyS = 0, keyT = 0;
|
size_t keyS = 0, keyT = 0;
|
||||||
clock_t startTime = clock();
|
clock_t startTime = clock();
|
||||||
|
@ -193,9 +206,6 @@ class Experiment {
|
||||||
poseArray[i] = Pose2(x, y, rad);
|
poseArray[i] = Pose2(x, y, rad);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flag to decide whether to run smoother update
|
|
||||||
bool doSmootherUpdate = false;
|
|
||||||
|
|
||||||
// Take the first one as the initial estimate
|
// Take the first one as the initial estimate
|
||||||
Pose2 odomPose = poseArray[0];
|
Pose2 odomPose = poseArray[0];
|
||||||
if (keyS == keyT - 1) {
|
if (keyS == keyT - 1) {
|
||||||
|
@ -207,8 +217,8 @@ class Experiment {
|
||||||
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
|
hybridOdometryFactor(numMeasurements, keyS, keyT, m, poseArray);
|
||||||
graph_.push_back(mixtureFactor);
|
graph_.push_back(mixtureFactor);
|
||||||
discreteCount++;
|
discreteCount++;
|
||||||
doSmootherUpdate = true;
|
numberOfHybridFactors += 1;
|
||||||
// std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
|
std::cout << "mixtureFactor: " << keyS << " " << keyT << std::endl;
|
||||||
} else {
|
} else {
|
||||||
graph_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
|
graph_.add(BetweenFactor<Pose2>(X(keyS), X(keyT), odomPose,
|
||||||
kPoseNoiseModel));
|
kPoseNoiseModel));
|
||||||
|
@ -220,20 +230,22 @@ class Experiment {
|
||||||
HybridNonlinearFactor loopFactor =
|
HybridNonlinearFactor loopFactor =
|
||||||
hybridLoopClosureFactor(loopCount, keyS, keyT, odomPose);
|
hybridLoopClosureFactor(loopCount, keyS, keyT, odomPose);
|
||||||
// print loop closure event keys:
|
// print loop closure event keys:
|
||||||
// std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
|
std::cout << "Loop closure: " << keyS << " " << keyT << std::endl;
|
||||||
graph_.add(loopFactor);
|
graph_.add(loopFactor);
|
||||||
doSmootherUpdate = true;
|
numberOfHybridFactors += 1;
|
||||||
loopCount++;
|
loopCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (doSmootherUpdate) {
|
if (numberOfHybridFactors >= updateFrequency) {
|
||||||
|
// print the keys involved in the smoother update
|
||||||
|
std::cout << "Smoother update: " << graph_.size() << std::endl;
|
||||||
gttic_(SmootherUpdate);
|
gttic_(SmootherUpdate);
|
||||||
beforeUpdate = clock();
|
beforeUpdate = clock();
|
||||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||||
afterUpdate = clock();
|
afterUpdate = clock();
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||||
gttoc_(SmootherUpdate);
|
gttoc_(SmootherUpdate);
|
||||||
doSmootherUpdate = false;
|
numberOfHybridFactors = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record timing for odometry edges only
|
// Record timing for odometry edges only
|
||||||
|
@ -259,7 +271,7 @@ class Experiment {
|
||||||
|
|
||||||
// Final update
|
// Final update
|
||||||
beforeUpdate = clock();
|
beforeUpdate = clock();
|
||||||
smootherUpdate(smoother_, graph_, initial_, kMaxNrHypotheses, &result_);
|
smootherUpdate(smoother_, graph_, initial_, maxNrHypotheses, &result_);
|
||||||
afterUpdate = clock();
|
afterUpdate = clock();
|
||||||
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
smootherUpdateTimes.push_back({index, afterUpdate - beforeUpdate});
|
||||||
|
|
||||||
|
@ -289,7 +301,7 @@ class Experiment {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// Write timing info to file
|
// Write timing info to file
|
||||||
ofstream outfileTime;
|
std::ofstream outfileTime;
|
||||||
std::string timeFileName = "Hybrid_City10000_time.txt";
|
std::string timeFileName = "Hybrid_City10000_time.txt";
|
||||||
outfileTime.open(timeFileName);
|
outfileTime.open(timeFileName);
|
||||||
for (auto accTime : timeList) {
|
for (auto accTime : timeList) {
|
||||||
|
@ -301,15 +313,47 @@ class Experiment {
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
// Function to parse command-line arguments
|
||||||
|
void parseArguments(int argc, char* argv[], size_t& maxLoopCount,
|
||||||
|
size_t& updateFrequency, size_t& maxNrHypotheses) {
|
||||||
|
for (int i = 1; i < argc; ++i) {
|
||||||
|
std::string arg = argv[i];
|
||||||
|
if (arg == "--max-loop-count" && i + 1 < argc) {
|
||||||
|
maxLoopCount = std::stoul(argv[++i]);
|
||||||
|
} else if (arg == "--update-frequency" && i + 1 < argc) {
|
||||||
|
updateFrequency = std::stoul(argv[++i]);
|
||||||
|
} else if (arg == "--max-nr-hypotheses" && i + 1 < argc) {
|
||||||
|
maxNrHypotheses = std::stoul(argv[++i]);
|
||||||
|
} else if (arg == "--help") {
|
||||||
|
std::cout << "Usage: " << argv[0] << " [options]\n"
|
||||||
|
<< "Options:\n"
|
||||||
|
<< " --max-loop-count <value> Set the maximum loop "
|
||||||
|
"count (default: 3000)\n"
|
||||||
|
<< " --update-frequency <value> Set the update frequency "
|
||||||
|
"(default: 3)\n"
|
||||||
|
<< " --max-nr-hypotheses <value> Set the maximum number of "
|
||||||
|
"hypotheses (default: 10)\n"
|
||||||
|
<< " --help Show this help message\n";
|
||||||
|
std::exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Main function
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
Experiment experiment(findExampleDataFile("T1_city10000_04.txt"));
|
Experiment experiment(findExampleDataFile("T1_city10000_04.txt"));
|
||||||
// Experiment experiment("../data/mh_T1_city10000_04.txt"); //Type #1 only
|
// Experiment experiment("../data/mh_T1_city10000_04.txt"); //Type #1 only
|
||||||
// Experiment experiment("../data/mh_T3b_city10000_10.txt"); //Type #3 only
|
// Experiment experiment("../data/mh_T3b_city10000_10.txt"); //Type #3 only
|
||||||
// Experiment experiment("../data/mh_T1_T3_city10000_04.txt"); //Type #1 +
|
// Experiment experiment("../data/mh_T1_T3_city10000_04.txt"); //Type #1 +
|
||||||
// Type #3
|
// Type #3
|
||||||
|
|
||||||
|
// Parse command-line arguments
|
||||||
|
parseArguments(argc, argv, experiment.maxLoopCount,
|
||||||
|
experiment.updateFrequency, experiment.maxNrHypotheses);
|
||||||
|
|
||||||
// Run the experiment
|
// Run the experiment
|
||||||
experiment.run(kMaxLoopCount);
|
experiment.run();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -74,8 +74,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/// equality up to tolerance
|
/// equality up to tolerance
|
||||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
if (!q.isLeaf()) return false;
|
||||||
if (!other) return false;
|
const Leaf* other = static_cast<const Leaf*>(&q);
|
||||||
return compare(this->constant_, other->constant_);
|
return compare(this->constant_, other->constant_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,36 +202,39 @@ namespace gtsam {
|
||||||
* @param node The root node of the decision tree.
|
* @param node The root node of the decision tree.
|
||||||
* @return NodePtr
|
* @return NodePtr
|
||||||
*/
|
*/
|
||||||
|
#ifdef GTSAM_DT_MERGING
|
||||||
|
|
||||||
static NodePtr Unique(const NodePtr& node) {
|
static NodePtr Unique(const NodePtr& node) {
|
||||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
if (node->isLeaf()) return node; // Leaf node, return as is
|
||||||
// Choice node, we recurse!
|
|
||||||
// Make non-const copy so we can update
|
|
||||||
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
|
||||||
|
|
||||||
// Iterate over all the branches
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
// Choice node, we recurse!
|
||||||
auto branch = choice->branches_[i];
|
// Make non-const copy so we can update
|
||||||
f->push_back(Unique(branch));
|
auto f = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef GTSAM_DT_MERGING
|
// Iterate over all the branches
|
||||||
// If all the branches are the same, we can merge them into one
|
for (const auto& branch : choice->branches_) {
|
||||||
if (f->allSame_) {
|
f->push_back(Unique(branch));
|
||||||
assert(f->branches().size() > 0);
|
|
||||||
NodePtr f0 = f->branches_[0];
|
|
||||||
|
|
||||||
NodePtr newLeaf(
|
|
||||||
new Leaf(std::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
|
||||||
return newLeaf;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return f;
|
|
||||||
} else {
|
|
||||||
// Leaf node, return as is
|
|
||||||
return node;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If all the branches are the same, we can merge them into one
|
||||||
|
if (f->allSame_) {
|
||||||
|
assert(f->branches().size() > 0);
|
||||||
|
auto f0 = std::static_pointer_cast<const Leaf>(f->branches_[0]);
|
||||||
|
return std::make_shared<Leaf>(f0->constant());
|
||||||
|
}
|
||||||
|
|
||||||
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
static NodePtr Unique(const NodePtr& node) {
|
||||||
|
// No-op when GTSAM_DT_MERGING is not defined
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
bool isLeaf() const override { return false; }
|
bool isLeaf() const override { return false; }
|
||||||
|
|
||||||
/// Constructor, given choice label and mandatory expected branch count.
|
/// Constructor, given choice label and mandatory expected branch count.
|
||||||
|
@ -322,9 +325,9 @@ namespace gtsam {
|
||||||
const NodePtr& branch = branches_[i];
|
const NodePtr& branch = branches_[i];
|
||||||
|
|
||||||
// Check if zero
|
// Check if zero
|
||||||
if (!showZero) {
|
if (!showZero && branch->isLeaf()) {
|
||||||
const Leaf* leaf = dynamic_cast<const Leaf*>(branch.get());
|
auto leaf = std::static_pointer_cast<const Leaf>(branch);
|
||||||
if (leaf && valueFormatter(leaf->constant()).compare("0")) continue;
|
if (valueFormatter(leaf->constant()).compare("0")) continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
os << "\"" << this->id() << "\" -> \"" << branch->id() << "\"";
|
||||||
|
@ -346,8 +349,8 @@ namespace gtsam {
|
||||||
|
|
||||||
/// equality
|
/// equality
|
||||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
if (q.isLeaf()) return false;
|
||||||
if (!other) return false;
|
const Choice* other = static_cast<const Choice*>(&q);
|
||||||
if (this->label_ != other->label_) return false;
|
if (this->label_ != other->label_) return false;
|
||||||
if (branches_.size() != other->branches_.size()) return false;
|
if (branches_.size() != other->branches_.size()) return false;
|
||||||
// we don't care about shared pointers being equal here
|
// we don't care about shared pointers being equal here
|
||||||
|
@ -570,11 +573,13 @@ namespace gtsam {
|
||||||
struct ApplyUnary {
|
struct ApplyUnary {
|
||||||
const Unary& op;
|
const Unary& op;
|
||||||
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
|
if (node->isLeaf()) {
|
||||||
// Apply the unary operation to the leaf's constant value
|
// Apply the unary operation to the leaf's constant value
|
||||||
|
auto leaf = std::static_pointer_cast<Leaf>(node);
|
||||||
leaf->constant_ = op(leaf->constant_);
|
leaf->constant_ = op(leaf->constant_);
|
||||||
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
} else {
|
||||||
// Recurse into the choice branches
|
// Recurse into the choice branches
|
||||||
|
auto choice = std::static_pointer_cast<Choice>(node);
|
||||||
for (NodePtr& branch : choice->branches()) {
|
for (NodePtr& branch : choice->branches()) {
|
||||||
(*this)(branch);
|
(*this)(branch);
|
||||||
}
|
}
|
||||||
|
@ -622,8 +627,7 @@ namespace gtsam {
|
||||||
for (Iterator it = begin; it != end; it++) {
|
for (Iterator it = begin; it != end; it++) {
|
||||||
if (it->root_->isLeaf())
|
if (it->root_->isLeaf())
|
||||||
continue;
|
continue;
|
||||||
std::shared_ptr<const Choice> c =
|
auto c = std::static_pointer_cast<const Choice>(it->root_);
|
||||||
std::dynamic_pointer_cast<const Choice>(it->root_);
|
|
||||||
if (!highestLabel || c->label() > *highestLabel) {
|
if (!highestLabel || c->label() > *highestLabel) {
|
||||||
highestLabel = c->label();
|
highestLabel = c->label();
|
||||||
nrChoices = c->nrChoices();
|
nrChoices = c->nrChoices();
|
||||||
|
@ -729,11 +733,7 @@ namespace gtsam {
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) {
|
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||||
auto node = build(begin, end, beginY, endY);
|
auto node = build(begin, end, beginY, endY);
|
||||||
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
return Choice::Unique(node);
|
||||||
return Choice::Unique(choice);
|
|
||||||
} else {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -742,18 +742,17 @@ namespace gtsam {
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<L, X>::NodePtr& f,
|
const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
std::function<Y(const X&)> Y_of_X) {
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
|
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||||
|
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||||
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
if (f->isLeaf()) {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
|
auto leaf = std::static_pointer_cast<LXLeaf>(f);
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if Choice
|
// Now a Choice!
|
||||||
using LXChoice = typename DecisionTree<L, X>::Choice;
|
auto choice = std::static_pointer_cast<const LXChoice>(f);
|
||||||
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
|
||||||
if (!choice) throw std::invalid_argument(
|
|
||||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
|
||||||
|
|
||||||
// Create a new Choice node with the same label
|
// Create a new Choice node with the same label
|
||||||
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
@ -773,18 +772,17 @@ namespace gtsam {
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
if (f->isLeaf()) {
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
auto leaf = std::static_pointer_cast<const MXLeaf>(f);
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if Choice
|
// Now is Choice!
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
auto choice = std::static_pointer_cast<const MXChoice>(f);
|
||||||
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
|
||||||
if (!choice)
|
|
||||||
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
|
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
|
@ -826,13 +824,14 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(leaf->constant());
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(leaf->constant());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -863,13 +862,14 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(*leaf);
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(*leaf);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -898,13 +898,16 @@ namespace gtsam {
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const Leaf>(node))
|
|
||||||
return f(assignment, leaf->constant());
|
|
||||||
|
|
||||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const Choice>(node);
|
|
||||||
if (!choice)
|
if (node->isLeaf()) {
|
||||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
auto leaf = std::static_pointer_cast<const Leaf>(node);
|
||||||
|
return f(assignment, leaf->constant());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
auto choice = std::static_pointer_cast<const Choice>(node);
|
||||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||||
assignment[choice->label()] = i; // Set assignment for label to i
|
assignment[choice->label()] = i; // Set assignment for label to i
|
||||||
|
|
||||||
|
|
|
@ -164,6 +164,12 @@ namespace gtsam {
|
||||||
virtual DiscreteFactor::shared_ptr multiply(
|
virtual DiscreteFactor::shared_ptr multiply(
|
||||||
const DiscreteFactor::shared_ptr& f) const override;
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
||||||
|
/// multiply with a scalar
|
||||||
|
DiscreteFactor::shared_ptr operator*(double s) const override {
|
||||||
|
return std::make_shared<DecisionTreeFactor>(
|
||||||
|
apply([s](const double& a) { return Ring::mul(a, s); }));
|
||||||
|
}
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
return apply(f, Ring::mul);
|
return apply(f, Ring::mul);
|
||||||
|
@ -201,6 +207,9 @@ namespace gtsam {
|
||||||
return combine(keys, Ring::add);
|
return combine(keys, Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find the maximum value in the factor.
|
||||||
|
double max() const override { return ADT::max(); };
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||||
return combine(nrFrontals, Ring::max);
|
return combine(nrFrontals, Ring::max);
|
||||||
|
|
|
@ -89,7 +89,7 @@ DiscreteBayesNet DiscreteBayesNet::prune(
|
||||||
DiscreteValues deadModesValues;
|
DiscreteValues deadModesValues;
|
||||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||||
// then we run dead mode removal.
|
// then we run dead mode removal.
|
||||||
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
|
if (marginalThreshold && pruned.keys().size() > 0) {
|
||||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||||
for (auto dkey : pruned.discreteKeys()) {
|
for (auto dkey : pruned.discreteKeys()) {
|
||||||
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||||
|
|
|
@ -73,10 +73,7 @@ AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr DiscreteFactor::scale() const {
|
DiscreteFactor::shared_ptr DiscreteFactor::scale() const {
|
||||||
// Max over all the potentials by pretending all keys are frontal:
|
return this->operator*(1.0 / max());
|
||||||
shared_ptr denominator = this->max(this->size());
|
|
||||||
// Normalize the product factor to prevent underflow.
|
|
||||||
return this->operator/(denominator);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// Compute error for each assignment and return as a tree
|
/// Compute error for each assignment and return as a tree
|
||||||
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
||||||
|
|
||||||
|
/// Multiply with a scalar
|
||||||
|
virtual DiscreteFactor::shared_ptr operator*(double s) const = 0;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as
|
/// Multiply in a DecisionTreeFactor and return the result as
|
||||||
/// DecisionTreeFactor
|
/// DecisionTreeFactor
|
||||||
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
||||||
|
@ -152,6 +155,9 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
|
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;
|
||||||
|
|
||||||
|
/// Find the maximum value in the factor.
|
||||||
|
virtual double max() const = 0;
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
|
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;
|
||||||
|
|
||||||
|
|
|
@ -65,15 +65,10 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
DiscreteFactor::shared_ptr DiscreteFactorGraph::product() const {
|
||||||
DiscreteFactor::shared_ptr result;
|
DiscreteFactor::shared_ptr result = nullptr;
|
||||||
for (auto it = this->begin(); it != this->end(); ++it) {
|
for (const auto& factor : *this) {
|
||||||
if (*it) {
|
if (factor) {
|
||||||
if (result) {
|
result = result ? result->multiply(factor) : factor;
|
||||||
result = result->multiply(*it);
|
|
||||||
} else {
|
|
||||||
// Assign to the first non-null factor
|
|
||||||
result = *it;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
@ -120,15 +115,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
|
DiscreteFactor::shared_ptr DiscreteFactorGraph::scaledProduct() const {
|
||||||
// PRODUCT: multiply all factors
|
return product()->scale();
|
||||||
gttic(product);
|
|
||||||
DiscreteFactor::shared_ptr product = this->product();
|
|
||||||
gttoc(product);
|
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
|
||||||
product = product->scale();
|
|
||||||
|
|
||||||
return product;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
@ -216,7 +203,7 @@ namespace gtsam {
|
||||||
const Ordering& frontalKeys) {
|
const Ordering& frontalKeys) {
|
||||||
gttic(product);
|
gttic(product);
|
||||||
// `product` is scaled later to prevent underflow.
|
// `product` is scaled later to prevent underflow.
|
||||||
DiscreteFactor::shared_ptr product = factors.product();
|
DiscreteFactor::shared_ptr product = factors.scaledProduct();
|
||||||
gttoc(product);
|
gttoc(product);
|
||||||
|
|
||||||
// sum out frontals, this is the factor on the separator
|
// sum out frontals, this is the factor on the separator
|
||||||
|
@ -224,16 +211,6 @@ namespace gtsam {
|
||||||
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
DiscreteFactor::shared_ptr sum = product->sum(frontalKeys);
|
||||||
gttoc(sum);
|
gttoc(sum);
|
||||||
|
|
||||||
// Normalize/scale to prevent underflow.
|
|
||||||
// We divide both `product` and `sum` by `max(sum)`
|
|
||||||
// since it is faster to compute and when the conditional
|
|
||||||
// is formed by `product/sum`, the scaling term cancels out.
|
|
||||||
// gttic(scale);
|
|
||||||
// DiscreteFactor::shared_ptr denominator = sum->max(sum->size());
|
|
||||||
// product = product->operator/(denominator);
|
|
||||||
// sum = sum->operator/(denominator);
|
|
||||||
// gttoc(scale);
|
|
||||||
|
|
||||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||||
Ordering orderedKeys;
|
Ordering orderedKeys;
|
||||||
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(),
|
||||||
|
|
|
@ -110,6 +110,11 @@ DiscreteFactor::shared_ptr TableDistribution::max(const Ordering& keys) const {
|
||||||
return table_.max(keys);
|
return table_.max(keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
DiscreteFactor::shared_ptr TableDistribution::operator*(double s) const {
|
||||||
|
return table_ * s;
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
DiscreteFactor::shared_ptr TableDistribution::operator/(
|
DiscreteFactor::shared_ptr TableDistribution::operator/(
|
||||||
const DiscreteFactor::shared_ptr& f) const {
|
const DiscreteFactor::shared_ptr& f) const {
|
||||||
|
|
|
@ -116,12 +116,19 @@ class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
||||||
|
|
||||||
|
/// Find the maximum value in the factor.
|
||||||
|
double max() const override { return table_.max(); }
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||||
|
|
||||||
|
|
||||||
|
/// Multiply by scalar s
|
||||||
|
DiscreteFactor::shared_ptr operator*(double s) const override;
|
||||||
|
|
||||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
DiscreteFactor::shared_ptr operator/(
|
DiscreteFactor::shared_ptr operator/(
|
||||||
const DiscreteFactor::shared_ptr& f) const override;
|
const DiscreteFactor::shared_ptr& f) const override;
|
||||||
|
|
|
@ -389,6 +389,36 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
|
||||||
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::sum(size_t nrFrontals) const {
|
||||||
|
return combine(nrFrontals, Ring::add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::sum(const Ordering& keys) const {
|
||||||
|
return combine(keys, Ring::add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
double TableFactor::max() const {
|
||||||
|
double max_value = std::numeric_limits<double>::lowest();
|
||||||
|
for (Eigen::SparseVector<double>::InnerIterator it(sparse_table_); it; ++it) {
|
||||||
|
max_value = std::max(max_value, it.value());
|
||||||
|
}
|
||||||
|
return max_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::max(size_t nrFrontals) const {
|
||||||
|
return combine(nrFrontals, Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
DiscreteFactor::shared_ptr TableFactor::max(const Ordering& keys) const {
|
||||||
|
return combine(keys, Ring::max);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor TableFactor::apply(Unary op) const {
|
TableFactor TableFactor::apply(Unary op) const {
|
||||||
// Initialize new factor.
|
// Initialize new factor.
|
||||||
|
|
|
@ -171,6 +171,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const override;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
|
/// multiply with a scalar
|
||||||
|
DiscreteFactor::shared_ptr operator*(double s) const override {
|
||||||
|
return std::make_shared<TableFactor>(
|
||||||
|
apply([s](const double& a) { return Ring::mul(a, s); }));
|
||||||
|
}
|
||||||
|
|
||||||
/// multiply two TableFactors
|
/// multiply two TableFactors
|
||||||
TableFactor operator*(const TableFactor& f) const {
|
TableFactor operator*(const TableFactor& f) const {
|
||||||
return apply(f, Ring::mul);
|
return apply(f, Ring::mul);
|
||||||
|
@ -215,24 +221,19 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
DiscreteKeys parent_keys) const;
|
DiscreteKeys parent_keys) const;
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
|
||||||
return combine(nrFrontals, Ring::add);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
||||||
return combine(keys, Ring::add);
|
|
||||||
}
|
/// Find the maximum value in the factor.
|
||||||
|
double max() const override;
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
||||||
return combine(nrFrontals, Ring::max);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by maximizing over all values with the same separator.
|
/// Create new factor by maximizing over all values with the same separator.
|
||||||
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
||||||
return combine(keys, Ring::max);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
|
|
|
@ -65,6 +65,7 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
}
|
}
|
||||||
|
|
||||||
HybridBayesNet result;
|
HybridBayesNet result;
|
||||||
|
result.reserve(size());
|
||||||
|
|
||||||
// Go through all the Gaussian conditionals, restrict them according to
|
// Go through all the Gaussian conditionals, restrict them according to
|
||||||
// fixed values, and then prune further.
|
// fixed values, and then prune further.
|
||||||
|
@ -84,9 +85,9 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
}
|
}
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
result.push_back(prunedHybridGaussianConditional);
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (conditional->isContinuous()) {
|
||||||
// Add the non-HybridGaussianConditional conditional
|
// Add the non-Hybrid GaussianConditional conditional
|
||||||
result.push_back(gc);
|
result.push_back(conditional);
|
||||||
} else
|
} else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
|
||||||
|
|
|
@ -23,6 +23,7 @@ namespace gtsam {
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys) {
|
const DiscreteKeys &discreteKeys) {
|
||||||
KeyVector allKeys;
|
KeyVector allKeys;
|
||||||
|
allKeys.reserve(continuousKeys.size() + discreteKeys.size());
|
||||||
std::copy(continuousKeys.begin(), continuousKeys.end(),
|
std::copy(continuousKeys.begin(), continuousKeys.end(),
|
||||||
std::back_inserter(allKeys));
|
std::back_inserter(allKeys));
|
||||||
std::transform(discreteKeys.begin(), discreteKeys.end(),
|
std::transform(discreteKeys.begin(), discreteKeys.end(),
|
||||||
|
@ -34,6 +35,7 @@ KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
|
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
|
||||||
KeyVector allKeys;
|
KeyVector allKeys;
|
||||||
|
allKeys.reserve(keys1.size() + keys2.size());
|
||||||
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
|
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
|
||||||
std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys));
|
std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys));
|
||||||
return allKeys;
|
return allKeys;
|
||||||
|
|
|
@ -191,11 +191,19 @@ size_t HybridGaussianConditional::nrComponents() const {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
auto &[factor, _] = factors()(discreteValues);
|
try {
|
||||||
if (!factor) return nullptr;
|
auto &[factor, _] = factors()(discreteValues);
|
||||||
|
if (!factor) return nullptr;
|
||||||
|
|
||||||
auto conditional = checkConditional(factor);
|
auto conditional = checkConditional(factor);
|
||||||
return conditional;
|
return conditional;
|
||||||
|
} catch (const std::out_of_range &e) {
|
||||||
|
GTSAM_PRINT(*this);
|
||||||
|
GTSAM_PRINT(discreteValues);
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional::choose: discreteValues does not contain "
|
||||||
|
"all discrete parents.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -50,7 +50,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define GTSAM_HYBRID_WITH_TABLEFACTOR 0
|
#define GTSAM_HYBRID_WITH_TABLEFACTOR 1
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
|
@ -21,55 +21,70 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
|
// #define DEBUG_SMOOTHER
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
Ordering HybridSmoother::getOrdering(const HybridGaussianFactorGraph &factors,
|
||||||
const KeySet &newFactorKeys) {
|
const KeySet &lastKeysToEliminate) {
|
||||||
// Get all the discrete keys from the factors
|
// Get all the discrete keys from the factors
|
||||||
KeySet allDiscrete = factors.discreteKeySet();
|
KeySet allDiscrete = factors.discreteKeySet();
|
||||||
|
|
||||||
// Create KeyVector with continuous keys followed by discrete keys.
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector lastKeys;
|
||||||
|
|
||||||
// Insert continuous keys first.
|
// Insert continuous keys first.
|
||||||
for (auto &k : newFactorKeys) {
|
for (auto &k : lastKeysToEliminate) {
|
||||||
if (!allDiscrete.exists(k)) {
|
if (!allDiscrete.exists(k)) {
|
||||||
newKeysDiscreteLast.push_back(k);
|
lastKeys.push_back(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert discrete keys at the end
|
// Insert discrete keys at the end
|
||||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||||
std::back_inserter(newKeysDiscreteLast));
|
std::back_inserter(lastKeys));
|
||||||
|
|
||||||
const VariableIndex index(factors);
|
|
||||||
|
|
||||||
// Get an ordering where the new keys are eliminated last
|
// Get an ordering where the new keys are eliminated last
|
||||||
Ordering ordering = Ordering::ColamdConstrainedLast(
|
Ordering ordering = Ordering::ColamdConstrainedLast(
|
||||||
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
|
factors, KeyVector(lastKeys.begin(), lastKeys.end()), true);
|
||||||
true);
|
|
||||||
return ordering;
|
return ordering;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
void HybridSmoother::update(const HybridGaussianFactorGraph &newFactors,
|
||||||
std::optional<size_t> maxNrLeaves,
|
std::optional<size_t> maxNrLeaves,
|
||||||
const std::optional<Ordering> given_ordering) {
|
const std::optional<Ordering> given_ordering) {
|
||||||
|
const KeySet originalNewFactorKeys = newFactors.keys();
|
||||||
|
#ifdef DEBUG_SMOOTHER
|
||||||
|
std::cout << "hybridBayesNet_ size before: " << hybridBayesNet_.size()
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "newFactors size: " << newFactors.size() << std::endl;
|
||||||
|
#endif
|
||||||
HybridGaussianFactorGraph updatedGraph;
|
HybridGaussianFactorGraph updatedGraph;
|
||||||
// Add the necessary conditionals from the previous timestep(s).
|
// Add the necessary conditionals from the previous timestep(s).
|
||||||
std::tie(updatedGraph, hybridBayesNet_) =
|
std::tie(updatedGraph, hybridBayesNet_) =
|
||||||
addConditionals(graph, hybridBayesNet_);
|
addConditionals(newFactors, hybridBayesNet_);
|
||||||
|
#ifdef DEBUG_SMOOTHER
|
||||||
|
// print size of newFactors, updatedGraph, hybridBayesNet_
|
||||||
|
std::cout << "updatedGraph size: " << updatedGraph.size() << std::endl;
|
||||||
|
std::cout << "hybridBayesNet_ size after: " << hybridBayesNet_.size()
|
||||||
|
<< std::endl;
|
||||||
|
std::cout << "total size: " << updatedGraph.size() + hybridBayesNet_.size()
|
||||||
|
<< std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
// If no ordering provided, then we compute one
|
// If no ordering provided, then we compute one
|
||||||
if (!given_ordering.has_value()) {
|
if (!given_ordering.has_value()) {
|
||||||
// Get the keys from the new factors
|
// Get the keys from the new factors
|
||||||
const KeySet newFactorKeys = graph.keys();
|
KeySet continuousKeysToInclude; // Scheme 1: empty, 15sec/2000, 64sec/3000 (69s without TF)
|
||||||
|
// continuousKeysToInclude = newFactors.keys(); // Scheme 2: all, 8sec/2000, 160sec/3000
|
||||||
|
// continuousKeysToInclude = updatedGraph.keys(); // Scheme 3: all, stopped after 80sec/2000
|
||||||
|
|
||||||
// Since updatedGraph now has all the connected conditionals,
|
// Since updatedGraph now has all the connected conditionals,
|
||||||
// we can get the correct ordering.
|
// we can get the correct ordering.
|
||||||
ordering = this->getOrdering(updatedGraph, newFactorKeys);
|
ordering = this->getOrdering(updatedGraph, continuousKeysToInclude);
|
||||||
} else {
|
} else {
|
||||||
ordering = *given_ordering;
|
ordering = *given_ordering;
|
||||||
}
|
}
|
||||||
|
@ -83,6 +98,22 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
gttoc_(HybridSmootherEliminate);
|
gttoc_(HybridSmootherEliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef DEBUG_SMOOTHER_DETAIL
|
||||||
|
for (auto conditional : bayesNetFragment) {
|
||||||
|
auto e = std::dynamic_pointer_cast<HybridConditional::BaseConditional>(
|
||||||
|
conditional);
|
||||||
|
GTSAM_PRINT(*e);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef DEBUG_SMOOTHER
|
||||||
|
// Print discrete keys in the bayesNetFragment:
|
||||||
|
std::cout << "Discrete keys in bayesNetFragment: ";
|
||||||
|
for (auto &key : HybridFactorGraph(bayesNetFragment).discreteKeySet()) {
|
||||||
|
std::cout << DefaultKeyFormatter(key) << " ";
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
/// Prune
|
/// Prune
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
|
@ -90,24 +121,47 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
||||||
#endif
|
#endif
|
||||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||||
// all the conditionals with the same keys in bayesNetFragment.
|
// all the conditionals with the same keys in bayesNetFragment.
|
||||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
|
DiscreteValues newlyFixedValues;
|
||||||
|
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_,
|
||||||
|
&newlyFixedValues);
|
||||||
|
fixedValues_.insert(newlyFixedValues);
|
||||||
#if GTSAM_HYBRID_TIMING
|
#if GTSAM_HYBRID_TIMING
|
||||||
gttoc_(HybridSmootherPrune);
|
gttoc_(HybridSmootherPrune);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef DEBUG_SMOOTHER
|
||||||
|
// Print discrete keys in the bayesNetFragment:
|
||||||
|
std::cout << "\nAfter pruning: ";
|
||||||
|
for (auto &key : HybridFactorGraph(bayesNetFragment).discreteKeySet()) {
|
||||||
|
std::cout << DefaultKeyFormatter(key) << " ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl << std::endl;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef DEBUG_SMOOTHER_DETAIL
|
||||||
|
for (auto conditional : bayesNetFragment) {
|
||||||
|
auto c = std::dynamic_pointer_cast<HybridConditional::BaseConditional>(
|
||||||
|
conditional);
|
||||||
|
GTSAM_PRINT(*c);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Add the partial bayes net to the posterior bayes net.
|
// Add the partial bayes net to the posterior bayes net.
|
||||||
hybridBayesNet_.add(bayesNetFragment);
|
hybridBayesNet_.add(bayesNetFragment);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
std::pair<HybridGaussianFactorGraph, HybridBayesNet>
|
||||||
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
HybridSmoother::addConditionals(const HybridGaussianFactorGraph &newFactors,
|
||||||
const HybridBayesNet &hybridBayesNet) const {
|
const HybridBayesNet &hybridBayesNet) const {
|
||||||
HybridGaussianFactorGraph graph(originalGraph);
|
HybridGaussianFactorGraph graph(newFactors);
|
||||||
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
HybridBayesNet updatedHybridBayesNet(hybridBayesNet);
|
||||||
|
|
||||||
KeySet factorKeys = graph.keys();
|
KeySet involvedKeys = newFactors.keys();
|
||||||
|
auto involved = [&involvedKeys](const Key &key) {
|
||||||
|
return involvedKeys.find(key) != involvedKeys.end();
|
||||||
|
};
|
||||||
|
|
||||||
// If hybridBayesNet is not empty,
|
// If hybridBayesNet is not empty,
|
||||||
// it means we have conditionals to add to the factor graph.
|
// it means we have conditionals to add to the factor graph.
|
||||||
|
@ -129,25 +183,26 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
|
||||||
auto conditional = hybridBayesNet.at(i);
|
auto conditional = hybridBayesNet.at(i);
|
||||||
|
|
||||||
for (auto &key : conditional->frontals()) {
|
for (auto &key : conditional->frontals()) {
|
||||||
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
if (involved(key)) {
|
||||||
factorKeys.end()) {
|
// Add the conditional parents to involvedKeys
|
||||||
// Add the conditional parents to factorKeys
|
|
||||||
// so we add those conditionals too.
|
// so we add those conditionals too.
|
||||||
for (auto &&parentKey : conditional->parents()) {
|
for (auto &&parentKey : conditional->parents()) {
|
||||||
factorKeys.insert(parentKey);
|
involvedKeys.insert(parentKey);
|
||||||
}
|
}
|
||||||
// Break so we don't add parents twice.
|
// Break so we don't add parents twice.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#ifdef DEBUG_SMOOTHER
|
||||||
|
PrintKeySet(involvedKeys);
|
||||||
|
#endif
|
||||||
|
|
||||||
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
|
for (size_t i = 0; i < hybridBayesNet.size(); i++) {
|
||||||
auto conditional = hybridBayesNet.at(i);
|
auto conditional = hybridBayesNet.at(i);
|
||||||
|
|
||||||
for (auto &key : conditional->frontals()) {
|
for (auto &key : conditional->frontals()) {
|
||||||
if (std::find(factorKeys.begin(), factorKeys.end(), key) !=
|
if (involved(key)) {
|
||||||
factorKeys.end()) {
|
|
||||||
newConditionals.push_back(conditional);
|
newConditionals.push_back(conditional);
|
||||||
|
|
||||||
// Remove the conditional from the updated Bayes net
|
// Remove the conditional from the updated Bayes net
|
||||||
|
@ -177,4 +232,21 @@ const HybridBayesNet &HybridSmoother::hybridBayesNet() const {
|
||||||
return hybridBayesNet_;
|
return hybridBayesNet_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
HybridValues HybridSmoother::optimize() const {
|
||||||
|
// Solve for the MPE
|
||||||
|
DiscreteValues mpe = hybridBayesNet_.mpe();
|
||||||
|
|
||||||
|
// Add fixed values to the MPE.
|
||||||
|
mpe.insert(fixedValues_);
|
||||||
|
|
||||||
|
// Given the MPE, compute the optimal continuous values.
|
||||||
|
GaussianBayesNet gbn = hybridBayesNet_.choose(mpe);
|
||||||
|
const VectorValues continuous = gbn.optimize();
|
||||||
|
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
|
||||||
|
throw std::runtime_error("At least one nullptr factor in hybridBayesNet_");
|
||||||
|
}
|
||||||
|
return HybridValues(continuous, mpe);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -106,6 +106,9 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
|
|
||||||
/// Return the Bayes Net posterior.
|
/// Return the Bayes Net posterior.
|
||||||
const HybridBayesNet& hybridBayesNet() const;
|
const HybridBayesNet& hybridBayesNet() const;
|
||||||
|
|
||||||
|
/// Optimize the hybrid Bayes Net, taking into accound fixed values.
|
||||||
|
HybridValues optimize() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -28,6 +28,28 @@ using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
using symbol_shorthand::Z;
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test the HybridConditional constructor.
|
||||||
|
TEST(HybridConditional, Constructor) {
|
||||||
|
// Create a HybridGaussianConditional.
|
||||||
|
const KeyVector continuousKeys{X(0), X(1)};
|
||||||
|
const DiscreteKeys discreteKeys{{M(0), 2}};
|
||||||
|
const size_t nFrontals = 1;
|
||||||
|
const HybridConditional hc(continuousKeys, discreteKeys, nFrontals);
|
||||||
|
|
||||||
|
// Check Frontals:
|
||||||
|
EXPECT_LONGS_EQUAL(1, hc.nrFrontals());
|
||||||
|
const auto frontals = hc.frontals();
|
||||||
|
EXPECT_LONGS_EQUAL(1, frontals.size());
|
||||||
|
EXPECT_LONGS_EQUAL(X(0), *frontals.begin());
|
||||||
|
|
||||||
|
// Check parents:
|
||||||
|
const auto parents = hc.parents();
|
||||||
|
EXPECT_LONGS_EQUAL(2, parents.size());
|
||||||
|
EXPECT_LONGS_EQUAL(X(1), *parents.begin());
|
||||||
|
EXPECT_LONGS_EQUAL(M(0), *(parents.begin() + 1));
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check invariants for all conditionals in a tiny Bayes net.
|
// Check invariants for all conditionals in a tiny Bayes net.
|
||||||
TEST(HybridConditional, Invariants) {
|
TEST(HybridConditional, Invariants) {
|
||||||
|
@ -43,6 +65,12 @@ TEST(HybridConditional, Invariants) {
|
||||||
auto hc0 = bn.at(0);
|
auto hc0 = bn.at(0);
|
||||||
CHECK(hc0->isHybrid());
|
CHECK(hc0->isHybrid());
|
||||||
|
|
||||||
|
// Check parents:
|
||||||
|
const auto parents = hc0->parents();
|
||||||
|
EXPECT_LONGS_EQUAL(2, parents.size());
|
||||||
|
EXPECT_LONGS_EQUAL(X(0), *parents.begin());
|
||||||
|
EXPECT_LONGS_EQUAL(M(0), *(parents.begin() + 1));
|
||||||
|
|
||||||
// Check invariants as a HybridGaussianConditional.
|
// Check invariants as a HybridGaussianConditional.
|
||||||
const auto conditional = hc0->asHybrid();
|
const auto conditional = hc0->asHybrid();
|
||||||
EXPECT(HybridGaussianConditional::CheckInvariants(*conditional, values));
|
EXPECT(HybridGaussianConditional::CheckInvariants(*conditional, values));
|
||||||
|
|
|
@ -467,7 +467,7 @@ std::shared_ptr<GaussianConditional> HessianFactor::eliminateCholesky(const Orde
|
||||||
info_.choleskyPartial(nFrontals);
|
info_.choleskyPartial(nFrontals);
|
||||||
|
|
||||||
// TODO(frank): pre-allocate GaussianConditional and write into it
|
// TODO(frank): pre-allocate GaussianConditional and write into it
|
||||||
const VerticalBlockMatrix Ab = info_.split(nFrontals);
|
VerticalBlockMatrix Ab = info_.split(nFrontals);
|
||||||
conditional = std::make_shared<GaussianConditional>(keys_, nFrontals, std::move(Ab));
|
conditional = std::make_shared<GaussianConditional>(keys_, nFrontals, std::move(Ab));
|
||||||
|
|
||||||
// Erase the eliminated keys in this factor
|
// Erase the eliminated keys in this factor
|
||||||
|
|
|
@ -87,6 +87,16 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
||||||
this->operator*(df->toDecisionTreeFactor()));
|
this->operator*(df->toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Multiply by a scalar
|
||||||
|
virtual DiscreteFactor::shared_ptr operator*(double s) const override {
|
||||||
|
return this->toDecisionTreeFactor() * s;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Multiply by a DecisionTreeFactor and return a DecisionTreeFactor
|
||||||
|
DecisionTreeFactor operator*(const DecisionTreeFactor& dtf) const override {
|
||||||
|
return this->toDecisionTreeFactor() * dtf;
|
||||||
|
}
|
||||||
|
|
||||||
/// divide by DiscreteFactor::shared_ptr f (safely)
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
||||||
DiscreteFactor::shared_ptr operator/(
|
DiscreteFactor::shared_ptr operator/(
|
||||||
const DiscreteFactor::shared_ptr& df) const override {
|
const DiscreteFactor::shared_ptr& df) const override {
|
||||||
|
@ -104,6 +114,9 @@ class GTSAM_UNSTABLE_EXPORT Constraint : public DiscreteFactor {
|
||||||
return toDecisionTreeFactor().sum(keys);
|
return toDecisionTreeFactor().sum(keys);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find the max value
|
||||||
|
double max() const override { return toDecisionTreeFactor().max(); }
|
||||||
|
|
||||||
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
|
||||||
return toDecisionTreeFactor().max(nrFrontals);
|
return toDecisionTreeFactor().max(nrFrontals);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue