Merge branch 'develop' into feature/attitude
commit
0c0ae7478c
|
|
@ -1,8 +1 @@
|
||||||
BasedOnStyle: Google
|
BasedOnStyle: Google
|
||||||
|
|
||||||
BinPackArguments: false
|
|
||||||
BinPackParameters: false
|
|
||||||
ColumnLimit: 100
|
|
||||||
DerivePointerAlignment: false
|
|
||||||
IncludeBlocks: Preserve
|
|
||||||
PointerAlignment: Left
|
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ name: "Bug Report"
|
||||||
about: Submit a bug report to help us improve GTSAM
|
about: Submit a bug report to help us improve GTSAM
|
||||||
---
|
---
|
||||||
|
|
||||||
|
<!-- This is a channel to report bugs/issues, not a support channel to help install/use/debug your own code. We'd love to help, but just don't have the bandwidth. Please post questions in the GTSAM Google group (https://groups.google.com/forum/#!forum/gtsam-users) -->
|
||||||
|
|
||||||
<!--Please only submit issues/bug reports that come with enough information to reproduce them, ideally a unit test that fails, and possible ideas on what might be wrong. -->
|
<!--Please only submit issues/bug reports that come with enough information to reproduce them, ideally a unit test that fails, and possible ideas on what might be wrong. -->
|
||||||
|
|
||||||
<!-- Even better yet, fix the bug and/or documentation, add a unit test, and create a pull request! -->
|
<!-- Even better yet, fix the bug and/or documentation, add a unit test, and create a pull request! -->
|
||||||
|
|
||||||
<!-- This is a channel to report bugs/issues, not a support channel to help install/use/debug your own code. We'd love to help, but just don't have the bandwidth. Please post questions in the GTSAM Google group (https://groups.google.com/forum/#!forum/gtsam-users) -->
|
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
<!-- A clear description of the bug -->
|
<!-- A clear description of the bug -->
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@
|
||||||
// Finally, once all of the factors have been added to our factor graph, we will want to
|
// Finally, once all of the factors have been added to our factor graph, we will want to
|
||||||
// solve/optimize to graph to find the best (Maximum A Posteriori) set of variable values.
|
// solve/optimize to graph to find the best (Maximum A Posteriori) set of variable values.
|
||||||
// GTSAM includes several nonlinear optimizers to perform this step. Here we will use a
|
// GTSAM includes several nonlinear optimizers to perform this step. Here we will use a
|
||||||
// trust-region method known as Powell's Degleg
|
// trust-region method known as Powell's Dogleg
|
||||||
#include <gtsam/nonlinear/DoglegOptimizer.h>
|
#include <gtsam/nonlinear/DoglegOptimizer.h>
|
||||||
|
|
||||||
// The nonlinear solvers within GTSAM are iterative solvers, meaning they linearize the
|
// The nonlinear solvers within GTSAM are iterative solvers, meaning they linearize the
|
||||||
|
|
@ -57,7 +57,7 @@ using namespace gtsam;
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main(int argc, char* argv[]) {
|
int main(int argc, char* argv[]) {
|
||||||
// Define the camera calibration parameters
|
// Define the camera calibration parameters
|
||||||
Cal3_S2::shared_ptr K(new Cal3_S2(50.0, 50.0, 0.0, 50.0, 50.0));
|
auto K = std::make_shared<Cal3_S2>(50.0, 50.0, 0.0, 50.0, 50.0);
|
||||||
|
|
||||||
// Define the camera observation noise model
|
// Define the camera observation noise model
|
||||||
auto measurementNoise =
|
auto measurementNoise =
|
||||||
|
|
|
||||||
|
|
@ -94,11 +94,10 @@ int main(int argc, char* argv[]) {
|
||||||
parameters.maxIterations = 500;
|
parameters.maxIterations = 500;
|
||||||
PCGSolverParameters::shared_ptr pcg =
|
PCGSolverParameters::shared_ptr pcg =
|
||||||
std::make_shared<PCGSolverParameters>();
|
std::make_shared<PCGSolverParameters>();
|
||||||
pcg->preconditioner_ =
|
pcg->preconditioner = std::make_shared<BlockJacobiPreconditionerParameters>();
|
||||||
std::make_shared<BlockJacobiPreconditionerParameters>();
|
|
||||||
// Following is crucial:
|
// Following is crucial:
|
||||||
pcg->setEpsilon_abs(1e-10);
|
pcg->epsilon_abs = 1e-10;
|
||||||
pcg->setEpsilon_rel(1e-10);
|
pcg->epsilon_rel = 1e-10;
|
||||||
parameters.iterativeParams = pcg;
|
parameters.iterativeParams = pcg;
|
||||||
|
|
||||||
LevenbergMarquardtOptimizer optimizer(graph, initialEstimate, parameters);
|
LevenbergMarquardtOptimizer optimizer(graph, initialEstimate, parameters);
|
||||||
|
|
|
||||||
|
|
@ -16,56 +16,89 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A structure-from-motion example with landmarks, default function arguments give
|
* A structure-from-motion example with landmarks, default arguments give:
|
||||||
* - The landmarks form a 10 meter cube
|
* - The landmarks form a 10 meter cube
|
||||||
* - The robot rotates around the landmarks, always facing towards the cube
|
* - The robot rotates around the landmarks, always facing towards the cube
|
||||||
* Passing function argument allows to specificy an initial position, a pose increment and step count.
|
* Passing function argument allows to specify an initial position, a pose
|
||||||
|
* increment and step count.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
// As this is a full 3D problem, we will use Pose3 variables to represent the camera
|
// As this is a full 3D problem, we will use Pose3 variables to represent the
|
||||||
// positions and Point3 variables (x, y, z) to represent the landmark coordinates.
|
// camera positions and Point3 variables (x, y, z) to represent the landmark
|
||||||
// Camera observations of landmarks (i.e. pixel coordinates) will be stored as Point2 (x, y).
|
// coordinates. Camera observations of landmarks (i.e. pixel coordinates) will
|
||||||
// We will also need a camera object to hold calibration information and perform projections.
|
// be stored as Point2 (x, y).
|
||||||
#include <gtsam/geometry/Pose3.h>
|
|
||||||
#include <gtsam/geometry/Point3.h>
|
#include <gtsam/geometry/Point3.h>
|
||||||
|
#include <gtsam/geometry/Pose3.h>
|
||||||
|
|
||||||
// We will also need a camera object to hold calibration information and perform projections.
|
// We will also need a camera object to hold calibration information and perform
|
||||||
#include <gtsam/geometry/PinholeCamera.h>
|
// projections.
|
||||||
#include <gtsam/geometry/Cal3_S2.h>
|
#include <gtsam/geometry/Cal3_S2.h>
|
||||||
|
#include <gtsam/geometry/PinholeCamera.h>
|
||||||
|
|
||||||
/* ************************************************************************* */
|
namespace gtsam {
|
||||||
std::vector<gtsam::Point3> createPoints() {
|
|
||||||
|
|
||||||
// Create the set of ground-truth landmarks
|
/// Create a set of ground-truth landmarks
|
||||||
std::vector<gtsam::Point3> points;
|
std::vector<Point3> createPoints() {
|
||||||
points.push_back(gtsam::Point3(10.0,10.0,10.0));
|
std::vector<Point3> points;
|
||||||
points.push_back(gtsam::Point3(-10.0,10.0,10.0));
|
points.push_back(Point3(10.0, 10.0, 10.0));
|
||||||
points.push_back(gtsam::Point3(-10.0,-10.0,10.0));
|
points.push_back(Point3(-10.0, 10.0, 10.0));
|
||||||
points.push_back(gtsam::Point3(10.0,-10.0,10.0));
|
points.push_back(Point3(-10.0, -10.0, 10.0));
|
||||||
points.push_back(gtsam::Point3(10.0,10.0,-10.0));
|
points.push_back(Point3(10.0, -10.0, 10.0));
|
||||||
points.push_back(gtsam::Point3(-10.0,10.0,-10.0));
|
points.push_back(Point3(10.0, 10.0, -10.0));
|
||||||
points.push_back(gtsam::Point3(-10.0,-10.0,-10.0));
|
points.push_back(Point3(-10.0, 10.0, -10.0));
|
||||||
points.push_back(gtsam::Point3(10.0,-10.0,-10.0));
|
points.push_back(Point3(-10.0, -10.0, -10.0));
|
||||||
|
points.push_back(Point3(10.0, -10.0, -10.0));
|
||||||
|
|
||||||
return points;
|
return points;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/**
|
||||||
std::vector<gtsam::Pose3> createPoses(
|
* Create a set of ground-truth poses
|
||||||
const gtsam::Pose3& init = gtsam::Pose3(gtsam::Rot3::Ypr(M_PI/2,0,-M_PI/2), gtsam::Point3(30, 0, 0)),
|
* Default values give a circular trajectory, radius 30 at pi/4 intervals,
|
||||||
const gtsam::Pose3& delta = gtsam::Pose3(gtsam::Rot3::Ypr(0,-M_PI/4,0), gtsam::Point3(sin(M_PI/4)*30, 0, 30*(1-sin(M_PI/4)))),
|
* always facing the circle center
|
||||||
|
*/
|
||||||
|
std::vector<Pose3> createPoses(
|
||||||
|
const Pose3& init = Pose3(Rot3::Ypr(M_PI_2, 0, -M_PI_2), {30, 0, 0}),
|
||||||
|
const Pose3& delta = Pose3(Rot3::Ypr(0, -M_PI_4, 0),
|
||||||
|
{sin(M_PI_4) * 30, 0, 30 * (1 - sin(M_PI_4))}),
|
||||||
int steps = 8) {
|
int steps = 8) {
|
||||||
|
std::vector<Pose3> poses;
|
||||||
|
poses.reserve(steps);
|
||||||
|
|
||||||
// Create the set of ground-truth poses
|
|
||||||
// Default values give a circular trajectory, radius 30 at pi/4 intervals, always facing the circle center
|
|
||||||
std::vector<gtsam::Pose3> poses;
|
|
||||||
int i = 1;
|
|
||||||
poses.push_back(init);
|
poses.push_back(init);
|
||||||
for(; i < steps; ++i) {
|
for (int i = 1; i < steps; ++i) {
|
||||||
poses.push_back(poses[i - 1].compose(delta));
|
poses.push_back(poses[i - 1].compose(delta));
|
||||||
}
|
}
|
||||||
|
|
||||||
return poses;
|
return poses;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create regularly spaced poses with specified radius and number of cameras
|
||||||
|
*/
|
||||||
|
std::vector<Pose3> posesOnCircle(int num_cameras = 8, double R = 30) {
|
||||||
|
const double theta = 2 * M_PI / num_cameras;
|
||||||
|
|
||||||
|
// Initial pose at angle 0, position (R, 0, 0), facing the center with Y-axis
|
||||||
|
// pointing down
|
||||||
|
const Pose3 init(Rot3::Ypr(M_PI_2, 0, -M_PI_2), {R, 0, 0});
|
||||||
|
|
||||||
|
// Delta rotation: rotate by -theta around Z-axis (counterclockwise movement)
|
||||||
|
Rot3 delta_rotation = Rot3::Ypr(0, -theta, 0);
|
||||||
|
|
||||||
|
// Delta translation in world frame
|
||||||
|
Vector3 delta_translation_world(R * (cos(theta) - 1), R * sin(theta), 0);
|
||||||
|
|
||||||
|
// Transform delta translation to local frame of the camera
|
||||||
|
Vector3 delta_translation_local =
|
||||||
|
init.rotation().inverse() * delta_translation_world;
|
||||||
|
|
||||||
|
// Define delta pose
|
||||||
|
const Pose3 delta(delta_rotation, delta_translation_local);
|
||||||
|
|
||||||
|
// Generate poses using createPoses
|
||||||
|
return createPoses(init, delta, num_cameras);
|
||||||
|
}
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file ViewGraphExample.cpp
|
||||||
|
* @brief View-graph calibration on a simulated dataset, a la Sweeney 2015
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date October 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/geometry/Cal3_S2.h>
|
||||||
|
#include <gtsam/geometry/PinholeCamera.h>
|
||||||
|
#include <gtsam/geometry/Point2.h>
|
||||||
|
#include <gtsam/geometry/Point3.h>
|
||||||
|
#include <gtsam/geometry/Pose3.h>
|
||||||
|
#include <gtsam/inference/EdgeKey.h>
|
||||||
|
#include <gtsam/nonlinear/LevenbergMarquardtOptimizer.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/sfm/TransferFactor.h>
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "SFMdata.h"
|
||||||
|
#include "gtsam/inference/Key.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main(int argc, char* argv[]) {
|
||||||
|
// Define the camera calibration parameters
|
||||||
|
Cal3_S2 K(50.0, 50.0, 0.0, 50.0, 50.0);
|
||||||
|
|
||||||
|
// Create the set of 8 ground-truth landmarks
|
||||||
|
vector<Point3> points = createPoints();
|
||||||
|
|
||||||
|
// Create the set of 4 ground-truth poses
|
||||||
|
vector<Pose3> poses = posesOnCircle(4, 30);
|
||||||
|
|
||||||
|
// Calculate ground truth fundamental matrices, 1 and 2 poses apart
|
||||||
|
auto F1 = FundamentalMatrix(K, poses[0].between(poses[1]), K);
|
||||||
|
auto F2 = FundamentalMatrix(K, poses[0].between(poses[2]), K);
|
||||||
|
|
||||||
|
// Simulate measurements from each camera pose
|
||||||
|
std::array<std::array<Point2, 8>, 4> p;
|
||||||
|
for (size_t i = 0; i < 4; ++i) {
|
||||||
|
PinholeCamera<Cal3_S2> camera(poses[i], K);
|
||||||
|
for (size_t j = 0; j < 8; ++j) {
|
||||||
|
p[i][j] = camera.project(points[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// This section of the code is inspired by the work of Sweeney et al.
|
||||||
|
// [link](sites.cs.ucsb.edu/~holl/pubs/Sweeney-2015-ICCV.pdf) on view-graph
|
||||||
|
// calibration. The graph is made up of transfer factors that enforce the
|
||||||
|
// epipolar constraint between corresponding points across three views, as
|
||||||
|
// described in the paper. Rather than adding one ternary error term per point
|
||||||
|
// in a triplet, we add three binary factors for sparsity during optimization.
|
||||||
|
// In this version, we only include triplets between 3 successive cameras.
|
||||||
|
NonlinearFactorGraph graph;
|
||||||
|
using Factor = TransferFactor<FundamentalMatrix>;
|
||||||
|
|
||||||
|
for (size_t a = 0; a < 4; ++a) {
|
||||||
|
size_t b = (a + 1) % 4; // Next camera
|
||||||
|
size_t c = (a + 2) % 4; // Camera after next
|
||||||
|
|
||||||
|
// Vectors to collect tuples for each factor
|
||||||
|
std::vector<std::tuple<Point2, Point2, Point2>> tuples1, tuples2, tuples3;
|
||||||
|
|
||||||
|
// Collect data for the three factors
|
||||||
|
for (size_t j = 0; j < 8; ++j) {
|
||||||
|
tuples1.emplace_back(p[a][j], p[b][j], p[c][j]);
|
||||||
|
tuples2.emplace_back(p[a][j], p[c][j], p[b][j]);
|
||||||
|
tuples3.emplace_back(p[c][j], p[b][j], p[a][j]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add transfer factors between views a, b, and c. Note that the EdgeKeys
|
||||||
|
// are crucial in performing the transfer in the right direction. We use
|
||||||
|
// exactly 8 unique EdgeKeys, corresponding to 8 unknown fundamental
|
||||||
|
// matrices we will optimize for.
|
||||||
|
graph.emplace_shared<Factor>(EdgeKey(a, c), EdgeKey(b, c), tuples1);
|
||||||
|
graph.emplace_shared<Factor>(EdgeKey(a, b), EdgeKey(b, c), tuples2);
|
||||||
|
graph.emplace_shared<Factor>(EdgeKey(a, c), EdgeKey(a, b), tuples3);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto formatter = [](Key key) {
|
||||||
|
EdgeKey edge(key);
|
||||||
|
return (std::string)edge;
|
||||||
|
};
|
||||||
|
|
||||||
|
graph.print("Factor Graph:\n", formatter);
|
||||||
|
|
||||||
|
// Create a delta vector to perturb the ground truth
|
||||||
|
// We can't really go far before convergence becomes problematic :-(
|
||||||
|
Vector7 delta;
|
||||||
|
delta << 1, 2, 3, 4, 5, 6, 7;
|
||||||
|
delta *= 1e-5;
|
||||||
|
|
||||||
|
// Create the data structure to hold the initial estimate to the solution
|
||||||
|
Values initialEstimate;
|
||||||
|
for (size_t a = 0; a < 4; ++a) {
|
||||||
|
size_t b = (a + 1) % 4; // Next camera
|
||||||
|
size_t c = (a + 2) % 4; // Camera after next
|
||||||
|
initialEstimate.insert(EdgeKey(a, b), F1.retract(delta));
|
||||||
|
initialEstimate.insert(EdgeKey(a, c), F2.retract(delta));
|
||||||
|
}
|
||||||
|
initialEstimate.print("Initial Estimates:\n", formatter);
|
||||||
|
graph.printErrors(initialEstimate, "errors: ", formatter);
|
||||||
|
|
||||||
|
/* Optimize the graph and print results */
|
||||||
|
LevenbergMarquardtParams params;
|
||||||
|
params.setlambdaInitial(1000.0); // Initialize lambda to a high value
|
||||||
|
params.setVerbosityLM("SUMMARY");
|
||||||
|
Values result =
|
||||||
|
LevenbergMarquardtOptimizer(graph, initialEstimate, params).optimize();
|
||||||
|
|
||||||
|
cout << "initial error = " << graph.error(initialEstimate) << endl;
|
||||||
|
cout << "final error = " << graph.error(result) << endl;
|
||||||
|
|
||||||
|
result.print("Final results:\n", formatter);
|
||||||
|
|
||||||
|
cout << "Ground Truth F1:\n" << F1.matrix() << endl;
|
||||||
|
cout << "Ground Truth F2:\n" << F2.matrix() << endl;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -162,7 +162,7 @@ struct FixedDimension {
|
||||||
typedef const int value_type;
|
typedef const int value_type;
|
||||||
static const int value = traits<T>::dimension;
|
static const int value = traits<T>::dimension;
|
||||||
static_assert(value != Eigen::Dynamic,
|
static_assert(value != Eigen::Dynamic,
|
||||||
"FixedDimension instantiated for dymanically-sized type.");
|
"FixedDimension instantiated for dynamically-sized type.");
|
||||||
};
|
};
|
||||||
} // \ namespace gtsam
|
} // \ namespace gtsam
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,12 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -70,6 +72,7 @@ namespace gtsam {
|
||||||
return a / b;
|
return a / b;
|
||||||
}
|
}
|
||||||
static inline double id(const double& x) { return x; }
|
static inline double id(const double& x) { return x; }
|
||||||
|
static inline double negate(const double& x) { return -x; }
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}
|
||||||
|
|
@ -181,11 +184,36 @@ namespace gtsam {
|
||||||
this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
|
this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create from an arbitrary DecisionTree<L, X> by operating on it
|
||||||
|
* with a functional `f`.
|
||||||
|
*
|
||||||
|
* @tparam X The type of the leaf of the original DecisionTree
|
||||||
|
* @tparam Func Type signature of functional `f`.
|
||||||
|
* @param other The original DecisionTree from which the
|
||||||
|
* AlgbraicDecisionTree is constructed.
|
||||||
|
* @param f Functional used to operate on
|
||||||
|
* the leaves of the input DecisionTree.
|
||||||
|
*/
|
||||||
|
template <typename X, typename Func>
|
||||||
|
AlgebraicDecisionTree(const DecisionTree<L, X>& other, Func f)
|
||||||
|
: Base(other, f) {}
|
||||||
|
|
||||||
/** sum */
|
/** sum */
|
||||||
AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const {
|
AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const {
|
||||||
return this->apply(g, &Ring::add);
|
return this->apply(g, &Ring::add);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** negation */
|
||||||
|
AlgebraicDecisionTree operator-() const {
|
||||||
|
return this->apply(&Ring::negate);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** subtract */
|
||||||
|
AlgebraicDecisionTree operator-(const AlgebraicDecisionTree& g) const {
|
||||||
|
return *this + (-g);
|
||||||
|
}
|
||||||
|
|
||||||
/** product */
|
/** product */
|
||||||
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
|
AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const {
|
||||||
return this->apply(g, &Ring::mul);
|
return this->apply(g, &Ring::mul);
|
||||||
|
|
@ -208,12 +236,9 @@ namespace gtsam {
|
||||||
* @brief Helper method to perform normalization such that all leaves in the
|
* @brief Helper method to perform normalization such that all leaves in the
|
||||||
* tree sum to 1
|
* tree sum to 1
|
||||||
*
|
*
|
||||||
* @param sum
|
|
||||||
* @return AlgebraicDecisionTree
|
* @return AlgebraicDecisionTree
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree normalize(double sum) const {
|
AlgebraicDecisionTree normalize() const { return (*this) / this->sum(); }
|
||||||
return this->apply([&sum](const double& x) { return x / sum; });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the minimum values amongst all leaves
|
/// Find the minimum values amongst all leaves
|
||||||
double min() const {
|
double min() const {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Install headers
|
# Install headers
|
||||||
set(subdir discrete)
|
set(subdir discrete)
|
||||||
file(GLOB discrete_headers "*.h")
|
file(GLOB discrete_headers "*.h")
|
||||||
# FIXME: exclude headers
|
|
||||||
install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete)
|
install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete)
|
||||||
|
|
||||||
# Add all tests
|
# Add all tests
|
||||||
|
|
|
||||||
|
|
@ -22,18 +22,16 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
#include <cmath>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <iterator>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <optional>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <optional>
|
|
||||||
#include <cassert>
|
|
||||||
#include <iterator>
|
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -251,22 +249,28 @@ namespace gtsam {
|
||||||
label_ = f.label();
|
label_ = f.label();
|
||||||
size_t count = f.nrChoices();
|
size_t count = f.nrChoices();
|
||||||
branches_.reserve(count);
|
branches_.reserve(count);
|
||||||
for (size_t i = 0; i < count; i++)
|
for (size_t i = 0; i < count; i++) {
|
||||||
push_back(f.branches_[i]->apply_f_op_g(g, op));
|
NodePtr newBranch = f.branches_[i]->apply_f_op_g(g, op);
|
||||||
|
push_back(std::move(newBranch));
|
||||||
|
}
|
||||||
} else if (g.label() > f.label()) {
|
} else if (g.label() > f.label()) {
|
||||||
// f lower than g
|
// f lower than g
|
||||||
label_ = g.label();
|
label_ = g.label();
|
||||||
size_t count = g.nrChoices();
|
size_t count = g.nrChoices();
|
||||||
branches_.reserve(count);
|
branches_.reserve(count);
|
||||||
for (size_t i = 0; i < count; i++)
|
for (size_t i = 0; i < count; i++) {
|
||||||
push_back(g.branches_[i]->apply_g_op_fC(f, op));
|
NodePtr newBranch = g.branches_[i]->apply_g_op_fC(f, op);
|
||||||
|
push_back(std::move(newBranch));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// f same level as g
|
// f same level as g
|
||||||
label_ = f.label();
|
label_ = f.label();
|
||||||
size_t count = f.nrChoices();
|
size_t count = f.nrChoices();
|
||||||
branches_.reserve(count);
|
branches_.reserve(count);
|
||||||
for (size_t i = 0; i < count; i++)
|
for (size_t i = 0; i < count; i++) {
|
||||||
push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
|
NodePtr newBranch = f.branches_[i]->apply_f_op_g(*g.branches_[i], op);
|
||||||
|
push_back(std::move(newBranch));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -283,13 +287,17 @@ namespace gtsam {
|
||||||
return branches_;
|
return branches_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<NodePtr>& branches() {
|
||||||
|
return branches_;
|
||||||
|
}
|
||||||
|
|
||||||
/** add a branch: TODO merge into constructor */
|
/** add a branch: TODO merge into constructor */
|
||||||
void push_back(const NodePtr& node) {
|
void push_back(NodePtr&& node) {
|
||||||
// allSame_ is restricted to leaf nodes in a decision tree
|
// allSame_ is restricted to leaf nodes in a decision tree
|
||||||
if (allSame_ && !branches_.empty()) {
|
if (allSame_ && !branches_.empty()) {
|
||||||
allSame_ = node->sameLeaf(*branches_.back());
|
allSame_ = node->sameLeaf(*branches_.back());
|
||||||
}
|
}
|
||||||
branches_.push_back(node);
|
branches_.push_back(std::move(node));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// print (as a tree).
|
/// print (as a tree).
|
||||||
|
|
@ -480,7 +488,7 @@ namespace gtsam {
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {}
|
DecisionTree<L, Y>::DecisionTree() : root_(nullptr) {}
|
||||||
|
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
DecisionTree<L, Y>::DecisionTree(const NodePtr& root) :
|
||||||
|
|
@ -497,9 +505,9 @@ namespace gtsam {
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
auto a = std::make_shared<Choice>(label, 2);
|
auto a = std::make_shared<Choice>(label, 2);
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(std::move(l1));
|
||||||
a->push_back(l2);
|
a->push_back(std::move(l2));
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(std::move(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -510,11 +518,10 @@ namespace gtsam {
|
||||||
"DecisionTree: binary constructor called with non-binary label");
|
"DecisionTree: binary constructor called with non-binary label");
|
||||||
auto a = std::make_shared<Choice>(labelC.first, 2);
|
auto a = std::make_shared<Choice>(labelC.first, 2);
|
||||||
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
NodePtr l1(new Leaf(y1)), l2(new Leaf(y2));
|
||||||
a->push_back(l1);
|
a->push_back(std::move(l1));
|
||||||
a->push_back(l2);
|
a->push_back(std::move(l2));
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(std::move(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
|
|
@ -552,14 +559,42 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
DecisionTree<L, Y>::DecisionTree(const Unary& op,
|
||||||
|
DecisionTree&& other) noexcept
|
||||||
|
: root_(std::move(other.root_)) {
|
||||||
|
// Apply the unary operation directly to each leaf in the tree
|
||||||
|
if (root_) {
|
||||||
|
// Define a helper function to traverse and apply the operation
|
||||||
|
struct ApplyUnary {
|
||||||
|
const Unary& op;
|
||||||
|
void operator()(typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
if (auto leaf = std::dynamic_pointer_cast<Leaf>(node)) {
|
||||||
|
// Apply the unary operation to the leaf's constant value
|
||||||
|
leaf->constant_ = op(leaf->constant_);
|
||||||
|
} else if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||||
|
// Recurse into the choice branches
|
||||||
|
for (NodePtr& branch : choice->branches()) {
|
||||||
|
(*this)(branch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
ApplyUnary applyUnary{op};
|
||||||
|
applyUnary(root_);
|
||||||
|
}
|
||||||
|
// Reset the other tree's root to nullptr to avoid dangling references
|
||||||
|
other.root_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
Func Y_of_X) {
|
Func Y_of_X) {
|
||||||
// Define functor for identity mapping of node label.
|
root_ = convertFrom<X>(other.root_, Y_of_X);
|
||||||
auto L_of_L = [](const L& label) { return label; };
|
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -580,7 +615,7 @@ namespace gtsam {
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
Iterator begin, Iterator end, const L& label) const {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
std::optional<L> highestLabel;
|
std::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
|
@ -598,8 +633,10 @@ namespace gtsam {
|
||||||
// if label is already in correct order, just put together a choice on label
|
// if label is already in correct order, just put together a choice on label
|
||||||
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
if (!nrChoices || !highestLabel || label > *highestLabel) {
|
||||||
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
|
auto choiceOnLabel = std::make_shared<Choice>(label, end - begin);
|
||||||
for (Iterator it = begin; it != end; it++)
|
for (Iterator it = begin; it != end; it++) {
|
||||||
choiceOnLabel->push_back(it->root_);
|
NodePtr root = it->root_;
|
||||||
|
choiceOnLabel->push_back(std::move(root));
|
||||||
|
}
|
||||||
// If no reordering, no need to call Choice::Unique
|
// If no reordering, no need to call Choice::Unique
|
||||||
return choiceOnLabel;
|
return choiceOnLabel;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -618,7 +655,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
// We then recurse, for all values of the highest label
|
// We then recurse, for all values of the highest label
|
||||||
NodePtr fi = compose(functions.begin(), functions.end(), label);
|
NodePtr fi = compose(functions.begin(), functions.end(), label);
|
||||||
choiceOnHighestLabel->push_back(fi);
|
choiceOnHighestLabel->push_back(std::move(fi));
|
||||||
}
|
}
|
||||||
return choiceOnHighestLabel;
|
return choiceOnHighestLabel;
|
||||||
}
|
}
|
||||||
|
|
@ -648,7 +685,7 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||||
// get crucial counts
|
// get crucial counts
|
||||||
size_t nrChoices = begin->second;
|
size_t nrChoices = begin->second;
|
||||||
size_t size = endY - beginY;
|
size_t size = endY - beginY;
|
||||||
|
|
@ -675,6 +712,7 @@ namespace gtsam {
|
||||||
// Creates one tree (i.e.,function) for each choice of current key
|
// Creates one tree (i.e.,function) for each choice of current key
|
||||||
// by calling create recursively, and then puts them all together.
|
// by calling create recursively, and then puts them all together.
|
||||||
std::vector<DecisionTree> functions;
|
std::vector<DecisionTree> functions;
|
||||||
|
functions.reserve(nrChoices);
|
||||||
size_t split = size / nrChoices;
|
size_t split = size / nrChoices;
|
||||||
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
for (size_t i = 0; i < nrChoices; i++, beginY += split) {
|
||||||
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
|
NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split);
|
||||||
|
|
@ -689,26 +727,53 @@ namespace gtsam {
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) {
|
||||||
auto node = build(begin, end, beginY, endY);
|
auto node = build(begin, end, beginY, endY);
|
||||||
if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) {
|
if (auto choice = std::dynamic_pointer_cast<Choice>(node)) {
|
||||||
return Choice::Unique(choice);
|
return Choice::Unique(choice);
|
||||||
} else {
|
} else {
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename X>
|
||||||
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
|
const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
|
std::function<Y(const X&)> Y_of_X) {
|
||||||
|
|
||||||
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
|
using LXLeaf = typename DecisionTree<L, X>::Leaf;
|
||||||
|
if (auto leaf = std::dynamic_pointer_cast<LXLeaf>(f)) {
|
||||||
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if Choice
|
||||||
|
using LXChoice = typename DecisionTree<L, X>::Choice;
|
||||||
|
auto choice = std::dynamic_pointer_cast<const LXChoice>(f);
|
||||||
|
if (!choice) throw std::invalid_argument(
|
||||||
|
"DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
|
// Create a new Choice node with the same label
|
||||||
|
auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices());
|
||||||
|
|
||||||
|
// Convert each branch recursively
|
||||||
|
for (auto&& branch : choice->branches()) {
|
||||||
|
newChoice->push_back(convertFrom<X>(branch, Y_of_X));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Choice::Unique(newChoice);
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M,
|
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// Ugliness below because apparently we can't have templated virtual
|
|
||||||
// functions.
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||||
|
|
@ -718,19 +783,27 @@ namespace gtsam {
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice)
|
||||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
const L newLabel = L_of_M(oldLabel);
|
const L newLabel = L_of_M(oldLabel);
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// Shannon expansion in this context involves:
|
||||||
|
// 1. Creating separate subtrees (functions) for each possible value of the new label.
|
||||||
|
// 2. Combining these subtrees using the 'compose' method, which implements the expansion.
|
||||||
|
// This approach guarantees that the resulting tree maintains the correct variable ordering
|
||||||
|
// based on the new labels (L) after translation from the old labels (M).
|
||||||
|
// Simply creating a Choice node here would not work because it wouldn't account for the
|
||||||
|
// potentially new ordering of variables resulting from the label translation,
|
||||||
|
// which is crucial for maintaining consistency and efficiency in the converted tree.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for (auto&& branch : choice->branches()) {
|
for (auto&& branch : choice->branches()) {
|
||||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
}
|
}
|
||||||
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
|
return Choice::Unique(
|
||||||
|
LY::compose(functions.begin(), functions.end(), newLabel));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -913,11 +986,16 @@ namespace gtsam {
|
||||||
return root_->equals(*other.root_);
|
return root_->equals(*other.root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
const Y& DecisionTree<L, Y>::operator()(const Assignment<L>& x) const {
|
||||||
|
if (root_ == nullptr)
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DecisionTree::operator() called on empty tree");
|
||||||
return root_->operator ()(x);
|
return root_->operator ()(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const Unary& op) const {
|
||||||
// It is unclear what should happen if tree is empty:
|
// It is unclear what should happen if tree is empty:
|
||||||
|
|
@ -928,6 +1006,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
/// Apply unary operator with assignment
|
/// Apply unary operator with assignment
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||||
|
|
@ -1011,6 +1090,18 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/******************************************************************************/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> DecisionTree<L, Y>::split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const {
|
||||||
|
using AB = std::pair<A, B>;
|
||||||
|
const DecisionTree<L, AB> ab(*this, AB_of_Y);
|
||||||
|
const DecisionTree<L, A> a(ab, [](const AB& p) { return p.first; });
|
||||||
|
const DecisionTree<L, B> b(ab, [](const AB& p) { return p.second; });
|
||||||
|
return {a, b};
|
||||||
|
}
|
||||||
|
|
||||||
/******************************************************************************/
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,6 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -86,7 +85,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
struct Node {
|
struct Node {
|
||||||
using Ptr = std::shared_ptr<const Node>;
|
using Ptr = std::shared_ptr<Node>;
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
static int nrNodes;
|
static int nrNodes;
|
||||||
|
|
@ -155,15 +154,28 @@ namespace gtsam {
|
||||||
* and Y values
|
* and Y values
|
||||||
*/
|
*/
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
|
||||||
/** Internal helper function to create from
|
/**
|
||||||
* keys, cardinalities, and Y values.
|
* Internal helper function to create a tree from keys, cardinalities, and Y
|
||||||
* Calls `build` which builds thetree bottom-up,
|
* values. Calls `build` which builds the tree bottom-up, before we prune in
|
||||||
* before we prune in a top-down fashion.
|
* a top-down fashion.
|
||||||
*/
|
*/
|
||||||
template <typename It, typename ValueIt>
|
template <typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>.
|
||||||
|
*
|
||||||
|
* @tparam M The previous label type.
|
||||||
|
* @tparam X The previous value type.
|
||||||
|
* @param f The node pointer to the root of the previous DecisionTree.
|
||||||
|
* @param Y_of_X Functor to convert from value type X to type Y.
|
||||||
|
* @return NodePtr
|
||||||
|
*/
|
||||||
|
template <typename X>
|
||||||
|
static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f,
|
||||||
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
* @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
|
||||||
|
|
@ -176,9 +188,9 @@ namespace gtsam {
|
||||||
* @return NodePtr
|
* @return NodePtr
|
||||||
*/
|
*/
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M,
|
std::function<L(const M&)> L_of_M,
|
||||||
std::function<Y(const X&)> Y_of_X) const;
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
|
|
@ -216,6 +228,15 @@ namespace gtsam {
|
||||||
DecisionTree(const L& label, const DecisionTree& f0,
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Move constructor for DecisionTree. Very efficient as does not
|
||||||
|
* allocate anything, just changes in-place. But `other` is consumed.
|
||||||
|
*
|
||||||
|
* @param op The unary operation to apply to the moved DecisionTree.
|
||||||
|
* @param other The DecisionTree to move from, will be empty afterwards.
|
||||||
|
*/
|
||||||
|
DecisionTree(const Unary& op, DecisionTree&& other) noexcept;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type.
|
* @brief Convert from a different value type.
|
||||||
*
|
*
|
||||||
|
|
@ -227,7 +248,7 @@ namespace gtsam {
|
||||||
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type X to value type Y, also transate
|
* @brief Convert from a different value type X to value type Y, also translate
|
||||||
* labels via map from type M to L.
|
* labels via map from type M to L.
|
||||||
*
|
*
|
||||||
* @tparam M Previous label type.
|
* @tparam M Previous label type.
|
||||||
|
|
@ -394,6 +415,18 @@ namespace gtsam {
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter,
|
||||||
bool showZero = true) const;
|
bool showZero = true) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert into two trees with value types A and B.
|
||||||
|
*
|
||||||
|
* @tparam A First new value type.
|
||||||
|
* @tparam B Second new value type.
|
||||||
|
* @param AB_of_Y Functor to convert from type X to std::pair<A, B>.
|
||||||
|
* @return A pair of DecisionTrees with value types A and B respectively.
|
||||||
|
*/
|
||||||
|
template <typename A, typename B>
|
||||||
|
std::pair<DecisionTree<L, A>, DecisionTree<L, B>> split(
|
||||||
|
std::function<std::pair<A, B>(const Y&)> AB_of_Y) const;
|
||||||
|
|
||||||
/// @name Advanced Interface
|
/// @name Advanced Interface
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -402,7 +435,7 @@ namespace gtsam {
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
static compose(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,22 +62,6 @@ namespace gtsam {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
AlgebraicDecisionTree<Key> DecisionTreeFactor::errorTree() const {
|
|
||||||
// Get all possible assignments
|
|
||||||
DiscreteKeys dkeys = discreteKeys();
|
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
|
||||||
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
|
||||||
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
|
||||||
|
|
||||||
// Construct vector with error values
|
|
||||||
std::vector<double> errors;
|
|
||||||
for (const auto& assignment : assignments) {
|
|
||||||
errors.push_back(error(assignment));
|
|
||||||
}
|
|
||||||
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
|
|
@ -385,6 +369,16 @@ namespace gtsam {
|
||||||
// Now threshold the decision tree
|
// Now threshold the decision tree
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
auto thresholdFunc = [threshold, &total, N](const double& value) {
|
auto thresholdFunc = [threshold, &total, N](const double& value) {
|
||||||
|
// There is a possible case where the `threshold` is equal to 0.0
|
||||||
|
// In that case `(value < threshold) == false`
|
||||||
|
// which increases the leaf total erroneously.
|
||||||
|
// Hence we check for 0.0 explicitly.
|
||||||
|
if (fpEqual(value, 0.0, 1e-12)) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if value is less than the threshold and
|
||||||
|
// we haven't exceeded the maximum number of leaves.
|
||||||
if (value < threshold || total >= N) {
|
if (value < threshold || total >= N) {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/// Calculate probability for given values `x`,
|
/// Calculate probability for given values `x`,
|
||||||
/// is just look up in AlgebraicDecisionTree.
|
/// is just look up in AlgebraicDecisionTree.
|
||||||
double evaluate(const DiscreteValues& values) const {
|
double evaluate(const Assignment<Key>& values) const {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -141,7 +141,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// multiply two factors
|
/// multiply two factors
|
||||||
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override {
|
||||||
|
|
@ -292,9 +292,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
||||||
// sample each node in turn in topological sort order (parents first)
|
// sample each node in turn in topological sort order (parents first)
|
||||||
for (auto it = std::make_reverse_iterator(end());
|
for (auto it = std::make_reverse_iterator(end());
|
||||||
it != std::make_reverse_iterator(begin()); ++it) {
|
it != std::make_reverse_iterator(begin()); ++it) {
|
||||||
(*it)->sampleInPlace(&result);
|
const DiscreteConditional::shared_ptr& conditional = *it;
|
||||||
|
// Sample the conditional only if value for j not already in result
|
||||||
|
const Key j = conditional->firstFrontalKey();
|
||||||
|
if (result.count(j) == 0) {
|
||||||
|
conditional->sampleInPlace(&result);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
|
||||||
assert(nrFrontals() == 1);
|
// throw if more than one frontal:
|
||||||
Key j = (firstFrontalKey());
|
if (nrFrontals() != 1) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::sampleInPlace can only be called on single "
|
||||||
|
"variable conditionals");
|
||||||
|
}
|
||||||
|
Key j = firstFrontalKey();
|
||||||
|
// throw if values already contains j:
|
||||||
|
if (values->count(j) > 0) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteConditional::sampleInPlace: values already contains j");
|
||||||
|
}
|
||||||
size_t sampled = sample(*values); // Sample variable given parents
|
size_t sampled = sample(*values); // Sample variable given parents
|
||||||
(*values)[j] = sampled; // store result in partial solution
|
(*values)[j] = sampled; // store result in partial solution
|
||||||
}
|
}
|
||||||
|
|
@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
double DiscreteConditional::negLogConstant() const {
|
double DiscreteConditional::negLogConstant() const { return 0.0; }
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional
|
||||||
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate, just look up in AlgebraicDecisonTree
|
/// Evaluate, just look up in AlgebraicDecisionTree
|
||||||
double evaluate(const DiscreteValues& values) const {
|
double evaluate(const DiscreteValues& values) const {
|
||||||
return ADT::operator()(values);
|
return ADT::operator()(values);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,22 @@ double DiscreteFactor::error(const HybridValues& c) const {
|
||||||
return this->error(c.discrete());
|
return this->error(c.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> DiscreteFactor::errorTree() const {
|
||||||
|
// Get all possible assignments
|
||||||
|
DiscreteKeys dkeys = discreteKeys();
|
||||||
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
|
DiscreteKeys rdkeys(dkeys.rbegin(), dkeys.rend());
|
||||||
|
const auto assignments = DiscreteValues::CartesianProduct(rdkeys);
|
||||||
|
|
||||||
|
// Construct vector with error values
|
||||||
|
std::vector<double> errors;
|
||||||
|
for (const auto& assignment : assignments) {
|
||||||
|
errors.push_back(error(assignment));
|
||||||
|
}
|
||||||
|
return AlgebraicDecisionTree<Key>(dkeys, errors);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
|
||||||
double maxLogProb = -std::numeric_limits<double>::infinity();
|
double maxLogProb = -std::numeric_limits<double>::infinity();
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
virtual double operator()(const DiscreteValues&) const = 0;
|
virtual double operator()(const DiscreteValues&) const = 0;
|
||||||
|
|
||||||
/// Error is just -log(value)
|
/// Error is just -log(value)
|
||||||
double error(const DiscreteValues& values) const;
|
virtual double error(const DiscreteValues& values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The Factor::error simply extracts the \class DiscreteValues from the
|
* The Factor::error simply extracts the \class DiscreteValues from the
|
||||||
|
|
@ -105,7 +105,7 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
double error(const HybridValues& c) const override;
|
double error(const HybridValues& c) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
/// Compute error for each assignment and return as a tree
|
||||||
virtual AlgebraicDecisionTree<Key> errorTree() const = 0;
|
virtual AlgebraicDecisionTree<Key> errorTree() const;
|
||||||
|
|
||||||
/// Multiply in a DecisionTreeFactor and return the result as
|
/// Multiply in a DecisionTreeFactor and return the result as
|
||||||
/// DecisionTreeFactor
|
/// DecisionTreeFactor
|
||||||
|
|
@ -158,8 +158,8 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
|
||||||
// DiscreteFactor
|
// DiscreteFactor
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
template <>
|
||||||
|
struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Normalize a set of log probabilities.
|
* @brief Normalize a set of log probabilities.
|
||||||
|
|
@ -179,5 +179,4 @@ template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
||||||
*/
|
*/
|
||||||
std::vector<double> expNormalize(const std::vector<double>& logProbs);
|
std::vector<double> expNormalize(const std::vector<double>& logProbs);
|
||||||
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -168,11 +168,6 @@ double TableFactor::error(const HybridValues& values) const {
|
||||||
return error(values.discrete());
|
return error(values.discrete());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
AlgebraicDecisionTree<Key> TableFactor::errorTree() const {
|
|
||||||
return toDecisionTreeFactor().errorTree();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
|
||||||
return toDecisionTreeFactor() * f;
|
return toDecisionTreeFactor() * f;
|
||||||
|
|
|
||||||
|
|
@ -179,7 +179,7 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
double operator()(const DiscreteValues& values) const override;
|
double operator()(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
||||||
double error(const DiscreteValues& values) const;
|
double error(const DiscreteValues& values) const override;
|
||||||
|
|
||||||
/// multiply two TableFactors
|
/// multiply two TableFactors
|
||||||
TableFactor operator*(const TableFactor& f) const {
|
TableFactor operator*(const TableFactor& f) const {
|
||||||
|
|
@ -358,9 +358,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues& values) const override;
|
double error(const HybridValues& values) const override;
|
||||||
|
|
||||||
/// Compute error for each assignment and return as a tree
|
|
||||||
AlgebraicDecisionTree<Key> errorTree() const override;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDecisionTree.cpp
|
* @file testAlgebraicDecisionTree.cpp
|
||||||
* @brief Develop DecisionTree
|
* @brief Unit tests for Algebraic decision tree
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @date Mar 6, 2011
|
* @date Mar 6, 2011
|
||||||
*/
|
*/
|
||||||
|
|
@ -46,23 +46,35 @@ void dot(const T& f, const string& filename) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/** I can't get this to work !
|
/* ************************************************************************** */
|
||||||
class Mul: std::function<double(const double&, const double&)> {
|
// Test arithmetic:
|
||||||
inline double operator()(const double& a, const double& b) {
|
TEST(ADT, arithmetic) {
|
||||||
return a * b;
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
}
|
ADT zero{0}, one{1};
|
||||||
};
|
ADT a(A, 1, 2);
|
||||||
|
ADT b(B, 3, 4);
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// Addition
|
||||||
template<typename L>
|
CHECK(assert_equal(a, zero + a));
|
||||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
|
||||||
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
// Negate and subtraction
|
||||||
Ptr h(new Choice(label(), cardinality()));
|
CHECK(assert_equal(-a, zero - a));
|
||||||
for(const NodePtr& branch: branches_)
|
CHECK(assert_equal({zero}, a - a));
|
||||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
CHECK(assert_equal(a + b, b + a));
|
||||||
return Unique(cache, h);
|
CHECK(assert_equal({A, 3, 4}, a + 2));
|
||||||
|
CHECK(assert_equal({B, 1, 2}, b - 2));
|
||||||
|
|
||||||
|
// Multiplication
|
||||||
|
CHECK(assert_equal(zero, zero * a));
|
||||||
|
CHECK(assert_equal(zero, a * zero));
|
||||||
|
CHECK(assert_equal(a, one * a));
|
||||||
|
CHECK(assert_equal(a, a * one));
|
||||||
|
CHECK(assert_equal(a * b, b * a));
|
||||||
|
|
||||||
|
// division
|
||||||
|
// CHECK(assert_equal(a, (a * b) / b)); // not true because no pruning
|
||||||
|
CHECK(assert_equal(b, (a * b) / a));
|
||||||
}
|
}
|
||||||
*/
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
|
|
@ -550,7 +562,7 @@ TEST(ADT, Sum) {
|
||||||
TEST(ADT, Normalize) {
|
TEST(ADT, Normalize) {
|
||||||
ADT a = exampleADT();
|
ADT a = exampleADT();
|
||||||
double sum = a.sum();
|
double sum = a.sum();
|
||||||
auto actual = a.normalize(sum);
|
auto actual = a.normalize();
|
||||||
|
|
||||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
DiscreteKeys keys = DiscreteKeys{A, B, C};
|
DiscreteKeys keys = DiscreteKeys{A, B, C};
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* @file testDecisionTree.cpp
|
* @file testDecisionTree.cpp
|
||||||
* @brief Develop DecisionTree
|
* @brief DecisionTree unit tests
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
* @author Can Erdogan
|
* @author Can Erdogan
|
||||||
* @date Jan 30, 2012
|
* @date Jan 30, 2012
|
||||||
|
|
@ -108,6 +108,7 @@ struct DT : public DecisionTree<string, int> {
|
||||||
std::cout << s;
|
std::cout << s;
|
||||||
Base::print("", keyFormatter, valueFormatter);
|
Base::print("", keyFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Equality method customized to int node type
|
/// Equality method customized to int node type
|
||||||
bool equals(const Base& other, double tol = 1e-9) const {
|
bool equals(const Base& other, double tol = 1e-9) const {
|
||||||
auto compare = [](const int& v, const int& w) { return v == w; };
|
auto compare = [](const int& v, const int& w) { return v == w; };
|
||||||
|
|
@ -271,6 +272,58 @@ TEST(DecisionTree, Example) {
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test that we can create two trees out of one, using a function that returns a pair.
|
||||||
|
TEST(DecisionTree, Split) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// Create a decision tree
|
||||||
|
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||||
|
|
||||||
|
// Define a function that returns an int/bool pair
|
||||||
|
auto split_function = [](const int& value) -> std::pair<int, bool> {
|
||||||
|
return {value*3, value*3 % 2 == 0};
|
||||||
|
};
|
||||||
|
|
||||||
|
// Split the original tree into two new trees
|
||||||
|
auto [la,lb] = original.split<int,bool>(split_function);
|
||||||
|
|
||||||
|
// Check the first resulting tree
|
||||||
|
EXPECT_LONGS_EQUAL(3, la(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(6, la(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(9, la(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(12, la(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
|
||||||
|
// Check the second resulting tree
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT(!lb(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT(lb(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test that we can create a tree by modifying an rvalue.
|
||||||
|
TEST(DecisionTree, Consume) {
|
||||||
|
// Create labels
|
||||||
|
string A("A"), B("B");
|
||||||
|
|
||||||
|
// Create a decision tree
|
||||||
|
DT original(A, DT(B, 1, 2), DT(B, 3, 4));
|
||||||
|
|
||||||
|
DT modified([](int i){return i*2;}, std::move(original));
|
||||||
|
|
||||||
|
// Check the first resulting tree
|
||||||
|
EXPECT_LONGS_EQUAL(2, modified(Assignment<string>{{A, 0}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(4, modified(Assignment<string>{{A, 0}, {B, 1}}));
|
||||||
|
EXPECT_LONGS_EQUAL(6, modified(Assignment<string>{{A, 1}, {B, 0}}));
|
||||||
|
EXPECT_LONGS_EQUAL(8, modified(Assignment<string>{{A, 1}, {B, 1}}));
|
||||||
|
|
||||||
|
// Check original was moved
|
||||||
|
EXPECT(original.root_ == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
/*
|
||||||
|
* @file FundamentalMatrix.cpp
|
||||||
|
* @brief FundamentalMatrix classes
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date Oct 23, 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/geometry/FundamentalMatrix.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
Point2 EpipolarTransfer(const Matrix3& Fca, const Point2& pa, //
|
||||||
|
const Matrix3& Fcb, const Point2& pb) {
|
||||||
|
// Create lines in camera a from projections of the other two cameras
|
||||||
|
Vector3 line_a = Fca * Vector3(pa.x(), pa.y(), 1);
|
||||||
|
Vector3 line_b = Fcb * Vector3(pb.x(), pb.y(), 1);
|
||||||
|
|
||||||
|
// Cross the lines to find the intersection point
|
||||||
|
Vector3 intersectionPoint = line_a.cross(line_b);
|
||||||
|
|
||||||
|
// Normalize the intersection point
|
||||||
|
intersectionPoint /= intersectionPoint(2);
|
||||||
|
|
||||||
|
return intersectionPoint.head<2>(); // Return the 2D point
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
FundamentalMatrix::FundamentalMatrix(const Matrix3& F) {
|
||||||
|
// Perform SVD
|
||||||
|
Eigen::JacobiSVD<Matrix3> svd(F, Eigen::ComputeFullU | Eigen::ComputeFullV);
|
||||||
|
|
||||||
|
// Extract U and V
|
||||||
|
Matrix3 U = svd.matrixU();
|
||||||
|
Matrix3 V = svd.matrixV();
|
||||||
|
Vector3 singularValues = svd.singularValues();
|
||||||
|
|
||||||
|
// Scale the singular values
|
||||||
|
double scale = singularValues(0);
|
||||||
|
if (scale != 0) {
|
||||||
|
singularValues /= scale; // Normalize the first singular value to 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the third singular value is close to zero (valid F condition)
|
||||||
|
if (std::abs(singularValues(2)) > 1e-9) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"The input matrix does not represent a valid fundamental matrix.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the second singular value is recorded as s
|
||||||
|
s_ = singularValues(1);
|
||||||
|
|
||||||
|
// Check if U is a reflection
|
||||||
|
if (U.determinant() < 0) {
|
||||||
|
U = -U;
|
||||||
|
s_ = -s_; // Change sign of scalar if U is a reflection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if V is a reflection
|
||||||
|
if (V.determinant() < 0) {
|
||||||
|
V = -V;
|
||||||
|
s_ = -s_; // Change sign of scalar if U is a reflection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign the rotations
|
||||||
|
U_ = Rot3(U);
|
||||||
|
V_ = Rot3(V);
|
||||||
|
}
|
||||||
|
|
||||||
|
Matrix3 FundamentalMatrix::matrix() const {
|
||||||
|
return U_.matrix() * Vector3(1, s_, 0).asDiagonal() * V_.transpose().matrix();
|
||||||
|
}
|
||||||
|
|
||||||
|
void FundamentalMatrix::print(const std::string& s) const {
|
||||||
|
std::cout << s << matrix() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool FundamentalMatrix::equals(const FundamentalMatrix& other,
|
||||||
|
double tol) const {
|
||||||
|
return U_.equals(other.U_, tol) && std::abs(s_ - other.s_) < tol &&
|
||||||
|
V_.equals(other.V_, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vector FundamentalMatrix::localCoordinates(const FundamentalMatrix& F) const {
|
||||||
|
Vector result(7);
|
||||||
|
result.head<3>() = U_.localCoordinates(F.U_);
|
||||||
|
result(3) = F.s_ - s_; // Difference in scalar
|
||||||
|
result.tail<3>() = V_.localCoordinates(F.V_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
FundamentalMatrix FundamentalMatrix::retract(const Vector& delta) const {
|
||||||
|
Rot3 newU = U_.retract(delta.head<3>());
|
||||||
|
double newS = s_ + delta(3); // Update scalar
|
||||||
|
Rot3 newV = V_.retract(delta.tail<3>());
|
||||||
|
return FundamentalMatrix(newU, newS, newV);
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
Matrix3 SimpleFundamentalMatrix::Ka() const {
|
||||||
|
Matrix3 K;
|
||||||
|
K << fa_, 0, ca_.x(), 0, fa_, ca_.y(), 0, 0, 1;
|
||||||
|
return K;
|
||||||
|
}
|
||||||
|
|
||||||
|
Matrix3 SimpleFundamentalMatrix::Kb() const {
|
||||||
|
Matrix3 K;
|
||||||
|
K << fb_, 0, cb_.x(), 0, fb_, cb_.y(), 0, 0, 1;
|
||||||
|
return K;
|
||||||
|
}
|
||||||
|
|
||||||
|
Matrix3 SimpleFundamentalMatrix::matrix() const {
|
||||||
|
return Ka().transpose().inverse() * E_.matrix() * Kb().inverse();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SimpleFundamentalMatrix::print(const std::string& s) const {
|
||||||
|
std::cout << s << " E:\n"
|
||||||
|
<< E_.matrix() << "\nfa: " << fa_ << "\nfb: " << fb_
|
||||||
|
<< "\nca: " << ca_.transpose() << "\ncb: " << cb_.transpose()
|
||||||
|
<< std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SimpleFundamentalMatrix::equals(const SimpleFundamentalMatrix& other,
|
||||||
|
double tol) const {
|
||||||
|
return E_.equals(other.E_, tol) && std::abs(fa_ - other.fa_) < tol &&
|
||||||
|
std::abs(fb_ - other.fb_) < tol && (ca_ - other.ca_).norm() < tol &&
|
||||||
|
(cb_ - other.cb_).norm() < tol;
|
||||||
|
}
|
||||||
|
|
||||||
|
Vector SimpleFundamentalMatrix::localCoordinates(
|
||||||
|
const SimpleFundamentalMatrix& F) const {
|
||||||
|
Vector result(7);
|
||||||
|
result.head<5>() = E_.localCoordinates(F.E_);
|
||||||
|
result(5) = F.fa_ - fa_; // Difference in fa
|
||||||
|
result(6) = F.fb_ - fb_; // Difference in fb
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
SimpleFundamentalMatrix SimpleFundamentalMatrix::retract(
|
||||||
|
const Vector& delta) const {
|
||||||
|
EssentialMatrix newE = E_.retract(delta.head<5>());
|
||||||
|
double newFa = fa_ + delta(5); // Update fa
|
||||||
|
double newFb = fb_ + delta(6); // Update fb
|
||||||
|
return SimpleFundamentalMatrix(newE, newFa, newFb, ca_, cb_);
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,207 @@
|
||||||
|
/*
|
||||||
|
* @file FundamentalMatrix.h
|
||||||
|
* @brief FundamentalMatrix classes
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date Oct 23, 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/geometry/EssentialMatrix.h>
|
||||||
|
#include <gtsam/geometry/Rot3.h>
|
||||||
|
#include <gtsam/geometry/Unit3.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @class FundamentalMatrix
|
||||||
|
* @brief Represents a general fundamental matrix.
|
||||||
|
*
|
||||||
|
* This class represents a general fundamental matrix, which is a 3x3 matrix
|
||||||
|
* that describes the relationship between two images. It is parameterized by a
|
||||||
|
* left rotation U, a scalar s, and a right rotation V.
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT FundamentalMatrix {
|
||||||
|
private:
|
||||||
|
Rot3 U_; ///< Left rotation
|
||||||
|
double s_; ///< Scalar parameter for S
|
||||||
|
Rot3 V_; ///< Right rotation
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Default constructor
|
||||||
|
FundamentalMatrix() : U_(Rot3()), s_(1.0), V_(Rot3()) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from U, V, and scalar s
|
||||||
|
*
|
||||||
|
* Initializes the FundamentalMatrix with the given left rotation U,
|
||||||
|
* scalar s, and right rotation V.
|
||||||
|
*
|
||||||
|
* @param U Left rotation matrix
|
||||||
|
* @param s Scalar parameter for the fundamental matrix
|
||||||
|
* @param V Right rotation matrix
|
||||||
|
*/
|
||||||
|
FundamentalMatrix(const Rot3& U, double s, const Rot3& V)
|
||||||
|
: U_(U), s_(s), V_(V) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from a 3x3 matrix using SVD
|
||||||
|
*
|
||||||
|
* Initializes the FundamentalMatrix by performing SVD on the given
|
||||||
|
* matrix and ensuring U and V are not reflections.
|
||||||
|
*
|
||||||
|
* @param F A 3x3 matrix representing the fundamental matrix
|
||||||
|
*/
|
||||||
|
FundamentalMatrix(const Matrix3& F);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from calibration matrices Ka, Kb, and pose aPb
|
||||||
|
*
|
||||||
|
* Initializes the FundamentalMatrix from the given calibration
|
||||||
|
* matrices Ka and Kb, and the pose aPb.
|
||||||
|
*
|
||||||
|
* @tparam CAL Calibration type, expected to have a matrix() method
|
||||||
|
* @param Ka Calibration matrix for the left camera
|
||||||
|
* @param aPb Pose from the left to the right camera
|
||||||
|
* @param Kb Calibration matrix for the right camera
|
||||||
|
*/
|
||||||
|
template <typename CAL>
|
||||||
|
FundamentalMatrix(const CAL& Ka, const Pose3& aPb, const CAL& Kb)
|
||||||
|
: FundamentalMatrix(Ka.K().transpose().inverse() *
|
||||||
|
EssentialMatrix::FromPose3(aPb).matrix() *
|
||||||
|
Kb.K().inverse()) {}
|
||||||
|
|
||||||
|
/// Return the fundamental matrix representation
|
||||||
|
Matrix3 matrix() const;
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
/// Print the FundamentalMatrix
|
||||||
|
void print(const std::string& s = "") const;
|
||||||
|
|
||||||
|
/// Check if the FundamentalMatrix is equal to another within a
|
||||||
|
/// tolerance
|
||||||
|
bool equals(const FundamentalMatrix& other, double tol = 1e-9) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Manifold
|
||||||
|
/// @{
|
||||||
|
enum { dimension = 7 }; // 3 for U, 1 for s, 3 for V
|
||||||
|
inline static size_t Dim() { return dimension; }
|
||||||
|
inline size_t dim() const { return dimension; }
|
||||||
|
|
||||||
|
/// Return local coordinates with respect to another FundamentalMatrix
|
||||||
|
Vector localCoordinates(const FundamentalMatrix& F) const;
|
||||||
|
|
||||||
|
/// Retract the given vector to get a new FundamentalMatrix
|
||||||
|
FundamentalMatrix retract(const Vector& delta) const;
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @class SimpleFundamentalMatrix
|
||||||
|
* @brief Class for representing a simple fundamental matrix.
|
||||||
|
*
|
||||||
|
* This class represents a simple fundamental matrix, which is a
|
||||||
|
* parameterization of the essential matrix and focal lengths for left and right
|
||||||
|
* cameras. Principal points are not part of the manifold but a convenience.
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT SimpleFundamentalMatrix {
|
||||||
|
private:
|
||||||
|
EssentialMatrix E_; ///< Essential matrix
|
||||||
|
double fa_; ///< Focal length for left camera
|
||||||
|
double fb_; ///< Focal length for right camera
|
||||||
|
Point2 ca_; ///< Principal point for left camera
|
||||||
|
Point2 cb_; ///< Principal point for right camera
|
||||||
|
|
||||||
|
/// Return the left calibration matrix
|
||||||
|
Matrix3 Ka() const;
|
||||||
|
|
||||||
|
/// Return the right calibration matrix
|
||||||
|
Matrix3 Kb() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Default constructor
|
||||||
|
SimpleFundamentalMatrix()
|
||||||
|
: E_(), fa_(1.0), fb_(1.0), ca_(0.0, 0.0), cb_(0.0, 0.0) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from essential matrix and focal lengths
|
||||||
|
* @param E Essential matrix
|
||||||
|
* @param fa Focal length for left camera
|
||||||
|
* @param fb Focal length for right camera
|
||||||
|
* @param ca Principal point for left camera
|
||||||
|
* @param cb Principal point for right camera
|
||||||
|
*/
|
||||||
|
SimpleFundamentalMatrix(const EssentialMatrix& E, //
|
||||||
|
double fa, double fb, const Point2& ca,
|
||||||
|
const Point2& cb)
|
||||||
|
: E_(E), fa_(fa), fb_(fb), ca_(ca), cb_(cb) {}
|
||||||
|
|
||||||
|
/// Return the fundamental matrix representation
|
||||||
|
/// F = Ka^(-T) * E * Kb^(-1)
|
||||||
|
Matrix3 matrix() const;
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
/// Print the SimpleFundamentalMatrix
|
||||||
|
void print(const std::string& s = "") const;
|
||||||
|
|
||||||
|
/// Check equality within a tolerance
|
||||||
|
bool equals(const SimpleFundamentalMatrix& other, double tol = 1e-9) const;
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Manifold
|
||||||
|
/// @{
|
||||||
|
enum { dimension = 7 }; // 5 for E, 1 for fa, 1 for fb
|
||||||
|
inline static size_t Dim() { return dimension; }
|
||||||
|
inline size_t dim() const { return dimension; }
|
||||||
|
|
||||||
|
/// Return local coordinates with respect to another SimpleFundamentalMatrix
|
||||||
|
Vector localCoordinates(const SimpleFundamentalMatrix& F) const;
|
||||||
|
|
||||||
|
/// Retract the given vector to get a new SimpleFundamentalMatrix
|
||||||
|
SimpleFundamentalMatrix retract(const Vector& delta) const;
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Transfer projections from cameras a and b to camera c
|
||||||
|
*
|
||||||
|
* Take two fundamental matrices Fca and Fcb, and two points pa and pb, and
|
||||||
|
* returns the 2D point in view (c) where the epipolar lines intersect.
|
||||||
|
*/
|
||||||
|
GTSAM_EXPORT Point2 EpipolarTransfer(const Matrix3& Fca, const Point2& pa,
|
||||||
|
const Matrix3& Fcb, const Point2& pb);
|
||||||
|
|
||||||
|
/// Represents a set of three fundamental matrices for transferring points
|
||||||
|
/// between three cameras.
|
||||||
|
template <typename F>
|
||||||
|
struct TripleF {
|
||||||
|
F Fab, Fbc, Fca;
|
||||||
|
|
||||||
|
/// Transfers a point from cameras b,c to camera a.
|
||||||
|
Point2 transferToA(const Point2& pb, const Point2& pc) {
|
||||||
|
return EpipolarTransfer(Fab.matrix(), pb, Fca.matrix().transpose(), pc);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transfers a point from camera a,c to camera b.
|
||||||
|
Point2 transferToB(const Point2& pa, const Point2& pc) {
|
||||||
|
return EpipolarTransfer(Fab.matrix().transpose(), pa, Fbc.matrix(), pc);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transfers a point from cameras a,b to camera c.
|
||||||
|
Point2 transferToC(const Point2& pa, const Point2& pb) {
|
||||||
|
return EpipolarTransfer(Fca.matrix(), pa, Fbc.matrix().transpose(), pb);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct traits<FundamentalMatrix>
|
||||||
|
: public internal::Manifold<FundamentalMatrix> {};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct traits<SimpleFundamentalMatrix>
|
||||||
|
: public internal::Manifold<SimpleFundamentalMatrix> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,234 @@
|
||||||
|
/*
|
||||||
|
* @file testFundamentalMatrix.cpp
|
||||||
|
* @brief Test FundamentalMatrix classes
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date October 23, 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/geometry/FundamentalMatrix.h>
|
||||||
|
#include <gtsam/geometry/Rot3.h>
|
||||||
|
#include <gtsam/geometry/SimpleCamera.h>
|
||||||
|
#include <gtsam/geometry/Unit3.h>
|
||||||
|
|
||||||
|
using namespace std::placeholders;
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
GTSAM_CONCEPT_TESTABLE_INST(FundamentalMatrix)
|
||||||
|
GTSAM_CONCEPT_MANIFOLD_INST(FundamentalMatrix)
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Create two rotations and corresponding fundamental matrix F
|
||||||
|
Rot3 trueU = Rot3::Yaw(M_PI_2);
|
||||||
|
Rot3 trueV = Rot3::Yaw(M_PI_4);
|
||||||
|
double trueS = 0.5;
|
||||||
|
FundamentalMatrix trueF(trueU, trueS, trueV);
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(FundamentalMatrix, localCoordinates) {
|
||||||
|
Vector expected = Z_7x1; // Assuming 7 dimensions for U, V, and s
|
||||||
|
Vector actual = trueF.localCoordinates(trueF);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(FundamentalMatrix, retract) {
|
||||||
|
FundamentalMatrix actual = trueF.retract(Z_7x1);
|
||||||
|
EXPECT(assert_equal(trueF, actual));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(FundamentalMatrix, RoundTrip) {
|
||||||
|
Vector7 d;
|
||||||
|
d << 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7;
|
||||||
|
FundamentalMatrix hx = trueF.retract(d);
|
||||||
|
Vector actual = trueF.localCoordinates(hx);
|
||||||
|
EXPECT(assert_equal(d, actual, 1e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Create the simplest SimpleFundamentalMatrix, a stereo pair
|
||||||
|
EssentialMatrix defaultE(Rot3(), Unit3(1, 0, 0));
|
||||||
|
Point2 zero(0.0, 0.0);
|
||||||
|
SimpleFundamentalMatrix stereoF(defaultE, 1.0, 1.0, zero, zero);
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(SimpleStereo, Conversion) {
|
||||||
|
FundamentalMatrix convertedF(stereoF.matrix());
|
||||||
|
EXPECT(assert_equal(stereoF.matrix(), convertedF.matrix(), 1e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(SimpleStereo, localCoordinates) {
|
||||||
|
Vector expected = Z_7x1;
|
||||||
|
Vector actual = stereoF.localCoordinates(stereoF);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(SimpleStereo, retract) {
|
||||||
|
SimpleFundamentalMatrix actual = stereoF.retract(Z_9x1);
|
||||||
|
EXPECT(assert_equal(stereoF, actual));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(SimpleStereo, RoundTrip) {
|
||||||
|
Vector7 d;
|
||||||
|
d << 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7;
|
||||||
|
SimpleFundamentalMatrix hx = stereoF.retract(d);
|
||||||
|
Vector actual = stereoF.localCoordinates(hx);
|
||||||
|
EXPECT(assert_equal(d, actual, 1e-8));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(SimpleStereo, EpipolarLine) {
|
||||||
|
// Create a point in b
|
||||||
|
Point3 p_b(0, 2, 1);
|
||||||
|
// Convert the point to a horizontal line in a
|
||||||
|
Vector3 l_a = stereoF.matrix() * p_b;
|
||||||
|
// Check if the line is horizontal at height 2
|
||||||
|
EXPECT(assert_equal(Vector3(0, -1, 2), l_a));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Create a stereo pair, but in pixels not normalized coordinates.
|
||||||
|
// We're still using zero principal points here.
|
||||||
|
double focalLength = 1000;
|
||||||
|
SimpleFundamentalMatrix pixelStereo(defaultE, focalLength, focalLength, zero,
|
||||||
|
zero);
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(PixelStereo, Conversion) {
|
||||||
|
auto expected = pixelStereo.matrix();
|
||||||
|
|
||||||
|
FundamentalMatrix convertedF(pixelStereo.matrix());
|
||||||
|
|
||||||
|
// Check equality of F-matrices up to a scale
|
||||||
|
auto actual = convertedF.matrix();
|
||||||
|
actual *= expected(1, 2) / actual(1, 2);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-5));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(PixelStereo, PointInBToHorizontalLineInA) {
|
||||||
|
// Create a point in b
|
||||||
|
Point3 p_b = Point3(0, 300, 1);
|
||||||
|
// Convert the point to a horizontal line in a
|
||||||
|
Vector3 l_a = pixelStereo.matrix() * p_b;
|
||||||
|
// Check if the line is horizontal at height 2
|
||||||
|
EXPECT(assert_equal(Vector3(0, -0.001, 0.3), l_a));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Create a stereo pair with the right camera rotated 90 degrees
|
||||||
|
Rot3 aRb = Rot3::Rz(M_PI_2); // Rotate 90 degrees around the Z-axis
|
||||||
|
EssentialMatrix rotatedE(aRb, Unit3(1, 0, 0));
|
||||||
|
SimpleFundamentalMatrix rotatedPixelStereo(rotatedE, focalLength, focalLength,
|
||||||
|
zero, zero);
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(RotatedPixelStereo, Conversion) {
|
||||||
|
auto expected = rotatedPixelStereo.matrix();
|
||||||
|
|
||||||
|
FundamentalMatrix convertedF(rotatedPixelStereo.matrix());
|
||||||
|
|
||||||
|
// Check equality of F-matrices up to a scale
|
||||||
|
auto actual = convertedF.matrix();
|
||||||
|
actual *= expected(1, 2) / actual(1, 2);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(RotatedPixelStereo, PointInBToHorizontalLineInA) {
|
||||||
|
// Create a point in b
|
||||||
|
Point3 p_b = Point3(300, 0, 1);
|
||||||
|
// Convert the point to a horizontal line in a
|
||||||
|
Vector3 l_a = rotatedPixelStereo.matrix() * p_b;
|
||||||
|
// Check if the line is horizontal at height 2
|
||||||
|
EXPECT(assert_equal(Vector3(0, -0.001, 0.3), l_a));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Now check that principal points also survive conversion
|
||||||
|
Point2 principalPoint(640 / 2, 480 / 2);
|
||||||
|
SimpleFundamentalMatrix stereoWithPrincipalPoints(rotatedE, focalLength,
|
||||||
|
focalLength, principalPoint,
|
||||||
|
principalPoint);
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(stereoWithPrincipalPoints, Conversion) {
|
||||||
|
auto expected = stereoWithPrincipalPoints.matrix();
|
||||||
|
|
||||||
|
FundamentalMatrix convertedF(stereoWithPrincipalPoints.matrix());
|
||||||
|
|
||||||
|
// Check equality of F-matrices up to a scale
|
||||||
|
auto actual = convertedF.matrix();
|
||||||
|
actual *= expected(1, 2) / actual(1, 2);
|
||||||
|
EXPECT(assert_equal(expected, actual, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
/// Generate three cameras on a circle, looking in
|
||||||
|
std::array<Pose3, 3> generateCameraPoses() {
|
||||||
|
std::array<Pose3, 3> cameraPoses;
|
||||||
|
const double radius = 1.0;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
double angle = i * 2.0 * M_PI / 3.0;
|
||||||
|
double c = cos(angle), s = sin(angle);
|
||||||
|
Rot3 aRb({-s, c, 0}, {0, 0, -1}, {-c, -s, 0});
|
||||||
|
cameraPoses[i] = {aRb, Point3(radius * c, radius * s, 0)};
|
||||||
|
}
|
||||||
|
return cameraPoses;
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
/// Function to generate a TripleF from camera poses
|
||||||
|
TripleF<SimpleFundamentalMatrix> generateTripleF(
|
||||||
|
const std::array<Pose3, 3>& cameraPoses) {
|
||||||
|
std::array<SimpleFundamentalMatrix, 3> F;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
size_t j = (i + 1) % 3;
|
||||||
|
const Pose3 iPj = cameraPoses[i].between(cameraPoses[j]);
|
||||||
|
EssentialMatrix E(iPj.rotation(), Unit3(iPj.translation()));
|
||||||
|
F[i] = {E, focalLength, focalLength, principalPoint, principalPoint};
|
||||||
|
}
|
||||||
|
return {F[0], F[1], F[2]}; // Return a TripleF instance
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
TEST(TripleF, Transfer) {
|
||||||
|
// Generate cameras on a circle, as well as fundamental matrices
|
||||||
|
auto cameraPoses = generateCameraPoses();
|
||||||
|
auto triplet = generateTripleF(cameraPoses);
|
||||||
|
|
||||||
|
// Check that they are all equal
|
||||||
|
EXPECT(triplet.Fab.equals(triplet.Fbc, 1e-9));
|
||||||
|
EXPECT(triplet.Fbc.equals(triplet.Fca, 1e-9));
|
||||||
|
EXPECT(triplet.Fca.equals(triplet.Fab, 1e-9));
|
||||||
|
|
||||||
|
// Now project a point into the three cameras
|
||||||
|
const Point3 P(0.1, 0.2, 0.3);
|
||||||
|
const Cal3_S2 K(focalLength, focalLength, 0.0, //
|
||||||
|
principalPoint.x(), principalPoint.y());
|
||||||
|
|
||||||
|
std::array<Point2, 3> p;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
// Project the point into each camera
|
||||||
|
PinholeCameraCal3_S2 camera(cameraPoses[i], K);
|
||||||
|
p[i] = camera.project(P);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that transfer works
|
||||||
|
EXPECT(assert_equal<Point2>(p[0], triplet.transferToA(p[1], p[2]), 1e-9));
|
||||||
|
EXPECT(assert_equal<Point2>(p[1], triplet.transferToB(p[0], p[2]), 1e-9));
|
||||||
|
EXPECT(assert_equal<Point2>(p[2], triplet.transferToC(p[0], p[1]), 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
//*************************************************************************
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
# Install headers
|
# Install headers
|
||||||
set(subdir hybrid)
|
set(subdir hybrid)
|
||||||
file(GLOB hybrid_headers "*.h")
|
file(GLOB hybrid_headers "*.h")
|
||||||
# FIXME: exclude headers
|
|
||||||
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid)
|
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid)
|
||||||
|
|
||||||
# Add all tests
|
# Add all tests
|
||||||
|
|
|
||||||
|
|
@ -17,10 +17,13 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
// In Wrappers we have no access to this so have a default ready
|
// In Wrappers we have no access to this so have a default ready
|
||||||
static std::mt19937_64 kRandomNumberGenerator(42);
|
static std::mt19937_64 kRandomNumberGenerator(42);
|
||||||
|
|
||||||
|
|
@ -38,135 +41,26 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
/**
|
// The implementation is: build the entire joint into one factor and then prune.
|
||||||
* @brief Helper function to get the pruner functional.
|
// TODO(Frank): This can be quite expensive *unless* the factors have already
|
||||||
*
|
// been pruned before. Another, possibly faster approach is branch and bound
|
||||||
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
|
// search to find the K-best leaves and then create a single pruned conditional.
|
||||||
* @param conditional Conditional to prune. Used to get full assignment.
|
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
* @return std::function<double(const Assignment<Key> &, double)>
|
// Collect all the discrete conditionals. Could be small if already pruned.
|
||||||
*/
|
const DiscreteBayesNet marginal = discreteMarginal();
|
||||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|
||||||
const DecisionTreeFactor &prunedDiscreteProbs,
|
|
||||||
const HybridConditional &conditional) {
|
|
||||||
// Get the discrete keys as sets for the decision tree
|
|
||||||
// and the hybrid Gaussian conditional.
|
|
||||||
std::set<DiscreteKey> discreteProbsKeySet =
|
|
||||||
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
|
|
||||||
std::set<DiscreteKey> conditionalKeySet =
|
|
||||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
|
||||||
|
|
||||||
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
const Assignment<Key> &choices,
|
DiscreteConditional joint;
|
||||||
double probability) -> double {
|
for (auto &&conditional : marginal) {
|
||||||
// This corresponds to 0 probability
|
joint = joint * (*conditional);
|
||||||
double pruned_prob = 0.0;
|
|
||||||
|
|
||||||
// typecast so we can use this to get probability value
|
|
||||||
DiscreteValues values(choices);
|
|
||||||
// Case where the hybrid Gaussian conditional has the same
|
|
||||||
// discrete keys as the decision tree.
|
|
||||||
if (conditionalKeySet == discreteProbsKeySet) {
|
|
||||||
if (prunedDiscreteProbs(values) == 0) {
|
|
||||||
return pruned_prob;
|
|
||||||
} else {
|
|
||||||
return probability;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
|
|
||||||
// get a `values` which doesn't have the full set of keys.
|
|
||||||
std::set<Key> valuesKeys;
|
|
||||||
for (auto kvp : values) {
|
|
||||||
valuesKeys.insert(kvp.first);
|
|
||||||
}
|
|
||||||
std::set<Key> conditionalKeys;
|
|
||||||
for (auto kvp : conditionalKeySet) {
|
|
||||||
conditionalKeys.insert(kvp.first);
|
|
||||||
}
|
|
||||||
// If true, then values is missing some keys
|
|
||||||
if (conditionalKeys != valuesKeys) {
|
|
||||||
// Get the keys present in conditionalKeys but not in valuesKeys
|
|
||||||
std::vector<Key> missing_keys;
|
|
||||||
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
|
|
||||||
valuesKeys.begin(), valuesKeys.end(),
|
|
||||||
std::back_inserter(missing_keys));
|
|
||||||
// Insert missing keys with a default assignment.
|
|
||||||
for (auto missing_key : missing_keys) {
|
|
||||||
values[missing_key] = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we generate the full assignment by enumerating
|
// Prune the joint. NOTE: again, possibly quite expensive.
|
||||||
// over all keys in the prunedDiscreteProbs.
|
const DecisionTreeFactor pruned = joint.prune(maxNrLeaves);
|
||||||
// First we find the differing keys
|
|
||||||
std::vector<DiscreteKey> set_diff;
|
|
||||||
std::set_difference(discreteProbsKeySet.begin(),
|
|
||||||
discreteProbsKeySet.end(), conditionalKeySet.begin(),
|
|
||||||
conditionalKeySet.end(),
|
|
||||||
std::back_inserter(set_diff));
|
|
||||||
|
|
||||||
// Now enumerate over all assignments of the differing keys
|
// Create a the result starting with the pruned joint.
|
||||||
const std::vector<DiscreteValues> assignments =
|
HybridBayesNet result;
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
result.emplace_shared<DiscreteConditional>(pruned.size(), pruned);
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
|
||||||
DiscreteValues augmented_values(values);
|
|
||||||
augmented_values.insert(assignment);
|
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
|
||||||
// we need this probability.
|
|
||||||
if (prunedDiscreteProbs(augmented_values) > 0.0) {
|
|
||||||
return probability;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If we are here, it means that all the sub-branches are 0,
|
|
||||||
// so we prune.
|
|
||||||
return pruned_prob;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return pruner;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
|
|
||||||
size_t maxNrLeaves) {
|
|
||||||
// Get the joint distribution of only the discrete keys
|
|
||||||
// The joint discrete probability.
|
|
||||||
DiscreteConditional discreteProbs;
|
|
||||||
|
|
||||||
std::vector<size_t> discrete_factor_idxs;
|
|
||||||
// Record frontal keys so we can maintain ordering
|
|
||||||
Ordering discrete_frontals;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
|
||||||
auto conditional = this->at(i);
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
discreteProbs = discreteProbs * (*conditional->asDiscrete());
|
|
||||||
|
|
||||||
Ordering conditional_keys(conditional->frontals());
|
|
||||||
discrete_frontals += conditional_keys;
|
|
||||||
discrete_factor_idxs.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const DecisionTreeFactor prunedDiscreteProbs =
|
|
||||||
discreteProbs.prune(maxNrLeaves);
|
|
||||||
|
|
||||||
// Eliminate joint probability back into conditionals
|
|
||||||
DiscreteFactorGraph dfg{prunedDiscreteProbs};
|
|
||||||
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);
|
|
||||||
|
|
||||||
// Assign pruned discrete conditionals back at the correct indices.
|
|
||||||
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
|
|
||||||
size_t idx = discrete_factor_idxs.at(i);
|
|
||||||
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
return prunedDiscreteProbs;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|
||||||
DecisionTreeFactor prunedDiscreteProbs =
|
|
||||||
this->pruneDiscreteConditionals(maxNrLeaves);
|
|
||||||
|
|
||||||
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
/* To prune, we visitWith every leaf in the HybridGaussianConditional.
|
||||||
* For each leaf, using the assignment we can check the discrete decision tree
|
* For each leaf, using the assignment we can check the discrete decision tree
|
||||||
|
|
@ -175,28 +69,34 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
* We can later check the HybridGaussianConditional for just nullptrs.
|
* We can later check the HybridGaussianConditional for just nullptrs.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
HybridBayesNet prunedBayesNetFragment;
|
// Go through all the Gaussian conditionals in the Bayes Net and prune them as
|
||||||
|
// per pruned Discrete joint.
|
||||||
// Go through all the conditionals in the
|
|
||||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asHybrid()) {
|
if (auto hgc = conditional->asHybrid()) {
|
||||||
// Make a copy of the hybrid Gaussian conditional and prune it!
|
// Prune the hybrid Gaussian conditional!
|
||||||
auto prunedHybridGaussianConditional =
|
auto prunedHybridGaussianConditional = hgc->prune(pruned);
|
||||||
std::make_shared<HybridGaussianConditional>(*gm);
|
|
||||||
prunedHybridGaussianConditional->prune(
|
|
||||||
prunedDiscreteProbs); // imperative :-(
|
|
||||||
|
|
||||||
// Type-erase and add to the pruned Bayes Net fragment.
|
// Type-erase and add to the pruned Bayes Net fragment.
|
||||||
prunedBayesNetFragment.push_back(prunedHybridGaussianConditional);
|
result.push_back(prunedHybridGaussianConditional);
|
||||||
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
} else {
|
|
||||||
// Add the non-HybridGaussianConditional conditional
|
// Add the non-HybridGaussianConditional conditional
|
||||||
prunedBayesNetFragment.push_back(conditional);
|
result.push_back(gc);
|
||||||
}
|
}
|
||||||
|
// We ignore DiscreteConditional as they are already pruned and added.
|
||||||
}
|
}
|
||||||
|
|
||||||
return prunedBayesNetFragment;
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
DiscreteBayesNet HybridBayesNet::discreteMarginal() const {
|
||||||
|
DiscreteBayesNet result;
|
||||||
|
for (auto &&conditional : *this) {
|
||||||
|
if (auto dc = conditional->asDiscrete()) {
|
||||||
|
result.push_back(dc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -206,7 +106,7 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asHybrid()) {
|
if (auto gm = conditional->asHybrid()) {
|
||||||
// If conditional is hybrid, select based on assignment.
|
// If conditional is hybrid, select based on assignment.
|
||||||
gbn.push_back((*gm)(assignment));
|
gbn.push_back(gm->choose(assignment));
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// If continuous only, add Gaussian conditional.
|
// If continuous only, add Gaussian conditional.
|
||||||
gbn.push_back(gc);
|
gbn.push_back(gc);
|
||||||
|
|
@ -291,66 +191,19 @@ AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asHybrid()) {
|
result = result + conditional->errorTree(continuousValues);
|
||||||
// If conditional is hybrid, compute error for all assignments.
|
|
||||||
result = result + gm->errorTree(continuousValues);
|
|
||||||
|
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
|
||||||
// If continuous, get the error and add it to the result
|
|
||||||
double error = gc->error(continuousValues);
|
|
||||||
// Add the computed error to every leaf of the result tree.
|
|
||||||
result = result.apply(
|
|
||||||
[error](double leaf_value) { return leaf_value + error; });
|
|
||||||
|
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
|
||||||
// If discrete, add the discrete error in the right branch
|
|
||||||
result = result.apply(
|
|
||||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
|
||||||
return leaf_value + dc->error(DiscreteValues(assignment));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
AlgebraicDecisionTree<Key> HybridBayesNet::discretePosterior(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> result(0.0);
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
|
AlgebraicDecisionTree<Key> p =
|
||||||
// Iterate over each conditional.
|
errors.apply([](double error) { return exp(-error); });
|
||||||
for (auto &&conditional : *this) {
|
return p / p.sum();
|
||||||
if (auto gm = conditional->asHybrid()) {
|
|
||||||
// If conditional is hybrid, select based on assignment and compute
|
|
||||||
// logProbability.
|
|
||||||
result = result + gm->logProbability(continuousValues);
|
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
|
||||||
// If continuous, get the (double) logProbability and add it to the
|
|
||||||
// result
|
|
||||||
double logProbability = gc->logProbability(continuousValues);
|
|
||||||
// Add the computed logProbability to every leaf of the logProbability
|
|
||||||
// tree.
|
|
||||||
result = result.apply([logProbability](double leaf_value) {
|
|
||||||
return leaf_value + logProbability;
|
|
||||||
});
|
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
|
||||||
// If discrete, add the discrete logProbability in the right branch
|
|
||||||
result = result.apply(
|
|
||||||
[dc](const Assignment<Key> &assignment, double leaf_value) {
|
|
||||||
return leaf_value + dc->logProbability(DiscreteValues(assignment));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::evaluate(
|
|
||||||
const VectorValues &continuousValues) const {
|
|
||||||
AlgebraicDecisionTree<Key> tree = this->logProbability(continuousValues);
|
|
||||||
return tree.apply([](double log) { return exp(log); });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
#include <gtsam/global_includes.h>
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
@ -77,16 +78,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a conditional using a shared_ptr, using implicit conversion to
|
* Move a HybridConditional into a shared pointer and add.
|
||||||
* a HybridConditional.
|
|
||||||
*
|
|
||||||
* This is useful when you create a conditional shared pointer as you need it
|
|
||||||
* somewhere else.
|
|
||||||
*
|
|
||||||
* Example:
|
* Example:
|
||||||
* auto shared_ptr_to_a_conditional =
|
* HybridGaussianConditional conditional(...);
|
||||||
* std::make_shared<HybridGaussianConditional>(...);
|
* hbn.push_back(conditional); // loses the original conditional
|
||||||
* hbn.push_back(shared_ptr_to_a_conditional);
|
|
||||||
*/
|
*/
|
||||||
void push_back(HybridConditional &&conditional) {
|
void push_back(HybridConditional &&conditional) {
|
||||||
factors_.push_back(
|
factors_.push_back(
|
||||||
|
|
@ -124,11 +120,21 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
|
* @brief Get the discrete Bayes Net P(M). As the hybrid Bayes net defines
|
||||||
* value assignment.
|
* P(X,M) = P(X|M) P(M), this method returns the marginal distribution on the
|
||||||
|
* discrete variables.
|
||||||
|
*
|
||||||
|
* @return discrete marginal as a DiscreteBayesNet.
|
||||||
|
*/
|
||||||
|
DiscreteBayesNet discreteMarginal() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the Gaussian Bayes net P(X|M=m) corresponding to a specific
|
||||||
|
* assignment m for the discrete variables M. As the hybrid Bayes net defines
|
||||||
|
* P(X,M) = P(X|M) P(M), this method returns the **posterior** p(X|M=m).
|
||||||
*
|
*
|
||||||
* @param assignment The discrete value assignment for the discrete keys.
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
* @return GaussianBayesNet
|
* @return Gaussian posterior P(X|M=m) as a GaussianBayesNet.
|
||||||
*/
|
*/
|
||||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
|
|
@ -199,18 +205,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
HybridValues sample() const;
|
HybridValues sample() const;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
|
|
||||||
HybridBayesNet prune(size_t maxNrLeaves);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute conditional error for each discrete assignment,
|
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||||
* and return as a tree.
|
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the error.
|
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return A pruned HybridBayesNet
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> errorTree(
|
HybridBayesNet prune(size_t maxNrLeaves) const;
|
||||||
const VectorValues &continuousValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Error method using HybridValues which returns specific error for
|
* @brief Error method using HybridValues which returns specific error for
|
||||||
|
|
@ -219,29 +220,33 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using Base::error;
|
using Base::error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute log probability for each discrete assignment,
|
* @brief Compute the negative log posterior log P'(M|x) of all assignments up
|
||||||
* and return as a tree.
|
* to a constant, returning the result as an algebraic decision tree.
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which
|
* @note The joint P(X,M) is p(X|M) P(M)
|
||||||
* to compute the log probability.
|
* Then the posterior on M given X=x is is P(M|x) = p(x|M) P(M) / p(x).
|
||||||
|
* Ideally we want log P(M|x) = log p(x|M) + log P(M) - log p(x), but
|
||||||
|
* unfortunately log p(x) is expensive, so we compute the log of the
|
||||||
|
* unnormalized posterior log P'(M|x) = log p(x|M) + log P(M)
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x at which to compute log P'(M|x)
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> logProbability(
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
using BayesNet::logProbability; // expose HybridValues version
|
using BayesNet::logProbability; // expose HybridValues version
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute unnormalized probability q(μ|M),
|
* @brief Compute normalized posterior P(M|X=x) and return as a tree.
|
||||||
* for each discrete assignment, and return as a tree.
|
|
||||||
* q(μ|M) is the unnormalized probability at the MLE point μ,
|
|
||||||
* conditioned on the discrete variables.
|
|
||||||
*
|
*
|
||||||
* @param continuousValues Continuous values at which to compute the
|
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||||
* probability.
|
* which we would need, are hard to recover.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x to condition P(M|X=x) on.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> evaluate(
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
const VectorValues &continuousValues) const;
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -253,13 +258,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
|
||||||
* @brief Prune all the discrete conditionals.
|
|
||||||
*
|
|
||||||
* @param maxNrLeaves
|
|
||||||
*/
|
|
||||||
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);
|
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,13 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/inference/BayesTree-inst.h>
|
#include <gtsam/inference/BayesTree-inst.h>
|
||||||
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
|
||||||
#include <gtsam/linear/GaussianJunctionTree.h>
|
#include <gtsam/linear/GaussianJunctionTree.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
// Instantiate base class
|
// Instantiate base class
|
||||||
|
|
@ -207,7 +210,9 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
auto hybridGaussianCond = conditional->asHybrid();
|
auto hybridGaussianCond = conditional->asHybrid();
|
||||||
|
|
||||||
hybridGaussianCond->prune(parentData.prunedDiscreteProbs);
|
// Imperative
|
||||||
|
clique->conditional() = std::make_shared<HybridConditional>(
|
||||||
|
hybridGaussianCond->prune(parentData.prunedDiscreteProbs));
|
||||||
}
|
}
|
||||||
return parentData;
|
return parentData;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
* @file HybridConditional.cpp
|
* @file HybridConditional.cpp
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
|
@ -64,7 +65,6 @@ void HybridConditional::print(const std::string &s,
|
||||||
|
|
||||||
if (inner_) {
|
if (inner_) {
|
||||||
inner_->print("", formatter);
|
inner_->print("", formatter);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
if (isContinuous()) std::cout << "Continuous ";
|
if (isContinuous()) std::cout << "Continuous ";
|
||||||
if (isDiscrete()) std::cout << "Discrete ";
|
if (isDiscrete()) std::cout << "Discrete ";
|
||||||
|
|
@ -100,16 +100,13 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
if (auto gm = asHybrid()) {
|
if (auto gm = asHybrid()) {
|
||||||
auto other = e->asHybrid();
|
auto other = e->asHybrid();
|
||||||
return other != nullptr && gm->equals(*other, tol);
|
return other != nullptr && gm->equals(*other, tol);
|
||||||
}
|
} else if (auto gc = asGaussian()) {
|
||||||
if (auto gc = asGaussian()) {
|
|
||||||
auto other = e->asGaussian();
|
auto other = e->asGaussian();
|
||||||
return other != nullptr && gc->equals(*other, tol);
|
return other != nullptr && gc->equals(*other, tol);
|
||||||
}
|
} else if (auto dc = asDiscrete()) {
|
||||||
if (auto dc = asDiscrete()) {
|
|
||||||
auto other = e->asDiscrete();
|
auto other = e->asDiscrete();
|
||||||
return other != nullptr && dc->equals(*other, tol);
|
return other != nullptr && dc->equals(*other, tol);
|
||||||
}
|
} else
|
||||||
|
|
||||||
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
: !(e->inner_);
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
@ -118,13 +115,11 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
double HybridConditional::error(const HybridValues &values) const {
|
double HybridConditional::error(const HybridValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->error(values.continuous());
|
return gc->error(values.continuous());
|
||||||
}
|
} else if (auto gm = asHybrid()) {
|
||||||
if (auto gm = asHybrid()) {
|
|
||||||
return gm->error(values);
|
return gm->error(values);
|
||||||
}
|
} else if (auto dc = asDiscrete()) {
|
||||||
if (auto dc = asDiscrete()) {
|
|
||||||
return dc->error(values.discrete());
|
return dc->error(values.discrete());
|
||||||
}
|
} else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridConditional::error: conditional type not handled");
|
"HybridConditional::error: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
@ -133,14 +128,12 @@ double HybridConditional::error(const HybridValues &values) const {
|
||||||
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
|
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
|
||||||
const VectorValues &values) const {
|
const VectorValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return AlgebraicDecisionTree<Key>(gc->error(values));
|
return {gc->error(values)}; // NOTE: a "constant" tree
|
||||||
}
|
} else if (auto gm = asHybrid()) {
|
||||||
if (auto gm = asHybrid()) {
|
|
||||||
return gm->errorTree(values);
|
return gm->errorTree(values);
|
||||||
}
|
} else if (auto dc = asDiscrete()) {
|
||||||
if (auto dc = asDiscrete()) {
|
return dc->errorTree();
|
||||||
return AlgebraicDecisionTree<Key>(0.0);
|
} else
|
||||||
}
|
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridConditional::error: conditional type not handled");
|
"HybridConditional::error: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
@ -149,13 +142,11 @@ AlgebraicDecisionTree<Key> HybridConditional::errorTree(
|
||||||
double HybridConditional::logProbability(const HybridValues &values) const {
|
double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->logProbability(values.continuous());
|
return gc->logProbability(values.continuous());
|
||||||
}
|
} else if (auto gm = asHybrid()) {
|
||||||
if (auto gm = asHybrid()) {
|
|
||||||
return gm->logProbability(values);
|
return gm->logProbability(values);
|
||||||
}
|
} else if (auto dc = asDiscrete()) {
|
||||||
if (auto dc = asDiscrete()) {
|
|
||||||
return dc->logProbability(values.discrete());
|
return dc->logProbability(values.discrete());
|
||||||
}
|
} else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridConditional::logProbability: conditional type not handled");
|
"HybridConditional::logProbability: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
@ -164,13 +155,11 @@ double HybridConditional::logProbability(const HybridValues &values) const {
|
||||||
double HybridConditional::negLogConstant() const {
|
double HybridConditional::negLogConstant() const {
|
||||||
if (auto gc = asGaussian()) {
|
if (auto gc = asGaussian()) {
|
||||||
return gc->negLogConstant();
|
return gc->negLogConstant();
|
||||||
}
|
} else if (auto gm = asHybrid()) {
|
||||||
if (auto gm = asHybrid()) {
|
return gm->negLogConstant();
|
||||||
return gm->negLogConstant(); // 0.0!
|
} else if (auto dc = asDiscrete()) {
|
||||||
}
|
|
||||||
if (auto dc = asDiscrete()) {
|
|
||||||
return dc->negLogConstant(); // 0.0!
|
return dc->negLogConstant(); // 0.0!
|
||||||
}
|
} else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"HybridConditional::negLogConstant: conditional type not handled");
|
"HybridConditional::negLogConstant: conditional type not handled");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
* @file HybridConditional.h
|
* @file HybridConditional.h
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,6 @@ namespace gtsam {
|
||||||
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
|
||||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||||
|
|
|
||||||
|
|
@ -25,18 +25,41 @@
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
#include <gtsam/linear/GaussianBayesNet.h>
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
GaussianConditional::shared_ptr checkConditional(
|
||||||
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
|
if (auto conditional =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(factor)) {
|
||||||
|
return conditional;
|
||||||
|
} else {
|
||||||
|
throw std::logic_error(
|
||||||
|
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
/**
|
||||||
|
* @brief Helper struct for constructing HybridGaussianConditional objects
|
||||||
|
*
|
||||||
|
* This struct contains the following fields:
|
||||||
|
* - nrFrontals: Optional size_t for number of frontal variables
|
||||||
|
* - pairs: FactorValuePairs for storing conditionals with their negLogConstant
|
||||||
|
* - minNegLogConstant: minimum negLogConstant, computed here, subtracted in
|
||||||
|
* constructor
|
||||||
|
*/
|
||||||
struct HybridGaussianConditional::Helper {
|
struct HybridGaussianConditional::Helper {
|
||||||
std::optional<size_t> nrFrontals;
|
|
||||||
FactorValuePairs pairs;
|
FactorValuePairs pairs;
|
||||||
Conditionals conditionals;
|
std::optional<size_t> nrFrontals = {};
|
||||||
double minNegLogConstant;
|
double minNegLogConstant = std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
using GC = GaussianConditional;
|
using GC = GaussianConditional;
|
||||||
using P = std::vector<std::pair<Vector, double>>;
|
using P = std::vector<std::pair<Vector, double>>;
|
||||||
|
|
@ -45,8 +68,6 @@ struct HybridGaussianConditional::Helper {
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
explicit Helper(const DiscreteKey &mode, const P &p, Args &&...args) {
|
||||||
nrFrontals = 1;
|
nrFrontals = 1;
|
||||||
minNegLogConstant = std::numeric_limits<double>::infinity();
|
|
||||||
|
|
||||||
std::vector<GaussianFactorValuePair> fvs;
|
std::vector<GaussianFactorValuePair> fvs;
|
||||||
std::vector<GC::shared_ptr> gcs;
|
std::vector<GC::shared_ptr> gcs;
|
||||||
fvs.reserve(p.size());
|
fvs.reserve(p.size());
|
||||||
|
|
@ -60,24 +81,17 @@ struct HybridGaussianConditional::Helper {
|
||||||
gcs.push_back(gaussianConditional);
|
gcs.push_back(gaussianConditional);
|
||||||
}
|
}
|
||||||
|
|
||||||
conditionals = Conditionals({mode}, gcs);
|
|
||||||
pairs = FactorValuePairs({mode}, fvs);
|
pairs = FactorValuePairs({mode}, fvs);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct from tree of GaussianConditionals.
|
/// Construct from tree of GaussianConditionals.
|
||||||
explicit Helper(const Conditionals &conditionals)
|
explicit Helper(const Conditionals &conditionals) {
|
||||||
: conditionals(conditionals),
|
auto func = [this](const GC::shared_ptr &gc) -> GaussianFactorValuePair {
|
||||||
minNegLogConstant(std::numeric_limits<double>::infinity()) {
|
if (!gc) return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
auto func = [this](const GC::shared_ptr &c) -> GaussianFactorValuePair {
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
double value = 0.0;
|
double value = gc->negLogConstant();
|
||||||
if (c) {
|
|
||||||
if (!nrFrontals.has_value()) {
|
|
||||||
nrFrontals = c->nrFrontals();
|
|
||||||
}
|
|
||||||
value = c->negLogConstant();
|
|
||||||
minNegLogConstant = std::min(minNegLogConstant, value);
|
minNegLogConstant = std::min(minNegLogConstant, value);
|
||||||
}
|
return {gc, value};
|
||||||
return {std::dynamic_pointer_cast<GaussianFactor>(c), value};
|
|
||||||
};
|
};
|
||||||
pairs = FactorValuePairs(conditionals, func);
|
pairs = FactorValuePairs(conditionals, func);
|
||||||
if (!nrFrontals.has_value()) {
|
if (!nrFrontals.has_value()) {
|
||||||
|
|
@ -86,14 +100,36 @@ struct HybridGaussianConditional::Helper {
|
||||||
"Provided conditionals do not contain any frontal variables.");
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Construct from tree of factor/scalar pairs.
|
||||||
|
explicit Helper(const FactorValuePairs &pairs) : pairs(pairs) {
|
||||||
|
auto func = [this](const GaussianFactorValuePair &pair) {
|
||||||
|
if (!pair.first) return;
|
||||||
|
auto gc = checkConditional(pair.first);
|
||||||
|
if (!nrFrontals) nrFrontals = gc->nrFrontals();
|
||||||
|
minNegLogConstant = std::min(minNegLogConstant, pair.second);
|
||||||
|
};
|
||||||
|
pairs.visit(func);
|
||||||
|
if (!nrFrontals.has_value()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridGaussianConditional: need at least one frontal variable. "
|
||||||
|
"Provided conditionals do not contain any frontal variables.");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const DiscreteKeys &discreteParents, const Helper &helper)
|
const DiscreteKeys &discreteParents, Helper &&helper)
|
||||||
: BaseFactor(discreteParents, helper.pairs),
|
: BaseFactor(discreteParents,
|
||||||
|
FactorValuePairs(
|
||||||
|
[&](const GaussianFactorValuePair
|
||||||
|
&pair) { // subtract minNegLogConstant
|
||||||
|
return GaussianFactorValuePair{
|
||||||
|
pair.first, pair.second - helper.minNegLogConstant};
|
||||||
|
},
|
||||||
|
std::move(helper.pairs))),
|
||||||
BaseConditional(*helper.nrFrontals),
|
BaseConditional(*helper.nrFrontals),
|
||||||
conditionals_(helper.conditionals),
|
|
||||||
negLogConstant_(helper.minNegLogConstant) {}
|
negLogConstant_(helper.minNegLogConstant) {}
|
||||||
|
|
||||||
HybridGaussianConditional::HybridGaussianConditional(
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
|
|
@ -129,55 +165,35 @@ HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const HybridGaussianConditional::Conditionals &conditionals)
|
const HybridGaussianConditional::Conditionals &conditionals)
|
||||||
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
: HybridGaussianConditional(discreteParents, Helper(conditionals)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
HybridGaussianConditional::HybridGaussianConditional(
|
||||||
const HybridGaussianConditional::Conditionals &
|
const DiscreteKeys &discreteParents, const FactorValuePairs &pairs)
|
||||||
HybridGaussianConditional::conditionals() const {
|
: HybridGaussianConditional(discreteParents, Helper(pairs)) {}
|
||||||
return conditionals_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
const HybridGaussianConditional::Conditionals
|
||||||
const {
|
HybridGaussianConditional::conditionals() const {
|
||||||
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
|
return Conditionals(factors(), [](auto &&pair) {
|
||||||
// First check if conditional has not been pruned
|
return std::dynamic_pointer_cast<GaussianConditional>(pair.first);
|
||||||
if (gc) {
|
});
|
||||||
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
|
|
||||||
// If there is a difference in the covariances, we need to account for
|
|
||||||
// that since the error is dependent on the mode.
|
|
||||||
if (Cgm_Kgcm > 0.0) {
|
|
||||||
// We add a constant factor which will be used when computing
|
|
||||||
// the probability of the discrete variables.
|
|
||||||
Vector c(1);
|
|
||||||
c << std::sqrt(2.0 * Cgm_Kgcm);
|
|
||||||
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
|
||||||
return GaussianFactorGraph{gc, constantFactor};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return GaussianFactorGraph{gc};
|
|
||||||
};
|
|
||||||
return {conditionals_, wrap};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
size_t HybridGaussianConditional::nrComponents() const {
|
size_t HybridGaussianConditional::nrComponents() const {
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
|
factors().visit([&total](auto &&node) {
|
||||||
if (node) total += 1;
|
if (node.first) total += 1;
|
||||||
});
|
});
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
|
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
|
||||||
const DiscreteValues &discreteValues) const {
|
const DiscreteValues &discreteValues) const {
|
||||||
auto &ptr = conditionals_(discreteValues);
|
auto &[factor, _] = factors()(discreteValues);
|
||||||
if (!ptr) return nullptr;
|
if (!factor) return nullptr;
|
||||||
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
|
|
||||||
if (conditional)
|
auto conditional = checkConditional(factor);
|
||||||
return conditional;
|
return conditional;
|
||||||
else
|
|
||||||
throw std::logic_error(
|
|
||||||
"A HybridGaussianConditional unexpectedly contained a non-conditional");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
@ -186,26 +202,22 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
if (e == nullptr) return false;
|
if (e == nullptr) return false;
|
||||||
|
|
||||||
// This will return false if either conditionals_ is empty or e->conditionals_
|
// Factors existence and scalar values are checked in BaseFactor::equals.
|
||||||
// is empty, but not if both are empty or both are not empty:
|
// Here we check additionally that the factors *are* conditionals
|
||||||
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
|
// and are equal.
|
||||||
|
auto compareFunc = [tol](const GaussianFactorValuePair &pair1,
|
||||||
// Check the base and the factors:
|
const GaussianFactorValuePair &pair2) {
|
||||||
return BaseFactor::equals(*e, tol) &&
|
auto c1 = std::dynamic_pointer_cast<GaussianConditional>(pair1.first),
|
||||||
conditionals_.equals(e->conditionals_,
|
c2 = std::dynamic_pointer_cast<GaussianConditional>(pair2.first);
|
||||||
[tol](const GaussianConditional::shared_ptr &f1,
|
return (!c1 && !c2) || (c1 && c2 && c1->equals(*c2, tol));
|
||||||
const GaussianConditional::shared_ptr &f2) {
|
};
|
||||||
return f1->equals(*(f2), tol);
|
return Base::equals(*e, tol) && factors().equals(e->factors(), compareFunc);
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianConditional::print(const std::string &s,
|
void HybridGaussianConditional::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
if (isContinuous()) std::cout << "Continuous ";
|
|
||||||
if (isDiscrete()) std::cout << "Discrete ";
|
|
||||||
if (isHybrid()) std::cout << "Hybrid ";
|
|
||||||
BaseConditional::print("", formatter);
|
BaseConditional::print("", formatter);
|
||||||
std::cout << " Discrete Keys = ";
|
std::cout << " Discrete Keys = ";
|
||||||
for (auto &dk : discreteKeys()) {
|
for (auto &dk : discreteKeys()) {
|
||||||
|
|
@ -214,11 +226,12 @@ void HybridGaussianConditional::print(const std::string &s,
|
||||||
std::cout << std::endl
|
std::cout << std::endl
|
||||||
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
<< " logNormalizationConstant: " << -negLogConstant() << std::endl
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
conditionals_.print(
|
factors().print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
|
[&](const GaussianFactorValuePair &pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
if (gf && !gf->empty()) {
|
if (auto gf =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
gf->print("", formatter);
|
gf->print("", formatter);
|
||||||
return rd.str();
|
return rd.str();
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -266,17 +279,15 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
|
||||||
const DiscreteKeys discreteParentKeys = discreteKeys();
|
const DiscreteKeys discreteParentKeys = discreteKeys();
|
||||||
const KeyVector continuousParentKeys = continuousParents();
|
const KeyVector continuousParentKeys = continuousParents();
|
||||||
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
const HybridGaussianFactor::FactorValuePairs likelihoods(
|
||||||
conditionals_,
|
factors(),
|
||||||
[&](const GaussianConditional::shared_ptr &conditional)
|
[&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
-> GaussianFactorValuePair {
|
if (auto conditional =
|
||||||
|
std::dynamic_pointer_cast<GaussianConditional>(pair.first)) {
|
||||||
const auto likelihood_m = conditional->likelihood(given);
|
const auto likelihood_m = conditional->likelihood(given);
|
||||||
const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_;
|
// pair.second == conditional->negLogConstant() - negLogConstant_
|
||||||
if (Cgm_Kgcm == 0.0) {
|
return {likelihood_m, pair.second};
|
||||||
return {likelihood_m, 0.0};
|
|
||||||
} else {
|
} else {
|
||||||
// Add a constant to the likelihood in case the noise models
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
// are not all equal.
|
|
||||||
return {likelihood_m, Cgm_Kgcm};
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
return std::make_shared<HybridGaussianFactor>(discreteParentKeys,
|
||||||
|
|
@ -289,97 +300,49 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
std::function<GaussianConditional::shared_ptr(
|
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
|
||||||
// Get the discrete keys as sets for the decision tree
|
|
||||||
// and the hybrid gaussian conditional.
|
|
||||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
|
||||||
auto hybridGaussianCondKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
|
||||||
|
|
||||||
auto pruner = [discreteProbs, discreteProbsKeySet, hybridGaussianCondKeySet](
|
|
||||||
const Assignment<Key> &choices,
|
|
||||||
const GaussianConditional::shared_ptr &conditional)
|
|
||||||
-> GaussianConditional::shared_ptr {
|
|
||||||
// typecast so we can use this to get probability value
|
|
||||||
const DiscreteValues values(choices);
|
|
||||||
|
|
||||||
// Case where the hybrid gaussian conditional has the same
|
|
||||||
// discrete keys as the decision tree.
|
|
||||||
if (hybridGaussianCondKeySet == discreteProbsKeySet) {
|
|
||||||
if (discreteProbs(values) == 0.0) {
|
|
||||||
// empty aka null pointer
|
|
||||||
std::shared_ptr<GaussianConditional> null;
|
|
||||||
return null;
|
|
||||||
} else {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
std::vector<DiscreteKey> set_diff;
|
|
||||||
std::set_difference(
|
|
||||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
|
||||||
hybridGaussianCondKeySet.begin(), hybridGaussianCondKeySet.end(),
|
|
||||||
std::back_inserter(set_diff));
|
|
||||||
|
|
||||||
const std::vector<DiscreteValues> assignments =
|
|
||||||
DiscreteValues::CartesianProduct(set_diff);
|
|
||||||
for (const DiscreteValues &assignment : assignments) {
|
|
||||||
DiscreteValues augmented_values(values);
|
|
||||||
augmented_values.insert(assignment);
|
|
||||||
|
|
||||||
// If any one of the sub-branches are non-zero,
|
|
||||||
// we need this conditional.
|
|
||||||
if (discreteProbs(augmented_values) > 0.0) {
|
|
||||||
return conditional;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If we are here, it means that all the sub-branches are 0,
|
|
||||||
// so we prune.
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
return pruner;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
|
HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune(
|
||||||
// Functional which loops over all assignments and create a set of
|
const DecisionTreeFactor &discreteProbs) const {
|
||||||
// GaussianConditionals
|
// Find keys in discreteProbs.keys() but not in this->keys():
|
||||||
auto pruner = prunerFunc(discreteProbs);
|
std::set<Key> mine(this->keys().begin(), this->keys().end());
|
||||||
|
std::set<Key> theirs(discreteProbs.keys().begin(),
|
||||||
|
discreteProbs.keys().end());
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
// Find maximum probability value for every combination of our keys.
|
||||||
conditionals_.root_ = pruned_conditionals.root_;
|
Ordering keys(diff);
|
||||||
}
|
auto max = discreteProbs.max(keys);
|
||||||
|
|
||||||
/* *******************************************************************************/
|
// Check the max value for every combination of our keys.
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
|
// If the max value is 0.0, we can prune the corresponding conditional.
|
||||||
const VectorValues &continuousValues) const {
|
auto pruner =
|
||||||
// functor to calculate (double) logProbability value from
|
[&](const Assignment<Key> &choices,
|
||||||
// GaussianConditional.
|
const GaussianFactorValuePair &pair) -> GaussianFactorValuePair {
|
||||||
auto probFunc =
|
if (max->evaluate(choices) == 0.0)
|
||||||
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
if (conditional) {
|
else
|
||||||
return conditional->logProbability(continuousValues);
|
return pair;
|
||||||
} else {
|
|
||||||
// Return arbitrarily small logProbability if conditional is null
|
|
||||||
// Conditional is null if it is pruned out.
|
|
||||||
return -1e20;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
return DecisionTree<Key, double>(conditionals_, probFunc);
|
|
||||||
|
FactorValuePairs prunedConditionals = factors().apply(pruner);
|
||||||
|
return std::make_shared<HybridGaussianConditional>(discreteKeys(),
|
||||||
|
prunedConditionals);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::logProbability(
|
double HybridGaussianConditional::logProbability(
|
||||||
const HybridValues &values) const {
|
const HybridValues &values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->logProbability(values.continuous());
|
return conditional->logProbability(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto [factor, _] = factors()(values.discrete());
|
||||||
|
auto conditional = checkConditional(factor);
|
||||||
return conditional->evaluate(values.continuous());
|
return conditional->evaluate(values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Varun Agrawal
|
* @author Varun Agrawal
|
||||||
|
* @author Frank Dellaert
|
||||||
* @date Mar 12, 2022
|
* @date Mar 12, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
@ -63,8 +64,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
|
|
||||||
|
|
||||||
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
|
||||||
///< Take advantage of the neg-log space so everything is a minimization
|
///< Take advantage of the neg-log space so everything is a minimization
|
||||||
double negLogConstant_;
|
double negLogConstant_;
|
||||||
|
|
@ -142,6 +141,19 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Conditionals &conditionals);
|
const Conditionals &conditionals);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from multiple discrete keys M and a tree of
|
||||||
|
* factor/scalar pairs, where the scalar is assumed to be the
|
||||||
|
* the negative log constant for each assignment m, up to a constant.
|
||||||
|
*
|
||||||
|
* @note Will throw if factors are not actually conditionals.
|
||||||
|
*
|
||||||
|
* @param discreteParents the discrete parents. Will be placed last.
|
||||||
|
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
|
||||||
|
*/
|
||||||
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
|
const FactorValuePairs &pairs);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -159,9 +171,15 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
/// @brief Return the conditional Gaussian for the given discrete assignment.
|
||||||
GaussianConditional::shared_ptr operator()(
|
GaussianConditional::shared_ptr choose(
|
||||||
const DiscreteValues &discreteValues) const;
|
const DiscreteValues &discreteValues) const;
|
||||||
|
|
||||||
|
/// @brief Syntactic sugar for choose.
|
||||||
|
GaussianConditional::shared_ptr operator()(
|
||||||
|
const DiscreteValues &discreteValues) const {
|
||||||
|
return choose(discreteValues);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the total number of continuous components
|
/// Returns the total number of continuous components
|
||||||
size_t nrComponents() const;
|
size_t nrComponents() const;
|
||||||
|
|
||||||
|
|
@ -185,18 +203,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
std::shared_ptr<HybridGaussianFactor> likelihood(
|
std::shared_ptr<HybridGaussianFactor> likelihood(
|
||||||
const VectorValues &given) const;
|
const VectorValues &given) const;
|
||||||
|
|
||||||
/// Getter for the underlying Conditionals DecisionTree
|
/// Get Conditionals DecisionTree (dynamic cast from factors)
|
||||||
const Conditionals &conditionals() const;
|
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
|
||||||
|
const Conditionals conditionals() const;
|
||||||
/**
|
|
||||||
* @brief Compute logProbability of the HybridGaussianConditional as a tree.
|
|
||||||
*
|
|
||||||
* @param continuousValues The continuous VectorValues.
|
|
||||||
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
|
|
||||||
* as the conditionals, and leaf values as the logProbability.
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> logProbability(
|
|
||||||
const VectorValues &continuousValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
* @brief Compute the logProbability of this hybrid Gaussian conditional.
|
||||||
|
|
@ -219,8 +228,10 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
* `discreteProbs`.
|
* `discreteProbs`.
|
||||||
*
|
*
|
||||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||||
|
* @return Shared pointer to possibly a pruned HybridGaussianConditional
|
||||||
*/
|
*/
|
||||||
void prune(const DecisionTreeFactor &discreteProbs);
|
HybridGaussianConditional::shared_ptr prune(
|
||||||
|
const DecisionTreeFactor &discreteProbs) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
@ -230,21 +241,7 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
|
|
||||||
/// Private constructor that uses helper struct above.
|
/// Private constructor that uses helper struct above.
|
||||||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||||
const Helper &helper);
|
Helper &&helper);
|
||||||
|
|
||||||
/// Convert to a DecisionTree of Gaussian factor graphs.
|
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the pruner function from discrete probabilities.
|
|
||||||
*
|
|
||||||
* @param discreteProbs The probabilities of only discrete keys.
|
|
||||||
* @return std::function<GaussianConditional::shared_ptr(
|
|
||||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
*/
|
|
||||||
std::function<GaussianConditional::shared_ptr(
|
|
||||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
|
||||||
prunerFunc(const DecisionTreeFactor &prunedProbabilities);
|
|
||||||
|
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
@ -256,7 +253,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
||||||
void serialize(Archive &ar, const unsigned int /*version*/) {
|
void serialize(Archive &ar, const unsigned int /*version*/) {
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
||||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
ar &BOOST_SERIALIZATION_NVP(conditionals_);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -18,63 +18,26 @@
|
||||||
* @date Mar 12, 2022
|
* @date Mar 12, 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
|
|
||||||
const FactorValuePairs &factors) {
|
|
||||||
// Find the minimum value so we can "proselytize" to positive values.
|
|
||||||
// Done because we can't have sqrt of negative numbers.
|
|
||||||
Factors gaussianFactors;
|
|
||||||
AlgebraicDecisionTree<Key> valueTree;
|
|
||||||
std::tie(gaussianFactors, valueTree) = unzip(factors);
|
|
||||||
|
|
||||||
// Compute minimum value for normalization.
|
|
||||||
double min_value = valueTree.min();
|
|
||||||
|
|
||||||
// Finally, update the [A|b] matrices.
|
|
||||||
auto update = [&min_value](const GaussianFactorValuePair &gfv) {
|
|
||||||
auto [gf, value] = gfv;
|
|
||||||
|
|
||||||
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
|
|
||||||
if (!jf) return gf;
|
|
||||||
|
|
||||||
double normalized_value = value - min_value;
|
|
||||||
|
|
||||||
// If the value is 0, do nothing
|
|
||||||
if (normalized_value == 0.0) return gf;
|
|
||||||
|
|
||||||
GaussianFactorGraph gfg;
|
|
||||||
gfg.push_back(jf);
|
|
||||||
|
|
||||||
Vector c(1);
|
|
||||||
// When hiding c inside the `b` vector, value == 0.5*c^2
|
|
||||||
c << std::sqrt(2.0 * normalized_value);
|
|
||||||
auto constantFactor = std::make_shared<JacobianFactor>(c);
|
|
||||||
|
|
||||||
gfg.push_back(constantFactor);
|
|
||||||
return std::dynamic_pointer_cast<GaussianFactor>(
|
|
||||||
std::make_shared<JacobianFactor>(gfg));
|
|
||||||
};
|
|
||||||
return Factors(factors, update);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
struct HybridGaussianFactor::ConstructorHelper {
|
struct HybridGaussianFactor::ConstructorHelper {
|
||||||
KeyVector continuousKeys; // Continuous keys extracted from factors
|
KeyVector continuousKeys; // Continuous keys extracted from factors
|
||||||
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
DiscreteKeys discreteKeys; // Discrete keys provided to the constructors
|
||||||
FactorValuePairs pairs; // Used only if factorsTree is empty
|
FactorValuePairs pairs; // The decision tree with factors and scalars
|
||||||
Factors factorsTree;
|
|
||||||
|
|
||||||
|
/// Constructor for a single discrete key and a vector of Gaussian factors
|
||||||
ConstructorHelper(const DiscreteKey& discreteKey,
|
ConstructorHelper(const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr>& factors)
|
const std::vector<GaussianFactor::shared_ptr>& factors)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
|
|
@ -85,16 +48,22 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Build the FactorValuePairs DecisionTree
|
||||||
// Build the DecisionTree from the factor vector
|
pairs = FactorValuePairs(
|
||||||
factorsTree = Factors(discreteKeys, factors);
|
DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
|
||||||
|
[](const sharedFactor& f) {
|
||||||
|
return std::pair{f,
|
||||||
|
f ? 0.0 : std::numeric_limits<double>::infinity()};
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Constructor for a single discrete key and a vector of
|
||||||
|
/// GaussianFactorValuePairs
|
||||||
ConstructorHelper(const DiscreteKey& discreteKey,
|
ConstructorHelper(const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: discreteKeys({discreteKey}) {
|
: discreteKeys({discreteKey}) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
for (const auto &pair : factorPairs) {
|
for (const GaussianFactorValuePair& pair : factorPairs) {
|
||||||
if (pair.first && continuousKeys.empty()) {
|
if (pair.first && continuousKeys.empty()) {
|
||||||
continuousKeys = pair.first->keys();
|
continuousKeys = pair.first->keys();
|
||||||
break;
|
break;
|
||||||
|
|
@ -105,10 +74,13 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
pairs = FactorValuePairs(discreteKeys, factorPairs);
|
pairs = FactorValuePairs(discreteKeys, factorPairs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Constructor for a vector of discrete keys and a vector of
|
||||||
|
/// GaussianFactorValuePairs
|
||||||
ConstructorHelper(const DiscreteKeys& discreteKeys,
|
ConstructorHelper(const DiscreteKeys& discreteKeys,
|
||||||
const FactorValuePairs& factorPairs)
|
const FactorValuePairs& factorPairs)
|
||||||
: discreteKeys(discreteKeys) {
|
: discreteKeys(discreteKeys) {
|
||||||
// Extract continuous keys from the first non-null factor
|
// Extract continuous keys from the first non-null factor
|
||||||
|
// TODO: just stop after first non-null factor
|
||||||
factorPairs.visit([&](const GaussianFactorValuePair& pair) {
|
factorPairs.visit([&](const GaussianFactorValuePair& pair) {
|
||||||
if (pair.first && continuousKeys.empty()) {
|
if (pair.first && continuousKeys.empty()) {
|
||||||
continuousKeys = pair.first->keys();
|
continuousKeys = pair.first->keys();
|
||||||
|
|
@ -123,22 +95,18 @@ struct HybridGaussianFactor::ConstructorHelper {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
|
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper& helper)
|
||||||
: Base(helper.continuousKeys, helper.discreteKeys),
|
: Base(helper.continuousKeys, helper.discreteKeys),
|
||||||
factors_(helper.factorsTree.empty() ? augment(helper.pairs)
|
factors_(helper.pairs) {}
|
||||||
: helper.factorsTree) {}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey& discreteKey,
|
const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactor::shared_ptr>& factors)
|
const std::vector<GaussianFactor::shared_ptr>& factors)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(
|
HybridGaussianFactor::HybridGaussianFactor(
|
||||||
const DiscreteKey& discreteKey,
|
const DiscreteKey& discreteKey,
|
||||||
const std::vector<GaussianFactorValuePair>& factorPairs)
|
const std::vector<GaussianFactorValuePair>& factorPairs)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
|
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys& discreteKeys,
|
||||||
const FactorValuePairs& factors)
|
const FactorValuePairs& factors)
|
||||||
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
: HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
|
||||||
|
|
@ -153,18 +121,19 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
if (factors_.empty() ^ e->factors_.empty()) return false;
|
if (factors_.empty() ^ e->factors_.empty()) return false;
|
||||||
|
|
||||||
// Check the base and the factors:
|
// Check the base and the factors:
|
||||||
return Base::equals(*e, tol) &&
|
auto compareFunc = [tol](const GaussianFactorValuePair& pair1,
|
||||||
factors_.equals(e->factors_,
|
const GaussianFactorValuePair& pair2) {
|
||||||
[tol](const sharedFactor &f1, const sharedFactor &f2) {
|
auto f1 = pair1.first, f2 = pair2.first;
|
||||||
return f1->equals(*f2, tol);
|
bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
|
||||||
});
|
return match && gtsam::equal(pair1.second, pair2.second, tol);
|
||||||
|
};
|
||||||
|
return Base::equals(*e, tol) && factors_.equals(e->factors_, compareFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
void HybridGaussianFactor::print(const std::string& s,
|
void HybridGaussianFactor::print(const std::string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
std::cout << (s.empty() ? "" : s + "\n");
|
std::cout << (s.empty() ? "" : s + "\n");
|
||||||
std::cout << "HybridGaussianFactor" << std::endl;
|
|
||||||
HybridFactor::print("", formatter);
|
HybridFactor::print("", formatter);
|
||||||
std::cout << "{\n";
|
std::cout << "{\n";
|
||||||
if (factors_.empty()) {
|
if (factors_.empty()) {
|
||||||
|
|
@ -172,11 +141,12 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
} else {
|
} else {
|
||||||
factors_.print(
|
factors_.print(
|
||||||
"", [&](Key k) { return formatter(k); },
|
"", [&](Key k) { return formatter(k); },
|
||||||
[&](const sharedFactor &gf) -> std::string {
|
[&](const GaussianFactorValuePair& pair) -> std::string {
|
||||||
RedirectCout rd;
|
RedirectCout rd;
|
||||||
std::cout << ":\n";
|
std::cout << ":\n";
|
||||||
if (gf) {
|
if (pair.first) {
|
||||||
gf->print("", formatter);
|
pair.first->print("", formatter);
|
||||||
|
std::cout << "scalar: " << pair.second << "\n";
|
||||||
return rd.str();
|
return rd.str();
|
||||||
} else {
|
} else {
|
||||||
return "nullptr";
|
return "nullptr";
|
||||||
|
|
@ -187,62 +157,46 @@ void HybridGaussianFactor::print(const std::string &s,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
|
GaussianFactorValuePair HybridGaussianFactor::operator()(
|
||||||
const DiscreteValues& assignment) const {
|
const DiscreteValues& assignment) const {
|
||||||
return factors_(assignment);
|
return factors_(assignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree HybridGaussianFactor::add(
|
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||||
const GaussianFactorGraphTree &sum) const {
|
// Implemented by creating a new DecisionTree where:
|
||||||
using Y = GaussianFactorGraph;
|
// - The structure (keys and assignments) is preserved from factors_
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
// - Each leaf converted to a GaussianFactorGraph with just the factor and its
|
||||||
auto result = graph1;
|
// scalar.
|
||||||
result.push_back(graph2);
|
return {{factors_,
|
||||||
return result;
|
[](const GaussianFactorValuePair& pair)
|
||||||
};
|
-> std::pair<GaussianFactorGraph, double> {
|
||||||
const auto tree = asGaussianFactorGraphTree();
|
return {GaussianFactorGraph{pair.first}, pair.second};
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
}}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
inline static double PotentiallyPrunedComponentError(
|
||||||
const {
|
const GaussianFactorValuePair& pair, const VectorValues& continuousValues) {
|
||||||
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
|
return pair.first ? pair.first->error(continuousValues) + pair.second
|
||||||
return {factors_, wrap};
|
: std::numeric_limits<double>::infinity();
|
||||||
}
|
|
||||||
|
|
||||||
/* *******************************************************************************/
|
|
||||||
double HybridGaussianFactor::potentiallyPrunedComponentError(
|
|
||||||
const sharedFactor &gf, const VectorValues &values) const {
|
|
||||||
// Check if valid pointer
|
|
||||||
if (gf) {
|
|
||||||
return gf->error(values);
|
|
||||||
} else {
|
|
||||||
// If not valid, pointer, it means this component was pruned,
|
|
||||||
// so we return maximum error.
|
|
||||||
// This way the negative exponential will give
|
|
||||||
// a probability value close to 0.0.
|
|
||||||
return std::numeric_limits<double>::max();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
const VectorValues& continuousValues) const {
|
const VectorValues& continuousValues) const {
|
||||||
// functor to convert from sharedFactor to double error value.
|
// functor to convert from sharedFactor to double error value.
|
||||||
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
|
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
||||||
return this->potentiallyPrunedComponentError(gf, continuousValues);
|
return PotentiallyPrunedComponentError(pair, continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
return {factors_, errorFunc};
|
||||||
return error_tree;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double HybridGaussianFactor::error(const HybridValues& values) const {
|
double HybridGaussianFactor::error(const HybridValues& values) const {
|
||||||
// Directly index to get the component, no need to build the whole tree.
|
// Directly index to get the component, no need to build the whole tree.
|
||||||
const sharedFactor gf = factors_(values.discrete());
|
const GaussianFactorValuePair pair = factors_(values.discrete());
|
||||||
return potentiallyPrunedComponentError(gf, values.continuous());
|
return PotentiallyPrunedComponentError(pair, values.continuous());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
|
@ -66,12 +67,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
|
||||||
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
|
||||||
/// typedef for Decision Tree of Gaussian factors.
|
|
||||||
using Factors = DecisionTree<Key, sharedFactor>;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Decision tree of Gaussian factors indexed by discrete keys.
|
/// Decision tree of Gaussian factors indexed by discrete keys.
|
||||||
Factors factors_;
|
FactorValuePairs factors_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
|
@ -110,10 +109,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
|
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
|
||||||
*
|
*
|
||||||
* @param discreteKeys Discrete variables and their cardinalities.
|
* @param discreteKeys Discrete variables and their cardinalities.
|
||||||
* @param factors The decision tree of Gaussian factor/scalar pairs.
|
* @param factorPairs The decision tree of Gaussian factor/scalar pairs.
|
||||||
*/
|
*/
|
||||||
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
|
||||||
const FactorValuePairs &factors);
|
const FactorValuePairs &factorPairs);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
|
|
@ -129,17 +128,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/// Get factor at a given discrete assignment.
|
/// Get factor at a given discrete assignment.
|
||||||
sharedFactor operator()(const DiscreteValues &assignment) const;
|
GaussianFactorValuePair operator()(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
|
||||||
* maintaining the original tree structure.
|
|
||||||
*
|
|
||||||
* @param sum Decision Tree of Gaussian Factor Graphs indexed by the
|
|
||||||
* variables.
|
|
||||||
* @return Sum
|
|
||||||
*/
|
|
||||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute error of the HybridGaussianFactor as a tree.
|
* @brief Compute error of the HybridGaussianFactor as a tree.
|
||||||
|
|
@ -158,24 +147,16 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
/// Getter for GaussianFactor decision tree
|
/// Getter for GaussianFactor decision tree
|
||||||
const Factors &factors() const { return factors_; }
|
const FactorValuePairs &factors() const { return factors_; }
|
||||||
|
|
||||||
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
|
|
||||||
friend GaussianFactorGraphTree &operator+=(
|
|
||||||
GaussianFactorGraphTree &sum, const HybridGaussianFactor &factor) {
|
|
||||||
sum = factor.add(sum);
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
/// @}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to return factors and functional to create a
|
* @brief Helper function to return factors and functional to create a
|
||||||
* DecisionTree of Gaussian Factor Graphs.
|
* DecisionTree of Gaussian Factor Graphs.
|
||||||
*
|
*
|
||||||
* @return GaussianFactorGraphTree
|
* @return HybridGaussianProductFactor
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
virtual HybridGaussianProductFactor asProductFactor() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/**
|
/**
|
||||||
|
|
@ -184,14 +165,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
* value in the `b` vector as an additional row.
|
* value in the `b` vector as an additional row.
|
||||||
*
|
*
|
||||||
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
|
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
|
||||||
* Gaussian factor in factors.
|
* @return FactorValuePairs
|
||||||
* @return HybridGaussianFactor::Factors
|
|
||||||
*/
|
*/
|
||||||
static Factors augment(const FactorValuePairs &factors);
|
static FactorValuePairs augment(const FactorValuePairs &factors);
|
||||||
|
|
||||||
/// Helper method to compute the error of a component.
|
|
||||||
double potentiallyPrunedComponentError(
|
|
||||||
const sharedFactor &gf, const VectorValues &continuousValues) const;
|
|
||||||
|
|
||||||
/// Helper struct to assist private constructor below.
|
/// Helper struct to assist private constructor below.
|
||||||
struct ConstructorHelper;
|
struct ConstructorHelper;
|
||||||
|
|
|
||||||
|
|
@ -20,9 +20,12 @@
|
||||||
|
|
||||||
#include <gtsam/base/utilities.h>
|
#include <gtsam/base/utilities.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
#include <gtsam/discrete/DiscreteEliminationTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
#include <gtsam/discrete/DiscreteJunctionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridEliminationTree.h>
|
#include <gtsam/hybrid/HybridEliminationTree.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
|
|
@ -39,10 +42,8 @@
|
||||||
#include <gtsam/linear/HessianFactor.h>
|
#include <gtsam/linear/HessianFactor.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -53,9 +54,24 @@ namespace gtsam {
|
||||||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
|
using std::dynamic_pointer_cast;
|
||||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||||
|
|
||||||
using std::dynamic_pointer_cast;
|
/// Result from elimination.
|
||||||
|
struct Result {
|
||||||
|
GaussianConditional::shared_ptr conditional;
|
||||||
|
double negLogK;
|
||||||
|
GaussianFactor::shared_ptr factor;
|
||||||
|
double scalar;
|
||||||
|
|
||||||
|
bool operator==(const Result &other) const {
|
||||||
|
return conditional == other.conditional && negLogK == other.negLogK &&
|
||||||
|
factor == other.factor && scalar == other.scalar;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
using ResultTree = DecisionTree<Key, Result>;
|
||||||
|
|
||||||
|
static const VectorValues kEmpty;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// Throw a runtime exception for method specified in string s, and factor f:
|
// Throw a runtime exception for method specified in string s, and factor f:
|
||||||
|
|
@ -74,6 +90,61 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
|
||||||
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
static void printFactor(const std::shared_ptr<Factor> &factor,
|
||||||
|
const DiscreteValues &assignment,
|
||||||
|
const KeyFormatter &keyFormatter) {
|
||||||
|
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
||||||
|
if (assignment.empty()) {
|
||||||
|
hgf->print("HybridGaussianFactor:", keyFormatter);
|
||||||
|
} else {
|
||||||
|
hgf->operator()(assignment)
|
||||||
|
.first->print("HybridGaussianFactor, component:", keyFormatter);
|
||||||
|
}
|
||||||
|
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
factor->print("GaussianFactor:\n", keyFormatter);
|
||||||
|
|
||||||
|
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
|
factor->print("DiscreteFactor:\n", keyFormatter);
|
||||||
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(factor)) {
|
||||||
|
if (hc->isContinuous()) {
|
||||||
|
factor->print("GaussianConditional:\n", keyFormatter);
|
||||||
|
} else if (hc->isDiscrete()) {
|
||||||
|
factor->print("DiscreteConditional:\n", keyFormatter);
|
||||||
|
} else {
|
||||||
|
if (assignment.empty()) {
|
||||||
|
hc->print("HybridConditional:", keyFormatter);
|
||||||
|
} else {
|
||||||
|
hc->asHybrid()
|
||||||
|
->choose(assignment)
|
||||||
|
->print("HybridConditional, component:\n", keyFormatter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
factor->print("Unknown factor type\n", keyFormatter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
void HybridGaussianFactorGraph::print(const std::string &s,
|
||||||
|
const KeyFormatter &keyFormatter) const {
|
||||||
|
std::cout << (s.empty() ? "" : s + " ") << std::endl;
|
||||||
|
std::cout << "size: " << size() << std::endl;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
|
auto &&factor = factors_[i];
|
||||||
|
if (factor == nullptr) {
|
||||||
|
std::cout << "Factor " << i << ": nullptr\n";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Print the factor
|
||||||
|
std::cout << "Factor " << i << "\n";
|
||||||
|
printFactor(factor, {}, keyFormatter);
|
||||||
|
std::cout << "\n";
|
||||||
|
}
|
||||||
|
std::cout.flush();
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
void HybridGaussianFactorGraph::printErrors(
|
void HybridGaussianFactorGraph::printErrors(
|
||||||
const HybridValues &values, const std::string &str,
|
const HybridValues &values, const std::string &str,
|
||||||
|
|
@ -83,111 +154,46 @@ void HybridGaussianFactorGraph::printErrors(
|
||||||
&printCondition) const {
|
&printCondition) const {
|
||||||
std::cout << str << "size: " << size() << std::endl << std::endl;
|
std::cout << str << "size: " << size() << std::endl << std::endl;
|
||||||
|
|
||||||
std::stringstream ss;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < factors_.size(); i++) {
|
for (size_t i = 0; i < factors_.size(); i++) {
|
||||||
auto &&factor = factors_[i];
|
auto &&factor = factors_[i];
|
||||||
std::cout << "Factor " << i << ": ";
|
|
||||||
|
|
||||||
// Clear the stringstream
|
|
||||||
ss.str(std::string());
|
|
||||||
|
|
||||||
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
|
|
||||||
if (factor == nullptr) {
|
if (factor == nullptr) {
|
||||||
std::cout << "nullptr"
|
std::cout << "Factor " << i << ": nullptr\n";
|
||||||
<< "\n";
|
continue;
|
||||||
} else {
|
|
||||||
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << factor->error(values) << std::endl;
|
|
||||||
}
|
}
|
||||||
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
|
const double errorValue = factor->error(values);
|
||||||
if (factor == nullptr) {
|
|
||||||
std::cout << "nullptr"
|
|
||||||
<< "\n";
|
|
||||||
} else {
|
|
||||||
if (hc->isContinuous()) {
|
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
|
|
||||||
} else if (hc->isDiscrete()) {
|
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
|
|
||||||
<< "\n";
|
|
||||||
} else {
|
|
||||||
// Is hybrid
|
|
||||||
auto conditionalComponent =
|
|
||||||
hc->asHybrid()->operator()(values.discrete());
|
|
||||||
conditionalComponent->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << conditionalComponent->error(values)
|
|
||||||
<< "\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
|
||||||
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
|
|
||||||
if (!printCondition(factor.get(), errorValue, i))
|
if (!printCondition(factor.get(), errorValue, i))
|
||||||
continue; // User-provided filter did not pass
|
continue; // User-provided filter did not pass
|
||||||
|
|
||||||
if (factor == nullptr) {
|
// Print the factor
|
||||||
std::cout << "nullptr"
|
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
|
||||||
<< "\n";
|
printFactor(factor, values.discrete(), keyFormatter);
|
||||||
} else {
|
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << errorValue << "\n";
|
|
||||||
}
|
|
||||||
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
|
||||||
if (factor == nullptr) {
|
|
||||||
std::cout << "nullptr"
|
|
||||||
<< "\n";
|
|
||||||
} else {
|
|
||||||
factor->print(ss.str(), keyFormatter);
|
|
||||||
std::cout << "error = " << df->error(values.discrete()) << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
}
|
}
|
||||||
std::cout.flush();
|
std::cout.flush();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianFactorGraphTree addGaussian(
|
HybridGaussianProductFactor HybridGaussianFactorGraph::collectProductFactor()
|
||||||
const GaussianFactorGraphTree &gfgTree,
|
const {
|
||||||
const GaussianFactor::shared_ptr &factor) {
|
HybridGaussianProductFactor result;
|
||||||
// If the decision tree is not initialized, then initialize it.
|
|
||||||
if (gfgTree.empty()) {
|
|
||||||
GaussianFactorGraph result{factor};
|
|
||||||
return GaussianFactorGraphTree(result);
|
|
||||||
} else {
|
|
||||||
auto add = [&factor](const GaussianFactorGraph &graph) {
|
|
||||||
auto result = graph;
|
|
||||||
result.push_back(factor);
|
|
||||||
return result;
|
|
||||||
};
|
|
||||||
return gfgTree.apply(add);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
|
||||||
// keys, and then loop over all assignments to populate a vector.
|
|
||||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|
||||||
GaussianFactorGraphTree result;
|
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
// TODO(dellaert): can we make this cleaner and less error-prone?
|
||||||
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
result = addGaussian(result, gf);
|
continue; // Ignore OrphanWrapper
|
||||||
|
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||||
|
result += gf;
|
||||||
|
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
|
||||||
|
result += gc;
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
result = gmf->add(result);
|
result += *gmf;
|
||||||
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
||||||
result = gm->add(result);
|
result += *gm; // handled above already?
|
||||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
if (auto gm = hc->asHybrid()) {
|
if (auto gm = hc->asHybrid()) {
|
||||||
result = gm->add(result);
|
result += *gm;
|
||||||
} else if (auto g = hc->asGaussian()) {
|
} else if (auto g = hc->asGaussian()) {
|
||||||
result = addGaussian(result, g);
|
result += g;
|
||||||
} else {
|
} else {
|
||||||
// Has to be discrete.
|
// Has to be discrete.
|
||||||
// TODO(dellaert): in C++20, we can use std::visit.
|
// TODO(dellaert): in C++20, we can use std::visit.
|
||||||
|
|
@ -200,7 +206,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
} else {
|
} else {
|
||||||
// TODO(dellaert): there was an unattributed comment here: We need to
|
// TODO(dellaert): there was an unattributed comment here: We need to
|
||||||
// handle the case where the object is actually an BayesTreeOrphanWrapper!
|
// handle the case where the object is actually an BayesTreeOrphanWrapper!
|
||||||
throwRuntimeError("gtsam::assembleGraphTree", f);
|
throwRuntimeError("gtsam::collectProductFactor", f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -233,21 +239,19 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
/**
|
/**
|
||||||
* @brief Exponentiate (not necessarily normalized) negative log-values,
|
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
||||||
* normalize, and then return as AlgebraicDecisionTree<Key>.
|
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||||
*
|
*
|
||||||
* @param logValues DecisionTree of (unnormalized) log values.
|
* @param errors DecisionTree of (unnormalized) errors.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return DecisionTreeFactor::shared_ptr
|
||||||
*/
|
*/
|
||||||
static AlgebraicDecisionTree<Key> probabilitiesFromNegativeLogValues(
|
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
const AlgebraicDecisionTree<Key> &logValues) {
|
const DiscreteKeys &discreteKeys,
|
||||||
// Perform normalization
|
const AlgebraicDecisionTree<Key> &errors) {
|
||||||
double min_log = logValues.min();
|
double min_log = errors.min();
|
||||||
AlgebraicDecisionTree<Key> probabilities = DecisionTree<Key, double>(
|
AlgebraicDecisionTree<Key> potentials(
|
||||||
logValues, [&min_log](const double x) { return exp(-(x - min_log)); });
|
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||||
probabilities = probabilities.normalize(probabilities.sum());
|
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||||
|
|
||||||
return probabilities;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -261,19 +265,15 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
dfg.push_back(df);
|
dfg.push_back(df);
|
||||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
// Case where we have a HybridGaussianFactor with no continuous keys.
|
// Case where we have a HybridGaussianFactor with no continuous keys.
|
||||||
// In this case, compute discrete probabilities.
|
// In this case, compute a discrete factor from the remaining error.
|
||||||
auto logProbability =
|
auto calculateError = [&](const auto &pair) -> double {
|
||||||
[&](const GaussianFactor::shared_ptr &factor) -> double {
|
auto [factor, scalar] = pair;
|
||||||
if (!factor) return 0.0;
|
// If factor is null, it has been pruned, hence return infinite error
|
||||||
return factor->error(VectorValues());
|
if (!factor) return std::numeric_limits<double>::infinity();
|
||||||
|
return scalar + factor->error(kEmpty);
|
||||||
};
|
};
|
||||||
AlgebraicDecisionTree<Key> logProbabilities =
|
AlgebraicDecisionTree<Key> errors(gmf->factors(), calculateError);
|
||||||
DecisionTree<Key, double>(gmf->factors(), logProbability);
|
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> probabilities =
|
|
||||||
probabilitiesFromNegativeLogValues(logProbabilities);
|
|
||||||
dfg.emplace_shared<DecisionTreeFactor>(gmf->discreteKeys(),
|
|
||||||
probabilities);
|
|
||||||
|
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
// Ignore orphaned clique.
|
// Ignore orphaned clique.
|
||||||
|
|
@ -294,24 +294,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
|
||||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
|
||||||
// otherwise create a GFG with a single (null) factor,
|
|
||||||
// which doesn't register as null.
|
|
||||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
|
||||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
|
||||||
bool hasNull =
|
|
||||||
std::any_of(graph.begin(), graph.end(),
|
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
|
||||||
return hasNull ? GaussianFactorGraph() : graph;
|
|
||||||
};
|
|
||||||
return GaussianFactorGraphTree(sum, emptyGaussian);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
|
||||||
HybridGaussianFactor::sharedFactor>;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
|
* Compute the probability p(μ;m) = exp(-error(μ;m)) * sqrt(det(2π Σ_m)
|
||||||
* from the residual error ||b||^2 at the mean μ.
|
* from the residual error ||b||^2 at the mean μ.
|
||||||
|
|
@ -319,46 +301,42 @@ using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
||||||
* depends on the discrete separator if present.
|
* depends on the discrete separator if present.
|
||||||
*/
|
*/
|
||||||
static std::shared_ptr<Factor> createDiscreteFactor(
|
static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
const DecisionTree<Key, Result> &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
auto negLogProbability = [&](const Result &pair) -> double {
|
auto calculateError = [&](const Result &result) -> double {
|
||||||
const auto &[conditional, factor] = pair;
|
if (result.conditional && result.factor) {
|
||||||
static const VectorValues kEmpty;
|
// `error` has the following contributions:
|
||||||
// If the factor is not null, it has no keys, just contains the residual.
|
// - the scalar is the sum of all mode-dependent constants
|
||||||
if (!factor) return 1.0; // TODO(dellaert): not loving this.
|
// - factor->error(kempty) is the error remaining after elimination
|
||||||
|
// - negLogK is what is given to the conditional to normalize
|
||||||
// Negative logspace version of:
|
return result.scalar + result.factor->error(kEmpty) - result.negLogK;
|
||||||
// exp(-factor->error(kEmpty)) / conditional->normalizationConstant();
|
} else if (!result.conditional && !result.factor) {
|
||||||
// negLogConstant gives `-log(k)`
|
// If the factor has been pruned, return infinite error
|
||||||
// which is `-log(k) = log(1/k) = log(\sqrt{|2πΣ|})`.
|
return std::numeric_limits<double>::infinity();
|
||||||
return factor->error(kEmpty) - conditional->negLogConstant();
|
} else {
|
||||||
|
throw std::runtime_error("createDiscreteFactor has mixed NULLs");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> negLogProbabilities(
|
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
|
||||||
DecisionTree<Key, double>(eliminationResults, negLogProbability));
|
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||||
AlgebraicDecisionTree<Key> probabilities =
|
|
||||||
probabilitiesFromNegativeLogValues(negLogProbabilities);
|
|
||||||
|
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteSeparator, probabilities);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
// Create HybridGaussianFactor on the separator, taking care to correct
|
// Create HybridGaussianFactor on the separator, taking care to correct
|
||||||
// for conditional constants.
|
// for conditional constants.
|
||||||
static std::shared_ptr<Factor> createHybridGaussianFactor(
|
static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
const DecisionTree<Key, Result> &eliminationResults,
|
const ResultTree &eliminationResults,
|
||||||
const DiscreteKeys &discreteSeparator) {
|
const DiscreteKeys &discreteSeparator) {
|
||||||
// Correct for the normalization constant used up by the conditional
|
// Correct for the normalization constant used up by the conditional
|
||||||
auto correct = [&](const Result &pair) -> GaussianFactorValuePair {
|
auto correct = [&](const Result &result) -> GaussianFactorValuePair {
|
||||||
const auto &[conditional, factor] = pair;
|
if (result.conditional && result.factor) {
|
||||||
if (factor) {
|
return {result.factor, result.scalar - result.negLogK};
|
||||||
auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
|
} else if (!result.conditional && !result.factor) {
|
||||||
if (!hf) throw std::runtime_error("Expected HessianFactor!");
|
return {nullptr, std::numeric_limits<double>::infinity()};
|
||||||
// Add 2.0 term since the constant term will be premultiplied by 0.5
|
} else {
|
||||||
// as per the Hessian definition,
|
throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
|
||||||
// and negative since we want log(k)
|
|
||||||
hf->constantTerm() += -2.0 * conditional->negLogConstant();
|
|
||||||
}
|
}
|
||||||
return {factor, 0.0};
|
|
||||||
};
|
};
|
||||||
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
||||||
correct);
|
correct);
|
||||||
|
|
@ -366,42 +344,56 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
||||||
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
return std::make_shared<HybridGaussianFactor>(discreteSeparator, newFactors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
/// Get the discrete keys from the HybridGaussianFactorGraph as DiscreteKeys.
|
||||||
|
static auto GetDiscreteKeys =
|
||||||
|
[](const HybridGaussianFactorGraph &hfg) -> DiscreteKeys {
|
||||||
|
const std::set<DiscreteKey> discreteKeySet = hfg.discreteKeys();
|
||||||
|
return {discreteKeySet.begin(), discreteKeySet.end()};
|
||||||
|
};
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
std::pair<HybridConditional::shared_ptr, std::shared_ptr<Factor>>
|
||||||
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
// Since we eliminate all continuous variables first,
|
// Since we eliminate all continuous variables first,
|
||||||
// the discrete separator will be *all* the discrete keys.
|
// the discrete separator will be *all* the discrete keys.
|
||||||
const std::set<DiscreteKey> keysForDiscreteVariables = discreteKeys();
|
DiscreteKeys discreteSeparator = GetDiscreteKeys(*this);
|
||||||
DiscreteKeys discreteSeparator(keysForDiscreteVariables.begin(),
|
|
||||||
keysForDiscreteVariables.end());
|
|
||||||
|
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved.
|
// decision tree indexed by all discrete keys involved. Just like any hybrid
|
||||||
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
|
// factor, every assignment also has a scalar error, in this case the sum of
|
||||||
|
// all errors in the graph. This error is assignment-specific and accounts for
|
||||||
|
// any difference in noise models used.
|
||||||
|
HybridGaussianProductFactor productFactor = collectProductFactor();
|
||||||
|
|
||||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
// Check if a factor is null
|
||||||
// This is done after assembly since it is non-trivial to keep track of which
|
auto isNull = [](const GaussianFactor::shared_ptr &ptr) { return !ptr; };
|
||||||
// FG has a nullptr as we're looping over the factors.
|
|
||||||
factorGraphTree = removeEmpty(factorGraphTree);
|
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
bool someContinuousLeft = false;
|
bool someContinuousLeft = false;
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> Result {
|
auto eliminate =
|
||||||
if (graph.empty()) {
|
[&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
|
||||||
return {nullptr, nullptr};
|
const auto &[graph, scalar] = pair;
|
||||||
|
|
||||||
|
// If any product contains a pruned factor, prune it here. Done here as it's
|
||||||
|
// non non-trivial to do within collectProductFactor.
|
||||||
|
if (graph.empty() || std::any_of(graph.begin(), graph.end(), isNull)) {
|
||||||
|
return {nullptr, 0.0, nullptr, 0.0};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Expensive elimination of product factor.
|
// Expensive elimination of product factor.
|
||||||
auto result = EliminatePreferCholesky(graph, keys);
|
auto [conditional, factor] =
|
||||||
|
EliminatePreferCholesky(graph, keys); /// <<<<<< MOST COMPUTE IS HERE
|
||||||
|
|
||||||
// Record whether there any continuous variables left
|
// Record whether there any continuous variables left
|
||||||
someContinuousLeft |= !result.second->empty();
|
someContinuousLeft |= !factor->empty();
|
||||||
|
|
||||||
return result;
|
// We pass on the scalar unmodified.
|
||||||
|
return {conditional, conditional->negLogConstant(), factor, scalar};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
const ResultTree eliminationResults(productFactor, eliminate);
|
||||||
|
|
||||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||||
|
|
@ -411,11 +403,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
||||||
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
? createHybridGaussianFactor(eliminationResults, discreteSeparator)
|
||||||
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
: createDiscreteFactor(eliminationResults, discreteSeparator);
|
||||||
|
|
||||||
// Create the HybridGaussianConditional from the conditionals
|
// Create the HybridGaussianConditional without re-calculating constants:
|
||||||
HybridGaussianConditional::Conditionals conditionals(
|
HybridGaussianConditional::FactorValuePairs pairs(
|
||||||
eliminationResults, [](const Result &pair) { return pair.first; });
|
eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
|
||||||
auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
|
return {result.conditional, result.negLogK};
|
||||||
discreteSeparator, conditionals);
|
});
|
||||||
|
auto hybridGaussian =
|
||||||
|
std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
|
||||||
|
|
||||||
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
|
||||||
}
|
}
|
||||||
|
|
@ -496,6 +490,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
} else if (hybrid_factor->isHybrid()) {
|
} else if (hybrid_factor->isHybrid()) {
|
||||||
only_continuous = false;
|
only_continuous = false;
|
||||||
only_discrete = false;
|
only_discrete = false;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
} else if (auto cont_factor =
|
} else if (auto cont_factor =
|
||||||
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
std::dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
|
@ -523,22 +518,23 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
// Iterate over each factor.
|
// Iterate over each factor.
|
||||||
for (auto &factor : factors_) {
|
for (auto &factor : factors_) {
|
||||||
if (auto f = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
if (auto hf = std::dynamic_pointer_cast<HybridFactor>(factor)) {
|
||||||
// Check for HybridFactor, and call errorTree
|
// Add errorTree for hybrid factors, includes HybridGaussianConditionals!
|
||||||
error_tree = error_tree + f->errorTree(continuousValues);
|
result = result + hf->errorTree(continuousValues);
|
||||||
} else if (auto f = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
// Skip discrete factors
|
// If discrete, just add its errorTree as well
|
||||||
continue;
|
result = result + df->errorTree();
|
||||||
|
} else if (auto gf = dynamic_pointer_cast<GaussianFactor>(factor)) {
|
||||||
|
// For a continuous only factor, just add its error
|
||||||
|
result = result + gf->error(continuousValues);
|
||||||
} else {
|
} else {
|
||||||
// Everything else is a continuous only factor
|
throwRuntimeError("HybridGaussianFactorGraph::errorTree", factor);
|
||||||
HybridValues hv(continuousValues, DiscreteValues());
|
|
||||||
error_tree = error_tree + AlgebraicDecisionTree<Key>(factor->error(hv));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return error_tree;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -549,18 +545,18 @@ double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
|
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::discretePosterior(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->errorTree(continuousValues);
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
|
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||||
// NOTE: The 0.5 term is handled by each factor
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
return exp(-error);
|
return exp(-error);
|
||||||
});
|
});
|
||||||
return prob_tree;
|
return p / p.sum();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
GaussianFactorGraph HybridGaussianFactorGraph::operator()(
|
GaussianFactorGraph HybridGaussianFactorGraph::choose(
|
||||||
const DiscreteValues &assignment) const {
|
const DiscreteValues &assignment) const {
|
||||||
GaussianFactorGraph gfg;
|
GaussianFactorGraph gfg;
|
||||||
for (auto &&f : *this) {
|
for (auto &&f : *this) {
|
||||||
|
|
@ -569,9 +565,14 @@ GaussianFactorGraph HybridGaussianFactorGraph::operator()(
|
||||||
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
|
} else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
|
||||||
gfg.push_back(gf);
|
gfg.push_back(gf);
|
||||||
} else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
} else if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||||
gfg.push_back((*hgf)(assignment));
|
gfg.push_back((*hgf)(assignment).first);
|
||||||
} else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
} else if (auto hgc = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
||||||
gfg.push_back((*hgc)(assignment));
|
gfg.push_back((*hgc)(assignment));
|
||||||
|
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
|
if (auto gc = hc->asGaussian())
|
||||||
|
gfg.push_back(gc);
|
||||||
|
else if (auto hgc = hc->asHybrid())
|
||||||
|
gfg.push_back((*hgc)(assignment));
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
|
@ -131,6 +132,14 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
explicit HybridGaussianFactorGraph(const CONTAINER& factors)
|
explicit HybridGaussianFactorGraph(const CONTAINER& factors)
|
||||||
: Base(factors) {}
|
: Base(factors) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from an initializer lists of GaussianFactor shared pointers.
|
||||||
|
* Example:
|
||||||
|
* HybridGaussianFactorGraph graph = { factor1, factor2, factor3 };
|
||||||
|
*/
|
||||||
|
HybridGaussianFactorGraph(std::initializer_list<sharedFactor> factors)
|
||||||
|
: Base(factors) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implicit copy/downcast constructor to override explicit template container
|
* Implicit copy/downcast constructor to override explicit template container
|
||||||
* constructor. In BayesTree this is used for:
|
* constructor. In BayesTree this is used for:
|
||||||
|
|
@ -144,10 +153,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// TODO(dellaert): customize print and equals.
|
void print(
|
||||||
// void print(
|
const std::string& s = "HybridGaussianFactorGraph",
|
||||||
// const std::string& s = "HybridGaussianFactorGraph",
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
||||||
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Print the errors of each factor in the hybrid factor graph.
|
* @brief Print the errors of each factor in the hybrid factor graph.
|
||||||
|
|
@ -187,17 +195,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
AlgebraicDecisionTree<Key> errorTree(
|
AlgebraicDecisionTree<Key> errorTree(
|
||||||
const VectorValues& continuousValues) const;
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
|
|
||||||
* for each discrete assignment, and return as a tree.
|
|
||||||
*
|
|
||||||
* @param continuousValues Continuous values at which to compute the
|
|
||||||
* probability.
|
|
||||||
* @return AlgebraicDecisionTree<Key>
|
|
||||||
*/
|
|
||||||
AlgebraicDecisionTree<Key> probPrime(
|
|
||||||
const VectorValues& continuousValues) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute the unnormalized posterior probability for a continuous
|
* @brief Compute the unnormalized posterior probability for a continuous
|
||||||
* vector values given a specific assignment.
|
* vector values given a specific assignment.
|
||||||
|
|
@ -206,6 +203,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
*/
|
*/
|
||||||
double probPrime(const HybridValues& values) const;
|
double probPrime(const HybridValues& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||||
|
* This is efficient as this simply probPrime normalized.
|
||||||
|
*
|
||||||
|
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||||
|
* which we would need, are hard to recover.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x to condition on.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
|
const VectorValues& continuousValues) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
* @brief Create a decision tree of factor graphs out of this hybrid factor
|
||||||
* graph.
|
* graph.
|
||||||
|
|
@ -215,7 +225,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* one for A and one for B. The leaves of the tree will be the Gaussian
|
* one for A and one for B. The leaves of the tree will be the Gaussian
|
||||||
* factors that have only continuous keys.
|
* factors that have only continuous keys.
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree assembleGraphTree() const;
|
HybridGaussianProductFactor collectProductFactor() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Eliminate the given continuous keys.
|
* @brief Eliminate the given continuous keys.
|
||||||
|
|
@ -227,8 +237,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
eliminate(const Ordering& keys) const;
|
eliminate(const Ordering& keys) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
/// Get the GaussianFactorGraph at a given discrete assignment.
|
/**
|
||||||
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
|
@brief Get the GaussianFactorGraph at a given discrete assignment. Note this
|
||||||
|
* corresponds to the Gaussian posterior p(X|M=m, Z=z) of the continuous
|
||||||
|
* variables X given the discrete assignment M=m and whatever measurements z
|
||||||
|
* where assumed in the creation of the factor Graph.
|
||||||
|
*
|
||||||
|
* @note Be careful, as any factors not Gaussian are ignored.
|
||||||
|
*
|
||||||
|
* @param assignment The discrete value assignment for the discrete keys.
|
||||||
|
* @return Gaussian factors as a GaussianFactorGraph
|
||||||
|
*/
|
||||||
|
GaussianFactorGraph choose(const DiscreteValues& assignment) const;
|
||||||
|
|
||||||
|
/// Syntactic sugar for choose
|
||||||
|
GaussianFactorGraph operator()(const DiscreteValues& assignment) const {
|
||||||
|
return choose(assignment);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridGaussianFactorGraph>
|
||||||
|
: public Testable<HybridGaussianFactorGraph> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file HybridGaussianISAM.h
|
* @file HybridGaussianISAM.cpp
|
||||||
* @date March 31, 2022
|
* @date March 31, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file HybridGaussianProductFactor.h
|
||||||
|
* @date Oct 2, 2024
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/types.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
using Y = GaussianFactorGraphValuePair;
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
static Y add(const Y& y1, const Y& y2) {
|
||||||
|
GaussianFactorGraph result = y1.first;
|
||||||
|
result.push_back(y2.first);
|
||||||
|
return {result, y1.second + y2.second};
|
||||||
|
};
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor& a,
|
||||||
|
const HybridGaussianProductFactor& b) {
|
||||||
|
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||||
|
const HybridGaussianFactor& factor) const {
|
||||||
|
return *this + factor.asProductFactor();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||||
|
const GaussianFactor::shared_ptr& factor) const {
|
||||||
|
return *this + HybridGaussianProductFactor(factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
||||||
|
const GaussianFactor::shared_ptr& factor) {
|
||||||
|
*this = *this + factor;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor& HybridGaussianProductFactor::operator+=(
|
||||||
|
const HybridGaussianFactor& factor) {
|
||||||
|
*this = *this + factor;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
void HybridGaussianProductFactor::print(const std::string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
KeySet keys;
|
||||||
|
auto printer = [&](const Y& y) {
|
||||||
|
if (keys.empty()) keys = y.first.keys();
|
||||||
|
return "Graph of size " + std::to_string(y.first.size()) +
|
||||||
|
", scalar sum: " + std::to_string(y.second);
|
||||||
|
};
|
||||||
|
Base::print(s, formatter, printer);
|
||||||
|
if (!keys.empty()) {
|
||||||
|
std::cout << s << " Keys:";
|
||||||
|
for (auto&& key : keys) std::cout << " " << formatter(key);
|
||||||
|
std::cout << "." << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
bool HybridGaussianProductFactor::equals(
|
||||||
|
const HybridGaussianProductFactor& other, double tol) const {
|
||||||
|
return Base::equals(other, [tol](const Y& a, const Y& b) {
|
||||||
|
return a.first.equals(b.first, tol) && std::abs(a.second - b.second) < tol;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
||||||
|
auto emptyGaussian = [](const Y& y) {
|
||||||
|
bool hasNull =
|
||||||
|
std::any_of(y.first.begin(), y.first.end(),
|
||||||
|
[](const GaussianFactor::shared_ptr& ptr) { return !ptr; });
|
||||||
|
return hasNull ? Y{GaussianFactorGraph(), 0.0} : y;
|
||||||
|
};
|
||||||
|
return {Base(*this, emptyGaussian)};
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
std::istream& operator>>(std::istream& is, GaussianFactorGraphValuePair& pair) {
|
||||||
|
// Dummy, don't do anything
|
||||||
|
return is;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,147 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file HybridGaussianProductFactor.h
|
||||||
|
* @date Oct 2, 2024
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
class HybridGaussianFactor;
|
||||||
|
|
||||||
|
using GaussianFactorGraphValuePair = std::pair<GaussianFactorGraph, double>;
|
||||||
|
|
||||||
|
/// Alias for DecisionTree of GaussianFactorGraphs and their scalar sums
|
||||||
|
class GTSAM_EXPORT HybridGaussianProductFactor
|
||||||
|
: public DecisionTree<Key, GaussianFactorGraphValuePair> {
|
||||||
|
public:
|
||||||
|
using Base = DecisionTree<Key, GaussianFactorGraphValuePair>;
|
||||||
|
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor
|
||||||
|
HybridGaussianProductFactor() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from a single factor
|
||||||
|
* @tparam FACTOR Factor type
|
||||||
|
* @param factor Shared pointer to the factor
|
||||||
|
*/
|
||||||
|
template <class FACTOR>
|
||||||
|
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor)
|
||||||
|
: Base(GaussianFactorGraphValuePair{GaussianFactorGraph{factor}, 0.0}) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct from DecisionTree
|
||||||
|
* @param tree Decision tree to construct from
|
||||||
|
*/
|
||||||
|
HybridGaussianProductFactor(Base&& tree) : Base(std::move(tree)) {}
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
/// @name Operators
|
||||||
|
///@{
|
||||||
|
|
||||||
|
/// Add GaussianFactor into HybridGaussianProductFactor
|
||||||
|
HybridGaussianProductFactor operator+(
|
||||||
|
const GaussianFactor::shared_ptr& factor) const;
|
||||||
|
|
||||||
|
/// Add HybridGaussianFactor into HybridGaussianProductFactor
|
||||||
|
HybridGaussianProductFactor operator+(
|
||||||
|
const HybridGaussianFactor& factor) const;
|
||||||
|
|
||||||
|
/// Add-assign operator for GaussianFactor
|
||||||
|
HybridGaussianProductFactor& operator+=(
|
||||||
|
const GaussianFactor::shared_ptr& factor);
|
||||||
|
|
||||||
|
/// Add-assign operator for HybridGaussianFactor
|
||||||
|
HybridGaussianProductFactor& operator+=(const HybridGaussianFactor& factor);
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Print the HybridGaussianProductFactor
|
||||||
|
* @param s Optional string to prepend
|
||||||
|
* @param formatter Optional key formatter
|
||||||
|
*/
|
||||||
|
void print(const std::string& s = "",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check if this HybridGaussianProductFactor is equal to another
|
||||||
|
* @param other The other HybridGaussianProductFactor to compare with
|
||||||
|
* @param tol Tolerance for floating point comparisons
|
||||||
|
* @return true if equal, false otherwise
|
||||||
|
*/
|
||||||
|
bool equals(const HybridGaussianProductFactor& other,
|
||||||
|
double tol = 1e-9) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
/// @name Other methods
|
||||||
|
///@{
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Remove empty GaussianFactorGraphs from the decision tree
|
||||||
|
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs
|
||||||
|
* removed
|
||||||
|
*
|
||||||
|
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
|
* that leaf to an empty GaussianFactorGraph with zero scalar sum. This is
|
||||||
|
* needed because the DecisionTree will otherwise create a GaussianFactorGraph
|
||||||
|
* with a single (null) factor, which doesn't register as null.
|
||||||
|
*/
|
||||||
|
HybridGaussianProductFactor removeEmpty() const;
|
||||||
|
|
||||||
|
///@}
|
||||||
|
|
||||||
|
private:
|
||||||
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
// Testable traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridGaussianProductFactor>
|
||||||
|
: public Testable<HybridGaussianProductFactor> {};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a dummy overload of >> for GaussianFactorGraphValuePair
|
||||||
|
* so that HybridGaussianProductFactor compiles
|
||||||
|
* with the constructor
|
||||||
|
* `DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table)`.
|
||||||
|
*
|
||||||
|
* Needed to compile on Windows.
|
||||||
|
*/
|
||||||
|
std::istream& operator>>(std::istream& is, GaussianFactorGraphValuePair& pair);
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -100,8 +100,7 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
|
||||||
auto [factor, val] = f;
|
auto [factor, val] = f;
|
||||||
return factor->error(continuousValues) + val;
|
return factor->error(continuousValues) + val;
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> result(factors_, errorFunc);
|
return {factors_, errorFunc};
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -179,4 +179,47 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
|
||||||
return linearFG;
|
return linearFG;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::errorTree(
|
||||||
|
const Values& values) const {
|
||||||
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
|
// Iterate over each factor.
|
||||||
|
for (auto& factor : factors_) {
|
||||||
|
if (auto hnf = std::dynamic_pointer_cast<HybridNonlinearFactor>(factor)) {
|
||||||
|
// Compute factor error and add it.
|
||||||
|
result = result + hnf->errorTree(values);
|
||||||
|
|
||||||
|
} else if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(factor)) {
|
||||||
|
// If continuous only, get the (double) error
|
||||||
|
// and add it to every leaf of the result
|
||||||
|
result = result + nf->error(values);
|
||||||
|
|
||||||
|
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
|
||||||
|
// If discrete, just add its errorTree as well
|
||||||
|
result = result + df->errorTree();
|
||||||
|
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"HybridNonlinearFactorGraph::errorTree(Values) not implemented for "
|
||||||
|
"factor type " +
|
||||||
|
demangle(typeid(factor).name()) + ".");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
AlgebraicDecisionTree<Key> HybridNonlinearFactorGraph::discretePosterior(
|
||||||
|
const Values& continuousValues) const {
|
||||||
|
AlgebraicDecisionTree<Key> errors = this->errorTree(continuousValues);
|
||||||
|
AlgebraicDecisionTree<Key> p = errors.apply([](double error) {
|
||||||
|
// NOTE: The 0.5 term is handled by each factor
|
||||||
|
return exp(-error);
|
||||||
|
});
|
||||||
|
return p / p.sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,32 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
|
||||||
/// Expose error(const HybridValues&) method.
|
/// Expose error(const HybridValues&) method.
|
||||||
using Base::error;
|
using Base::error;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute error of (hybrid) nonlinear factors and discrete factors
|
||||||
|
* over each discrete assignment, and return as a tree.
|
||||||
|
*
|
||||||
|
* Error \f$ e = \Vert f(x) - \mu \Vert_{\Sigma} \f$.
|
||||||
|
*
|
||||||
|
* @note: Gaussian and hybrid Gaussian factors are not considered!
|
||||||
|
*
|
||||||
|
* @param values Manifold values at which to compute the error.
|
||||||
|
* @return AlgebraicDecisionTree<Key>
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> errorTree(const Values& values) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Computer posterior P(M|X=x) when all continuous values X are given.
|
||||||
|
* This is efficient as this simply takes -exp(.) of errorTree and normalizes.
|
||||||
|
*
|
||||||
|
* @note Not a DiscreteConditional as the cardinalities of the DiscreteKeys,
|
||||||
|
* which we would need, are hard to recover.
|
||||||
|
*
|
||||||
|
* @param continuousValues Continuous values x to condition on.
|
||||||
|
* @return DecisionTreeFactor
|
||||||
|
*/
|
||||||
|
AlgebraicDecisionTree<Key> discretePosterior(
|
||||||
|
const Values& continuousValues) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,8 +39,8 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
|
||||||
if (newFactors.size() > 0) {
|
if (newFactors.size() > 0) {
|
||||||
// Reorder and relinearize every reorderInterval updates
|
// Reorder and relinearize every reorderInterval updates
|
||||||
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
|
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
|
||||||
// TODO(Varun) Relinearization doesn't take into account pruning
|
// TODO(Varun) Re-linearization doesn't take into account pruning
|
||||||
reorder_relinearize();
|
reorderRelinearize();
|
||||||
reorderCounter_ = 0;
|
reorderCounter_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,7 +60,7 @@ void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridNonlinearISAM::reorder_relinearize() {
|
void HybridNonlinearISAM::reorderRelinearize() {
|
||||||
if (factors_.size() > 0) {
|
if (factors_.size() > 0) {
|
||||||
// Obtain the new linearization point
|
// Obtain the new linearization point
|
||||||
const Values newLinPoint = estimate();
|
const Values newLinPoint = estimate();
|
||||||
|
|
@ -69,7 +69,7 @@ void HybridNonlinearISAM::reorder_relinearize() {
|
||||||
|
|
||||||
// Just recreate the whole BayesTree
|
// Just recreate the whole BayesTree
|
||||||
// TODO: allow for constrained ordering here
|
// TODO: allow for constrained ordering here
|
||||||
// TODO: decouple relinearization and reordering to avoid
|
// TODO: decouple re-linearization and reordering to avoid
|
||||||
isam_.update(*factors_.linearize(newLinPoint), {}, {},
|
isam_.update(*factors_.linearize(newLinPoint), {}, {},
|
||||||
eliminationFunction_);
|
eliminationFunction_);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
||||||
/// The discrete assignment
|
/// The discrete assignment
|
||||||
DiscreteValues assignment_;
|
DiscreteValues assignment_;
|
||||||
|
|
||||||
/** The original factors, used when relinearizing */
|
/** The original factors, used when re-linearizing */
|
||||||
HybridNonlinearFactorGraph factors_;
|
HybridNonlinearFactorGraph factors_;
|
||||||
|
|
||||||
/** The reordering interval and counter */
|
/** The reordering interval and counter */
|
||||||
|
|
@ -52,8 +52,8 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Periodically reorder and relinearize
|
* Periodically reorder and re-linearize
|
||||||
* @param reorderInterval is the number of updates between reorderings,
|
* @param reorderInterval is the number of updates between re-orderings,
|
||||||
* 0 never reorders (and is dangerous for memory consumption)
|
* 0 never reorders (and is dangerous for memory consumption)
|
||||||
* 1 (default) reorders every time, in worse case is batch every update
|
* 1 (default) reorders every time, in worse case is batch every update
|
||||||
* typical values are 50 or 100
|
* typical values are 50 or 100
|
||||||
|
|
@ -124,8 +124,8 @@ class GTSAM_EXPORT HybridNonlinearISAM {
|
||||||
const std::optional<size_t>& maxNrLeaves = {},
|
const std::optional<size_t>& maxNrLeaves = {},
|
||||||
const std::optional<Ordering>& ordering = {});
|
const std::optional<Ordering>& ordering = {});
|
||||||
|
|
||||||
/** Relinearization and reordering of variables */
|
/** Re-linearization and reordering of variables */
|
||||||
void reorder_relinearize();
|
void reorderRelinearize();
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -72,21 +72,17 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph,
|
||||||
addConditionals(graph, hybridBayesNet_, ordering);
|
addConditionals(graph, hybridBayesNet_, ordering);
|
||||||
|
|
||||||
// Eliminate.
|
// Eliminate.
|
||||||
HybridBayesNet::shared_ptr bayesNetFragment =
|
HybridBayesNet bayesNetFragment = *graph.eliminateSequential(ordering);
|
||||||
graph.eliminateSequential(ordering);
|
|
||||||
|
|
||||||
/// Prune
|
/// Prune
|
||||||
if (maxNrLeaves) {
|
if (maxNrLeaves) {
|
||||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||||
// all the conditionals with the same keys in bayesNetFragment.
|
// all the conditionals with the same keys in bayesNetFragment.
|
||||||
HybridBayesNet prunedBayesNetFragment =
|
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves);
|
||||||
bayesNetFragment->prune(*maxNrLeaves);
|
|
||||||
// Set the bayes net fragment to the pruned version
|
|
||||||
bayesNetFragment = std::make_shared<HybridBayesNet>(prunedBayesNetFragment);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the partial bayes net to the posterior bayes net.
|
// Add the partial bayes net to the posterior bayes net.
|
||||||
hybridBayesNet_.add(*bayesNetFragment);
|
hybridBayesNet_.add(bayesNetFragment);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class GTSAM_EXPORT HybridSmoother {
|
||||||
* discrete factor on all discrete keys, plus all discrete factors in the
|
* discrete factor on all discrete keys, plus all discrete factors in the
|
||||||
* original graph.
|
* original graph.
|
||||||
*
|
*
|
||||||
* \note If maxComponents is given, we look at the discrete factor resulting
|
* \note If maxNrLeaves is given, we look at the discrete factor resulting
|
||||||
* from this elimination, and prune it and the Gaussian components
|
* from this elimination, and prune it and the Gaussian components
|
||||||
* corresponding to the pruned choices.
|
* corresponding to the pruned choices.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -46,29 +46,29 @@ using symbol_shorthand::X;
|
||||||
* @brief Create a switching system chain. A switching system is a continuous
|
* @brief Create a switching system chain. A switching system is a continuous
|
||||||
* system which depends on a discrete mode at each time step of the chain.
|
* system which depends on a discrete mode at each time step of the chain.
|
||||||
*
|
*
|
||||||
* @param n The number of chain elements.
|
* @param K The number of chain elements.
|
||||||
* @param x The functional to help specify the continuous key.
|
* @param x The functional to help specify the continuous key.
|
||||||
* @param m The functional to help specify the discrete key.
|
* @param m The functional to help specify the discrete key.
|
||||||
* @return HybridGaussianFactorGraph::shared_ptr
|
* @return HybridGaussianFactorGraph::shared_ptr
|
||||||
*/
|
*/
|
||||||
inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
|
inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain(
|
||||||
size_t n, std::function<Key(int)> x = X, std::function<Key(int)> m = M) {
|
size_t K, std::function<Key(int)> x = X, std::function<Key(int)> m = M) {
|
||||||
HybridGaussianFactorGraph hfg;
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
hfg.add(JacobianFactor(x(1), I_3x3, Z_3x1));
|
hfg.add(JacobianFactor(x(1), I_3x3, Z_3x1));
|
||||||
|
|
||||||
// x(1) to x(n+1)
|
// x(1) to x(n+1)
|
||||||
for (size_t t = 1; t < n; t++) {
|
for (size_t k = 1; k < K; k++) {
|
||||||
DiscreteKeys dKeys{{m(t), 2}};
|
DiscreteKeys dKeys{{m(k), 2}};
|
||||||
std::vector<GaussianFactor::shared_ptr> components;
|
std::vector<GaussianFactor::shared_ptr> components;
|
||||||
components.emplace_back(
|
components.emplace_back(
|
||||||
new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Z_3x1));
|
new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Z_3x1));
|
||||||
components.emplace_back(
|
components.emplace_back(
|
||||||
new JacobianFactor(x(t), I_3x3, x(t + 1), I_3x3, Vector3::Ones()));
|
new JacobianFactor(x(k), I_3x3, x(k + 1), I_3x3, Vector3::Ones()));
|
||||||
hfg.add(HybridGaussianFactor({m(t), 2}, components));
|
hfg.add(HybridGaussianFactor({m(k), 2}, components));
|
||||||
|
|
||||||
if (t > 1) {
|
if (k > 1) {
|
||||||
hfg.add(DecisionTreeFactor({{m(t - 1), 2}, {m(t), 2}}, "0 1 1 3"));
|
hfg.add(DecisionTreeFactor({{m(k - 1), 2}, {m(k), 2}}, "0 1 1 3"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -114,18 +114,27 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
|
||||||
return {new_order, levels};
|
return {new_order, levels};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ***************************************************************************
|
/* ****************************************************************************/
|
||||||
*/
|
|
||||||
using MotionModel = BetweenFactor<double>;
|
using MotionModel = BetweenFactor<double>;
|
||||||
|
|
||||||
// Test fixture with switching network.
|
// Test fixture with switching network.
|
||||||
|
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(K-3),M(K-2))
|
||||||
struct Switching {
|
struct Switching {
|
||||||
|
private:
|
||||||
|
HybridNonlinearFactorGraph nonlinearFactorGraph_;
|
||||||
|
|
||||||
|
public:
|
||||||
size_t K;
|
size_t K;
|
||||||
DiscreteKeys modes;
|
DiscreteKeys modes;
|
||||||
HybridNonlinearFactorGraph nonlinearFactorGraph;
|
HybridNonlinearFactorGraph unaryFactors, binaryFactors, modeChain;
|
||||||
HybridGaussianFactorGraph linearizedFactorGraph;
|
HybridGaussianFactorGraph linearizedFactorGraph;
|
||||||
Values linearizationPoint;
|
Values linearizationPoint;
|
||||||
|
|
||||||
|
// Access the flat nonlinear factor graph.
|
||||||
|
const HybridNonlinearFactorGraph &nonlinearFactorGraph() const {
|
||||||
|
return nonlinearFactorGraph_;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Create with given number of time steps.
|
* @brief Create with given number of time steps.
|
||||||
*
|
*
|
||||||
|
|
@ -136,12 +145,12 @@ struct Switching {
|
||||||
*/
|
*/
|
||||||
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1,
|
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1,
|
||||||
std::vector<double> measurements = {},
|
std::vector<double> measurements = {},
|
||||||
std::string discrete_transition_prob = "1/2 3/2")
|
std::string transitionProbabilityTable = "1/2 3/2")
|
||||||
: K(K) {
|
: K(K) {
|
||||||
using noiseModel::Isotropic;
|
using noiseModel::Isotropic;
|
||||||
|
|
||||||
// Create DiscreteKeys for binary K modes.
|
// Create DiscreteKeys for K-1 binary modes.
|
||||||
for (size_t k = 0; k < K; k++) {
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
modes.emplace_back(M(k), 2);
|
modes.emplace_back(M(k), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -153,35 +162,38 @@ struct Switching {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create hybrid factor graph.
|
// Create hybrid factor graph.
|
||||||
// Add a prior on X(0).
|
|
||||||
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
|
|
||||||
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));
|
|
||||||
|
|
||||||
// Add "motion models".
|
// Add a prior ϕ(X(0)) on X(0).
|
||||||
|
nonlinearFactorGraph_.emplace_shared<PriorFactor<double>>(
|
||||||
|
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));
|
||||||
|
unaryFactors.push_back(nonlinearFactorGraph_.back());
|
||||||
|
|
||||||
|
// Add "motion models" ϕ(X(k),X(k+1),M(k)).
|
||||||
for (size_t k = 0; k < K - 1; k++) {
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
auto motion_models = motionModels(k, between_sigma);
|
auto motion_models = motionModels(k, between_sigma);
|
||||||
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
|
nonlinearFactorGraph_.emplace_shared<HybridNonlinearFactor>(modes[k],
|
||||||
motion_models);
|
motion_models);
|
||||||
|
binaryFactors.push_back(nonlinearFactorGraph_.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add measurement factors
|
// Add measurement factors ϕ(X(k);z_k).
|
||||||
auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
|
auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
|
||||||
for (size_t k = 1; k < K; k++) {
|
for (size_t k = 1; k < K; k++) {
|
||||||
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
|
nonlinearFactorGraph_.emplace_shared<PriorFactor<double>>(
|
||||||
X(k), measurements.at(k), measurement_noise);
|
X(k), measurements.at(k), measurement_noise);
|
||||||
|
unaryFactors.push_back(nonlinearFactorGraph_.back());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add "mode chain"
|
// Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2))
|
||||||
addModeChain(&nonlinearFactorGraph, discrete_transition_prob);
|
modeChain = createModeChain(transitionProbabilityTable);
|
||||||
|
nonlinearFactorGraph_ += modeChain;
|
||||||
|
|
||||||
// Create the linearization point.
|
// Create the linearization point.
|
||||||
for (size_t k = 0; k < K; k++) {
|
for (size_t k = 0; k < K; k++) {
|
||||||
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
|
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The ground truth is robot moving forward
|
linearizedFactorGraph = *nonlinearFactorGraph_.linearize(linearizationPoint);
|
||||||
// and one less than the linearization point
|
|
||||||
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create motion models for a given time step
|
// Create motion models for a given time step
|
||||||
|
|
@ -201,15 +213,16 @@ struct Switching {
|
||||||
*
|
*
|
||||||
* @param fg The factor graph to which the mode chain is added.
|
* @param fg The factor graph to which the mode chain is added.
|
||||||
*/
|
*/
|
||||||
template <typename FACTORGRAPH>
|
HybridNonlinearFactorGraph createModeChain(
|
||||||
void addModeChain(FACTORGRAPH *fg,
|
std::string transitionProbabilityTable = "1/2 3/2") {
|
||||||
std::string discrete_transition_prob = "1/2 3/2") {
|
HybridNonlinearFactorGraph chain;
|
||||||
fg->template emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
chain.emplace_shared<DiscreteDistribution>(modes[0], "1/1");
|
||||||
for (size_t k = 0; k < K - 2; k++) {
|
for (size_t k = 0; k < K - 2; k++) {
|
||||||
auto parents = {modes[k]};
|
auto parents = {modes[k]};
|
||||||
fg->template emplace_shared<DiscreteConditional>(
|
chain.emplace_shared<DiscreteConditional>(modes[k + 1], parents,
|
||||||
modes[k + 1], parents, discrete_transition_prob);
|
transitionProbabilityTable);
|
||||||
}
|
}
|
||||||
|
return chain;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,17 +60,6 @@ double prob_m_z(double mu0, double mu1, double sigma0, double sigma1,
|
||||||
return p1 / (p0 + p1);
|
return p1 / (p0 + p1);
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Given \phi(m;z)\phi(m) use eliminate to obtain P(m|z).
|
|
||||||
DiscreteConditional SolveHFG(const HybridGaussianFactorGraph &hfg) {
|
|
||||||
return *hfg.eliminateSequential()->at(0)->asDiscrete();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Given p(z,m) and z, convert to HFG and solve.
|
|
||||||
DiscreteConditional SolveHBN(const HybridBayesNet &hbn, double z) {
|
|
||||||
VectorValues given{{Z(0), Vector1(z)}};
|
|
||||||
return SolveHFG(hbn.toFactorGraph(given));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Test a Gaussian Mixture Model P(m)p(z|m) with same sigma.
|
* Test a Gaussian Mixture Model P(m)p(z|m) with same sigma.
|
||||||
* The posterior, as a function of z, should be a sigmoid function.
|
* The posterior, as a function of z, should be a sigmoid function.
|
||||||
|
|
@ -88,7 +77,9 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
|
|
||||||
// At the halfway point between the means, we should get P(m|z)=0.5
|
// At the halfway point between the means, we should get P(m|z)=0.5
|
||||||
double midway = mu1 - mu0;
|
double midway = mu1 - mu0;
|
||||||
auto pMid = SolveHBN(gmm, midway);
|
auto eliminationResult =
|
||||||
|
gmm.toFactorGraph({{Z(0), Vector1(midway)}}).eliminateSequential();
|
||||||
|
auto pMid = *eliminationResult->at(0)->asDiscrete();
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
EXPECT(assert_equal(DiscreteConditional(m, "60/40"), pMid));
|
||||||
|
|
||||||
// Everywhere else, the result should be a sigmoid.
|
// Everywhere else, the result should be a sigmoid.
|
||||||
|
|
@ -97,7 +88,9 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);
|
const double expected = prob_m_z(mu0, mu1, sigma, sigma, z);
|
||||||
|
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto posterior1 = SolveHBN(gmm, z);
|
auto eliminationResult1 =
|
||||||
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
|
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
|
@ -105,7 +98,8 @@ TEST(GaussianMixture, GaussianMixtureModel) {
|
||||||
hfg1.emplace_shared<DecisionTreeFactor>(
|
hfg1.emplace_shared<DecisionTreeFactor>(
|
||||||
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
m, std::vector{Gaussian(mu0, sigma, z), Gaussian(mu1, sigma, z)});
|
||||||
hfg1.push_back(mixing);
|
hfg1.push_back(mixing);
|
||||||
auto posterior2 = SolveHFG(hfg1);
|
auto eliminationResult2 = hfg1.eliminateSequential();
|
||||||
|
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +122,23 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
// We get zMax=3.1333 by finding the maximum value of the function, at which
|
// We get zMax=3.1333 by finding the maximum value of the function, at which
|
||||||
// point the mode m==1 is about twice as probable as m==0.
|
// point the mode m==1 is about twice as probable as m==0.
|
||||||
double zMax = 3.133;
|
double zMax = 3.133;
|
||||||
auto pMax = SolveHBN(gmm, zMax);
|
const VectorValues vv{{Z(0), Vector1(zMax)}};
|
||||||
|
auto gfg = gmm.toFactorGraph(vv);
|
||||||
|
|
||||||
|
// Equality of posteriors asserts that the elimination is correct (same ratios
|
||||||
|
// for all modes)
|
||||||
|
const auto& expectedDiscretePosterior = gmm.discretePosterior(vv);
|
||||||
|
EXPECT(assert_equal(expectedDiscretePosterior, gfg.discretePosterior(vv)));
|
||||||
|
|
||||||
|
// Eliminate the graph!
|
||||||
|
auto eliminationResultMax = gfg.eliminateSequential();
|
||||||
|
|
||||||
|
// Equality of posteriors asserts that the elimination is correct (same ratios
|
||||||
|
// for all modes)
|
||||||
|
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||||
|
eliminationResultMax->discretePosterior(vv)));
|
||||||
|
|
||||||
|
auto pMax = *eliminationResultMax->at(0)->asDiscrete();
|
||||||
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
EXPECT(assert_equal(DiscreteConditional(m, "42/58"), pMax, 1e-4));
|
||||||
|
|
||||||
// Everywhere else, the result should be a bell curve like function.
|
// Everywhere else, the result should be a bell curve like function.
|
||||||
|
|
@ -137,7 +147,9 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);
|
const double expected = prob_m_z(mu0, mu1, sigma0, sigma1, z);
|
||||||
|
|
||||||
// Workflow 1: convert HBN to HFG and solve
|
// Workflow 1: convert HBN to HFG and solve
|
||||||
auto posterior1 = SolveHBN(gmm, z);
|
auto eliminationResult1 =
|
||||||
|
gmm.toFactorGraph({{Z(0), Vector1(z)}}).eliminateSequential();
|
||||||
|
auto posterior1 = *eliminationResult1->at(0)->asDiscrete();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior1(m1Assignment), 1e-8);
|
||||||
|
|
||||||
// Workflow 2: directly specify HFG and solve
|
// Workflow 2: directly specify HFG and solve
|
||||||
|
|
@ -145,11 +157,11 @@ TEST(GaussianMixture, GaussianMixtureModel2) {
|
||||||
hfg.emplace_shared<DecisionTreeFactor>(
|
hfg.emplace_shared<DecisionTreeFactor>(
|
||||||
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
m, std::vector{Gaussian(mu0, sigma0, z), Gaussian(mu1, sigma1, z)});
|
||||||
hfg.push_back(mixing);
|
hfg.push_back(mixing);
|
||||||
auto posterior2 = SolveHFG(hfg);
|
auto eliminationResult2 = hfg.eliminateSequential();
|
||||||
|
auto posterior2 = *eliminationResult2->at(0)->asDiscrete();
|
||||||
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,12 @@
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
@ -28,6 +31,7 @@
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
@ -62,32 +66,162 @@ TEST(HybridBayesNet, Add) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test evaluate for a pure discrete Bayes net P(Asia).
|
// Test API for a pure discrete Bayes net P(Asia).
|
||||||
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
TEST(HybridBayesNet, EvaluatePureDiscrete) {
|
||||||
HybridBayesNet bayesNet;
|
HybridBayesNet bayesNet;
|
||||||
bayesNet.emplace_shared<DiscreteConditional>(Asia, "4/6");
|
const auto pAsia = std::make_shared<DiscreteConditional>(Asia, "4/6");
|
||||||
HybridValues values;
|
bayesNet.push_back(pAsia);
|
||||||
values.insert(asiaKey, 0);
|
HybridValues zero{{}, {{asiaKey, 0}}}, one{{}, {{asiaKey, 1}}};
|
||||||
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(values), 1e-9);
|
|
||||||
|
// choose
|
||||||
|
GaussianBayesNet empty;
|
||||||
|
EXPECT(assert_equal(empty, bayesNet.choose(zero.discrete()), 1e-9));
|
||||||
|
|
||||||
|
// evaluate
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.4, bayesNet.evaluate(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.4, bayesNet(zero), 1e-9);
|
||||||
|
|
||||||
|
// optimize
|
||||||
|
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||||
|
EXPECT(assert_equal(VectorValues{}, bayesNet.optimize(one.discrete())));
|
||||||
|
|
||||||
|
// sample
|
||||||
|
std::mt19937_64 rng(42);
|
||||||
|
EXPECT(assert_equal(zero, bayesNet.sample(&rng)));
|
||||||
|
EXPECT(assert_equal(one, bayesNet.sample(one, &rng)));
|
||||||
|
EXPECT(assert_equal(zero, bayesNet.sample(zero, &rng)));
|
||||||
|
|
||||||
|
// prune
|
||||||
|
EXPECT(assert_equal(bayesNet, bayesNet.prune(2)));
|
||||||
|
EXPECT_LONGS_EQUAL(1, bayesNet.prune(1).at(0)->size());
|
||||||
|
|
||||||
|
// error
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(0.4), bayesNet.error(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(-log(0.6), bayesNet.error(one), 1e-9);
|
||||||
|
|
||||||
|
// errorTree
|
||||||
|
AlgebraicDecisionTree<Key> expected(asiaKey, -log(0.4), -log(0.6));
|
||||||
|
EXPECT(assert_equal(expected, bayesNet.errorTree({})));
|
||||||
|
|
||||||
|
// logProbability
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.4), bayesNet.logProbability(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(log(0.6), bayesNet.logProbability(one), 1e-9);
|
||||||
|
|
||||||
|
// discretePosterior
|
||||||
|
AlgebraicDecisionTree<Key> expectedPosterior(asiaKey, 0.4, 0.6);
|
||||||
|
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior({})));
|
||||||
|
|
||||||
|
// toFactorGraph
|
||||||
|
HybridGaussianFactorGraph expectedFG{pAsia}, fg = bayesNet.toFactorGraph({});
|
||||||
|
EXPECT(assert_equal(expectedFG, fg));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test creation of a tiny hybrid Bayes net.
|
// Test API for a tiny hybrid Bayes net.
|
||||||
TEST(HybridBayesNet, Tiny) {
|
TEST(HybridBayesNet, Tiny) {
|
||||||
auto bn = tiny::createHybridBayesNet();
|
auto bayesNet = tiny::createHybridBayesNet(); // P(z|x,mode)P(x)P(mode)
|
||||||
EXPECT_LONGS_EQUAL(3, bn.size());
|
EXPECT_LONGS_EQUAL(3, bayesNet.size());
|
||||||
|
|
||||||
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
const VectorValues vv{{Z(0), Vector1(5.0)}, {X(0), Vector1(5.0)}};
|
||||||
auto fg = bn.toFactorGraph(vv);
|
HybridValues zero{vv, {{M(0), 0}}}, one{vv, {{M(0), 1}}};
|
||||||
|
|
||||||
|
// Check Invariants for components
|
||||||
|
HybridGaussianConditional::shared_ptr hgc = bayesNet.at(0)->asHybrid();
|
||||||
|
GaussianConditional::shared_ptr gc0 = hgc->choose(zero.discrete()),
|
||||||
|
gc1 = hgc->choose(one.discrete());
|
||||||
|
GaussianConditional::shared_ptr px = bayesNet.at(1)->asGaussian();
|
||||||
|
GaussianConditional::CheckInvariants(*gc0, vv);
|
||||||
|
GaussianConditional::CheckInvariants(*gc1, vv);
|
||||||
|
GaussianConditional::CheckInvariants(*px, vv);
|
||||||
|
HybridGaussianConditional::CheckInvariants(*hgc, zero);
|
||||||
|
HybridGaussianConditional::CheckInvariants(*hgc, one);
|
||||||
|
|
||||||
|
// choose
|
||||||
|
GaussianBayesNet expectedChosen;
|
||||||
|
expectedChosen.push_back(gc0);
|
||||||
|
expectedChosen.push_back(px);
|
||||||
|
auto chosen0 = bayesNet.choose(zero.discrete());
|
||||||
|
auto chosen1 = bayesNet.choose(one.discrete());
|
||||||
|
EXPECT(assert_equal(expectedChosen, chosen0, 1e-9));
|
||||||
|
|
||||||
|
// logProbability
|
||||||
|
const double logP0 = chosen0.logProbability(vv) + log(0.4); // 0.4 is prior
|
||||||
|
const double logP1 = chosen1.logProbability(vv) + log(0.6); // 0.6 is prior
|
||||||
|
EXPECT_DOUBLES_EQUAL(logP0, bayesNet.logProbability(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(logP1, bayesNet.logProbability(one), 1e-9);
|
||||||
|
|
||||||
|
// evaluate
|
||||||
|
EXPECT_DOUBLES_EQUAL(exp(logP0), bayesNet.evaluate(zero), 1e-9);
|
||||||
|
|
||||||
|
// optimize
|
||||||
|
EXPECT(assert_equal(one, bayesNet.optimize()));
|
||||||
|
EXPECT(assert_equal(chosen0.optimize(), bayesNet.optimize(zero.discrete())));
|
||||||
|
|
||||||
|
// sample. Not deterministic !!! TODO(Frank): figure out why
|
||||||
|
// std::mt19937_64 rng(42);
|
||||||
|
// EXPECT(assert_equal({{M(0), 1}}, bayesNet.sample(&rng).discrete()));
|
||||||
|
|
||||||
|
// prune
|
||||||
|
auto pruned = bayesNet.prune(1);
|
||||||
|
CHECK(pruned.at(1)->asHybrid());
|
||||||
|
EXPECT_LONGS_EQUAL(1, pruned.at(1)->asHybrid()->nrComponents());
|
||||||
|
EXPECT(!pruned.equals(bayesNet));
|
||||||
|
|
||||||
|
// error
|
||||||
|
const double error0 = chosen0.error(vv) + gc0->negLogConstant() -
|
||||||
|
px->negLogConstant() - log(0.4);
|
||||||
|
const double error1 = chosen1.error(vv) + gc1->negLogConstant() -
|
||||||
|
px->negLogConstant() - log(0.6);
|
||||||
|
// print errors:
|
||||||
|
EXPECT_DOUBLES_EQUAL(error0, bayesNet.error(zero), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(error1, bayesNet.error(one), 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(error0 + logP0, error1 + logP1, 1e-9);
|
||||||
|
|
||||||
|
// errorTree
|
||||||
|
AlgebraicDecisionTree<Key> expected(M(0), error0, error1);
|
||||||
|
EXPECT(assert_equal(expected, bayesNet.errorTree(vv)));
|
||||||
|
|
||||||
|
// discretePosterior
|
||||||
|
// We have: P(z|x,mode)P(x)P(mode). When we condition on z and x, we get
|
||||||
|
// P(mode|z,x) \propto P(z|x,mode)P(x)P(mode)
|
||||||
|
// Normalizing this yields posterior P(mode|z,x) = {0.8, 0.2}
|
||||||
|
double q0 = std::exp(logP0), q1 = std::exp(logP1), sum = q0 + q1;
|
||||||
|
AlgebraicDecisionTree<Key> expectedPosterior(M(0), q0 / sum, q1 / sum);
|
||||||
|
EXPECT(assert_equal(expectedPosterior, bayesNet.discretePosterior(vv)));
|
||||||
|
|
||||||
|
// toFactorGraph
|
||||||
|
auto fg = bayesNet.toFactorGraph({{Z(0), Vector1(5.0)}});
|
||||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
|
// Create the product factor for eliminating x0:
|
||||||
|
HybridGaussianFactorGraph factors_x0;
|
||||||
|
factors_x0.push_back(fg.at(0));
|
||||||
|
factors_x0.push_back(fg.at(1));
|
||||||
|
auto productFactor = factors_x0.collectProductFactor();
|
||||||
|
|
||||||
|
// Check that scalars are 0 and 1.79 (regression)
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, productFactor({{M(0), 0}}).second, 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.791759, productFactor({{M(0), 1}}).second, 1e-5);
|
||||||
|
|
||||||
|
// Call eliminate and check scalar:
|
||||||
|
auto result = factors_x0.eliminate({X(0)});
|
||||||
|
auto df = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||||
|
|
||||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||||
std::vector<double> ratio(2);
|
std::vector<double> ratio(2);
|
||||||
for (size_t mode : {0, 1}) {
|
ratio[0] = std::exp(-fg.error(zero)) / bayesNet.evaluate(zero);
|
||||||
const HybridValues hv{vv, {{M(0), mode}}};
|
ratio[1] = std::exp(-fg.error(one)) / bayesNet.evaluate(one);
|
||||||
ratio[mode] = std::exp(-fg.error(hv)) / bn.evaluate(hv);
|
|
||||||
}
|
|
||||||
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
EXPECT_DOUBLES_EQUAL(ratio[0], ratio[1], 1e-8);
|
||||||
|
|
||||||
|
// Better and more general test:
|
||||||
|
// Since ϕ(M, x) \propto P(M,x|z) the discretePosteriors should agree
|
||||||
|
q0 = std::exp(-fg.error(zero));
|
||||||
|
q1 = std::exp(-fg.error(one));
|
||||||
|
sum = q0 + q1;
|
||||||
|
EXPECT(assert_equal(expectedPosterior, {M(0), q0 / sum, q1 / sum}));
|
||||||
|
VectorValues xv{{X(0), Vector1(5.0)}};
|
||||||
|
auto fgPosterior = fg.discretePosterior(xv);
|
||||||
|
EXPECT(assert_equal(expectedPosterior, fgPosterior));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -121,21 +255,6 @@ TEST(HybridBayesNet, evaluateHybrid) {
|
||||||
bayesNet.evaluate(values), 1e-9);
|
bayesNet.evaluate(values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test error for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
|
|
||||||
TEST(HybridBayesNet, Error) {
|
|
||||||
using namespace different_sigmas;
|
|
||||||
|
|
||||||
AlgebraicDecisionTree<Key> actual = bayesNet.errorTree(values.continuous());
|
|
||||||
|
|
||||||
// Regression.
|
|
||||||
// Manually added all the error values from the 3 conditional types.
|
|
||||||
AlgebraicDecisionTree<Key> expected(
|
|
||||||
{Asia}, std::vector<double>{2.33005033585, 5.38619084965});
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected, actual));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test choosing an assignment of conditionals
|
// Test choosing an assignment of conditionals
|
||||||
TEST(HybridBayesNet, Choose) {
|
TEST(HybridBayesNet, Choose) {
|
||||||
|
|
@ -223,29 +342,29 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test Bayes net error
|
// Test Bayes net error
|
||||||
TEST(HybridBayesNet, Pruning) {
|
TEST(HybridBayesNet, Pruning) {
|
||||||
|
// Create switching network with three continuous variables and two discrete:
|
||||||
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x0;z0) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||||
Switching s(3);
|
Switching s(3);
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr posterior =
|
HybridBayesNet::shared_ptr posterior =
|
||||||
s.linearizedFactorGraph.eliminateSequential();
|
s.linearizedFactorGraph.eliminateSequential();
|
||||||
EXPECT_LONGS_EQUAL(5, posterior->size());
|
EXPECT_LONGS_EQUAL(5, posterior->size());
|
||||||
|
|
||||||
|
// Optimize
|
||||||
HybridValues delta = posterior->optimize();
|
HybridValues delta = posterior->optimize();
|
||||||
auto actualTree = posterior->evaluate(delta.continuous());
|
|
||||||
|
|
||||||
// Regression test on density tree.
|
// Verify discrete posterior at optimal value sums to 1.
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
auto discretePosterior = posterior->discretePosterior(delta.continuous());
|
||||||
std::vector<double> leaves = {6.1112424, 20.346113, 17.785849, 19.738098};
|
EXPECT_DOUBLES_EQUAL(1.0, discretePosterior.sum(), 1e-9);
|
||||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
|
||||||
EXPECT(assert_equal(expected, actualTree, 1e-6));
|
// Regression test on discrete posterior at optimal value.
|
||||||
|
std::vector<double> leaves = {0.095516068, 0.31800092, 0.27798511, 0.3084979};
|
||||||
|
AlgebraicDecisionTree<Key> expected(s.modes, leaves);
|
||||||
|
EXPECT(assert_equal(expected, discretePosterior, 1e-6));
|
||||||
|
|
||||||
// Prune and get probabilities
|
// Prune and get probabilities
|
||||||
auto prunedBayesNet = posterior->prune(2);
|
auto prunedBayesNet = posterior->prune(2);
|
||||||
auto prunedTree = prunedBayesNet.evaluate(delta.continuous());
|
auto prunedTree = prunedBayesNet.discretePosterior(delta.continuous());
|
||||||
|
|
||||||
// Regression test on pruned logProbability tree
|
|
||||||
std::vector<double> pruned_leaves = {0.0, 32.713418, 0.0, 31.735823};
|
|
||||||
AlgebraicDecisionTree<Key> expected_pruned(discrete_keys, pruned_leaves);
|
|
||||||
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
|
||||||
|
|
||||||
// Verify logProbability computation and check specific logProbability value
|
// Verify logProbability computation and check specific logProbability value
|
||||||
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
const DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
|
||||||
|
|
@ -254,19 +373,25 @@ TEST(HybridBayesNet, Pruning) {
|
||||||
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
logProbability += posterior->at(0)->asHybrid()->logProbability(hybridValues);
|
||||||
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
|
logProbability += posterior->at(1)->asHybrid()->logProbability(hybridValues);
|
||||||
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
|
logProbability += posterior->at(2)->asHybrid()->logProbability(hybridValues);
|
||||||
// NOTE(dellaert): the discrete errors were not added in logProbability tree!
|
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(3)->asDiscrete()->logProbability(hybridValues);
|
||||||
logProbability +=
|
logProbability +=
|
||||||
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
posterior->at(4)->asDiscrete()->logProbability(hybridValues);
|
||||||
|
|
||||||
// Regression
|
|
||||||
double density = exp(logProbability);
|
|
||||||
EXPECT_DOUBLES_EQUAL(density,
|
|
||||||
1.6078460548731697 * actualTree(discrete_values), 1e-6);
|
|
||||||
EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues),
|
||||||
1e-9);
|
1e-9);
|
||||||
|
|
||||||
|
// Check agreement with discrete posterior
|
||||||
|
// double density = exp(logProbability);
|
||||||
|
// FAILS: EXPECT_DOUBLES_EQUAL(density, discretePosterior(discrete_values),
|
||||||
|
// 1e-6);
|
||||||
|
|
||||||
|
// Regression test on pruned logProbability tree
|
||||||
|
std::vector<double> pruned_leaves = {0.0, 0.50758422, 0.0, 0.49241578};
|
||||||
|
AlgebraicDecisionTree<Key> expected_pruned(s.modes, pruned_leaves);
|
||||||
|
EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6));
|
||||||
|
|
||||||
|
// Regression
|
||||||
|
// FAILS: EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -296,50 +421,47 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
||||||
s.linearizedFactorGraph.eliminateSequential();
|
s.linearizedFactorGraph.eliminateSequential();
|
||||||
EXPECT_LONGS_EQUAL(7, posterior->size());
|
EXPECT_LONGS_EQUAL(7, posterior->size());
|
||||||
|
|
||||||
|
DiscreteConditional joint;
|
||||||
|
for (auto&& conditional : posterior->discreteMarginal()) {
|
||||||
|
joint = joint * (*conditional);
|
||||||
|
}
|
||||||
|
|
||||||
size_t maxNrLeaves = 3;
|
size_t maxNrLeaves = 3;
|
||||||
DiscreteConditional discreteConditionals;
|
auto prunedDecisionTree = joint.prune(maxNrLeaves);
|
||||||
for (auto&& conditional : *posterior) {
|
|
||||||
if (conditional->isDiscrete()) {
|
|
||||||
discreteConditionals =
|
|
||||||
discreteConditionals * (*conditional->asDiscrete());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const DecisionTreeFactor::shared_ptr prunedDecisionTree =
|
|
||||||
std::make_shared<DecisionTreeFactor>(
|
|
||||||
discreteConditionals.prune(maxNrLeaves));
|
|
||||||
|
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||||
prunedDecisionTree->nrLeaves());
|
prunedDecisionTree.nrLeaves());
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves());
|
EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree.nrLeaves());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
// NOTE(Frank): I had to include *three* non-zeroes here now.
|
||||||
DecisionTreeFactor::ADT potentials(
|
DecisionTreeFactor::ADT potentials(
|
||||||
dkeys, std::vector<double>{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577});
|
s.modes,
|
||||||
DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials);
|
std::vector<double>{0, 0, 0, 0.28739288, 0, 0.43106901, 0, 0.2815381});
|
||||||
|
DiscreteConditional expectedConditional(3, s.modes, potentials);
|
||||||
|
|
||||||
// Prune!
|
// Prune!
|
||||||
posterior->prune(maxNrLeaves);
|
auto pruned = posterior->prune(maxNrLeaves);
|
||||||
|
|
||||||
// Functor to verify values against the expected_discrete_conditionals
|
// Functor to verify values against the expectedConditional
|
||||||
auto checker = [&](const Assignment<Key>& assignment,
|
auto checker = [&](const Assignment<Key>& assignment,
|
||||||
double probability) -> double {
|
double probability) -> double {
|
||||||
// typecast so we can use this to get probability value
|
// typecast so we can use this to get probability value
|
||||||
DiscreteValues choices(assignment);
|
DiscreteValues choices(assignment);
|
||||||
if (prunedDecisionTree->operator()(choices) == 0) {
|
if (prunedDecisionTree(choices) == 0) {
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability,
|
EXPECT_DOUBLES_EQUAL(expectedConditional(choices), probability, 1e-6);
|
||||||
1e-9);
|
|
||||||
}
|
}
|
||||||
return 0.0;
|
return 0.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||||
auto pruned_discrete_conditionals = posterior->at(4)->asDiscrete();
|
CHECK(pruned.at(0)->asDiscrete());
|
||||||
|
auto pruned_discrete_conditionals = pruned.at(0)->asDiscrete();
|
||||||
auto discrete_conditional_tree =
|
auto discrete_conditional_tree =
|
||||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||||
pruned_discrete_conditionals);
|
pruned_discrete_conditionals);
|
||||||
|
|
@ -463,8 +585,8 @@ TEST(HybridBayesNet, ErrorTreeWithConditional) {
|
||||||
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
|
AlgebraicDecisionTree<Key> errorTree = gfg.errorTree(vv);
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
AlgebraicDecisionTree<Key> expected(m1, 59.335390372, 5050.125);
|
AlgebraicDecisionTree<Key> expected(m1, 60.028538, 5050.8181);
|
||||||
EXPECT(assert_equal(expected, errorTree, 1e-9));
|
EXPECT(assert_equal(expected, errorTree, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,9 @@
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||||
|
#include <gtsam/inference/DotWriter.h>
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
||||||
|
|
@ -28,9 +31,320 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using noiseModel::Isotropic;
|
using symbol_shorthand::D;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Y;
|
||||||
|
|
||||||
|
static const DiscreteKey m0(M(0), 2), m1(M(1), 2), m2(M(2), 2), m3(M(3), 2);
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
|
||||||
|
// Test multifrontal elimination
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
// Add priors on x0 and c1
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
ordering.push_back(X(0));
|
||||||
|
auto result = hfg.eliminatePartialMultifrontal(ordering);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(result.first->size(), 1);
|
||||||
|
EXPECT_LONGS_EQUAL(result.second->size(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
namespace two {
|
||||||
|
std::vector<GaussianFactor::shared_ptr> components(Key key) {
|
||||||
|
return {std::make_shared<JacobianFactor>(key, I_3x3, Z_3x1),
|
||||||
|
std::make_shared<JacobianFactor>(key, I_3x3, Vector3::Ones())};
|
||||||
|
}
|
||||||
|
} // namespace two
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(HybridGaussianFactorGraph,
|
||||||
|
HybridGaussianFactorGraphEliminateFullMultifrontalSimple) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
|
||||||
|
|
||||||
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
|
// TODO(Varun) Adding extra discrete variable not connected to continuous
|
||||||
|
// variable throws segfault
|
||||||
|
// hfg.add(DecisionTreeFactor({m1, m2, "1 2 3 4"));
|
||||||
|
|
||||||
|
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();
|
||||||
|
|
||||||
|
// The bayes tree should have 3 cliques
|
||||||
|
EXPECT_LONGS_EQUAL(3, result->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
// Prior on x0
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||||
|
// Factor between x0-x1
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
// Hybrid factor P(x1|c1)
|
||||||
|
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
|
||||||
|
// Prior factor on c1
|
||||||
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
|
|
||||||
|
// Get a constrained ordering keeping c1 last
|
||||||
|
auto ordering_full = HybridOrdering(hfg);
|
||||||
|
|
||||||
|
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
||||||
|
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(3, hbt->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check assembling the Bayes Tree roots after we do partial elimination
|
||||||
|
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
|
||||||
|
HybridGaussianFactorGraph hfg;
|
||||||
|
|
||||||
|
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||||
|
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
hfg.add(HybridGaussianFactor(m0, two::components(X(0))));
|
||||||
|
hfg.add(HybridGaussianFactor(m1, two::components(X(2))));
|
||||||
|
|
||||||
|
hfg.add(DecisionTreeFactor({m1, m2}, "1 2 3 4"));
|
||||||
|
|
||||||
|
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
|
||||||
|
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
|
||||||
|
|
||||||
|
hfg.add(HybridGaussianFactor(m3, two::components(X(3))));
|
||||||
|
hfg.add(HybridGaussianFactor(m2, two::components(X(5))));
|
||||||
|
|
||||||
|
auto ordering_full =
|
||||||
|
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});
|
||||||
|
|
||||||
|
const auto [hbt, remaining] = hfg.eliminatePartialMultifrontal(ordering_full);
|
||||||
|
|
||||||
|
// 9 cliques in the bayes tree and 0 remaining variables to eliminate.
|
||||||
|
EXPECT_LONGS_EQUAL(9, hbt->size());
|
||||||
|
EXPECT_LONGS_EQUAL(0, remaining->size());
|
||||||
|
|
||||||
|
/*
|
||||||
|
(Fan) Explanation: the Junction tree will need to re-eliminate to get to the
|
||||||
|
marginal on X(1), which is not possible because it involves eliminating
|
||||||
|
discrete before continuous. The solution to this, however, is in Murphy02.
|
||||||
|
TLDR is that this is 1. expensive and 2. inexact. nevertheless it is doable.
|
||||||
|
And I believe that we should do this.
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
void dotPrint(const HybridGaussianFactorGraph::shared_ptr& hfg,
|
||||||
|
const HybridBayesTree::shared_ptr& hbt,
|
||||||
|
const Ordering& ordering) {
|
||||||
|
DotWriter dw;
|
||||||
|
dw.positionHints['c'] = 2;
|
||||||
|
dw.positionHints['x'] = 1;
|
||||||
|
std::cout << hfg->dot(DefaultKeyFormatter, dw);
|
||||||
|
std::cout << "\n";
|
||||||
|
hbt->dot(std::cout);
|
||||||
|
|
||||||
|
std::cout << "\n";
|
||||||
|
std::cout << hfg->eliminateSequential(ordering)->dot(DefaultKeyFormatter, dw);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// TODO(fan): make a graph like Varun's paper one
|
||||||
|
TEST(HybridGaussianFactorGraph, Switching) {
|
||||||
|
auto N = 12;
|
||||||
|
auto hfg = makeSwitchingChain(N);
|
||||||
|
|
||||||
|
// X(5) will be the center, X(1-4), X(6-9)
|
||||||
|
// X(3), X(7)
|
||||||
|
// X(2), X(8)
|
||||||
|
// X(1), X(4), X(6), X(9)
|
||||||
|
// M(5) will be the center, M(1-4), M(6-8)
|
||||||
|
// M(3), M(7)
|
||||||
|
// M(1), M(4), M(2), M(6), M(8)
|
||||||
|
// auto ordering_full =
|
||||||
|
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
|
||||||
|
// X(5),
|
||||||
|
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
|
||||||
|
KeyVector ordering;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<int> naturalX(N);
|
||||||
|
std::iota(naturalX.begin(), naturalX.end(), 1);
|
||||||
|
std::vector<Key> ordX;
|
||||||
|
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
||||||
|
[](int x) { return X(x); });
|
||||||
|
|
||||||
|
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
||||||
|
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
||||||
|
// TODO(dellaert): this has no effect!
|
||||||
|
for (auto& l : lvls) {
|
||||||
|
l = -l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<int> naturalC(N - 1);
|
||||||
|
std::iota(naturalC.begin(), naturalC.end(), 1);
|
||||||
|
std::vector<Key> ordC;
|
||||||
|
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
||||||
|
[](int x) { return M(x); });
|
||||||
|
|
||||||
|
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
||||||
|
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
||||||
|
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
|
||||||
|
}
|
||||||
|
auto ordering_full = Ordering(ordering);
|
||||||
|
|
||||||
|
const auto [hbt, remaining] =
|
||||||
|
hfg->eliminatePartialMultifrontal(ordering_full);
|
||||||
|
|
||||||
|
// 12 cliques in the bayes tree and 0 remaining variables to eliminate.
|
||||||
|
EXPECT_LONGS_EQUAL(12, hbt->size());
|
||||||
|
EXPECT_LONGS_EQUAL(0, remaining->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// TODO(fan): make a graph like Varun's paper one
|
||||||
|
TEST(HybridGaussianFactorGraph, SwitchingISAM) {
|
||||||
|
auto N = 11;
|
||||||
|
auto hfg = makeSwitchingChain(N);
|
||||||
|
|
||||||
|
// X(5) will be the center, X(1-4), X(6-9)
|
||||||
|
// X(3), X(7)
|
||||||
|
// X(2), X(8)
|
||||||
|
// X(1), X(4), X(6), X(9)
|
||||||
|
// M(5) will be the center, M(1-4), M(6-8)
|
||||||
|
// M(3), M(7)
|
||||||
|
// M(1), M(4), M(2), M(6), M(8)
|
||||||
|
// auto ordering_full =
|
||||||
|
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
|
||||||
|
// X(5),
|
||||||
|
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
|
||||||
|
KeyVector ordering;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::vector<int> naturalX(N);
|
||||||
|
std::iota(naturalX.begin(), naturalX.end(), 1);
|
||||||
|
std::vector<Key> ordX;
|
||||||
|
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
||||||
|
[](int x) { return X(x); });
|
||||||
|
|
||||||
|
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
||||||
|
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
||||||
|
// TODO(dellaert): this has no effect!
|
||||||
|
for (auto& l : lvls) {
|
||||||
|
l = -l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
std::vector<int> naturalC(N - 1);
|
||||||
|
std::iota(naturalC.begin(), naturalC.end(), 1);
|
||||||
|
std::vector<Key> ordC;
|
||||||
|
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
||||||
|
[](int x) { return M(x); });
|
||||||
|
|
||||||
|
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
||||||
|
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
||||||
|
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
|
||||||
|
}
|
||||||
|
auto ordering_full = Ordering(ordering);
|
||||||
|
|
||||||
|
const auto [hbt, remaining] =
|
||||||
|
hfg->eliminatePartialMultifrontal(ordering_full);
|
||||||
|
|
||||||
|
auto new_fg = makeSwitchingChain(12);
|
||||||
|
auto isam = HybridGaussianISAM(*hbt);
|
||||||
|
|
||||||
|
// Run an ISAM update.
|
||||||
|
HybridGaussianFactorGraph factorGraph;
|
||||||
|
factorGraph.push_back(new_fg->at(new_fg->size() - 2));
|
||||||
|
factorGraph.push_back(new_fg->at(new_fg->size() - 1));
|
||||||
|
isam.update(factorGraph);
|
||||||
|
|
||||||
|
// ISAM should have 12 factors after the last update
|
||||||
|
EXPECT_LONGS_EQUAL(12, isam.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
||||||
|
const int N = 7;
|
||||||
|
auto hfg = makeSwitchingChain(N, X);
|
||||||
|
hfg->push_back(*makeSwitchingChain(N, Y, D));
|
||||||
|
|
||||||
|
for (int t = 1; t <= N; t++) {
|
||||||
|
hfg->add(JacobianFactor(X(t), I_3x3, Y(t), -I_3x3, Vector3(1.0, 0.0, 0.0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
KeyVector ordering;
|
||||||
|
|
||||||
|
KeyVector naturalX(N);
|
||||||
|
std::iota(naturalX.begin(), naturalX.end(), 1);
|
||||||
|
KeyVector ordX;
|
||||||
|
for (size_t i = 1; i <= N; i++) {
|
||||||
|
ordX.emplace_back(X(i));
|
||||||
|
ordX.emplace_back(Y(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 1; i <= N - 1; i++) {
|
||||||
|
ordX.emplace_back(M(i));
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i <= N - 1; i++) {
|
||||||
|
ordX.emplace_back(D(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
DotWriter dw;
|
||||||
|
dw.positionHints['x'] = 1;
|
||||||
|
dw.positionHints['c'] = 0;
|
||||||
|
dw.positionHints['d'] = 3;
|
||||||
|
dw.positionHints['y'] = 2;
|
||||||
|
// std::cout << hfg->dot(DefaultKeyFormatter, dw);
|
||||||
|
// std::cout << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
DotWriter dw;
|
||||||
|
dw.positionHints['y'] = 9;
|
||||||
|
// dw.positionHints['c'] = 0;
|
||||||
|
// dw.positionHints['d'] = 3;
|
||||||
|
dw.positionHints['x'] = 1;
|
||||||
|
// std::cout << "\n";
|
||||||
|
// std::cout << hfg->eliminateSequential(Ordering(ordX))
|
||||||
|
// ->dot(DefaultKeyFormatter, dw);
|
||||||
|
// hfg->eliminateMultifrontal(Ordering(ordX))->dot(std::cout);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ordering ordering_partial;
|
||||||
|
for (size_t i = 1; i <= N; i++) {
|
||||||
|
ordering_partial.emplace_back(X(i));
|
||||||
|
ordering_partial.emplace_back(Y(i));
|
||||||
|
}
|
||||||
|
const auto [hbn, remaining] =
|
||||||
|
hfg->eliminatePartialSequential(ordering_partial);
|
||||||
|
|
||||||
|
EXPECT_LONGS_EQUAL(14, hbn->size());
|
||||||
|
EXPECT_LONGS_EQUAL(11, remaining->size());
|
||||||
|
|
||||||
|
{
|
||||||
|
DotWriter dw;
|
||||||
|
dw.positionHints['x'] = 1;
|
||||||
|
dw.positionHints['c'] = 0;
|
||||||
|
dw.positionHints['d'] = 3;
|
||||||
|
dw.positionHints['y'] = 2;
|
||||||
|
// std::cout << remaining->dot(DefaultKeyFormatter, dw);
|
||||||
|
// std::cout << "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test multifrontal optimize
|
// Test multifrontal optimize
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
#include <bitset>
|
#include <string>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
||||||
|
|
@ -47,6 +47,33 @@ using namespace gtsam;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
using symbol_shorthand::Z;
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
namespace estimation_fixture {
|
||||||
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
||||||
|
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
||||||
|
// Ground truth discrete seq
|
||||||
|
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
||||||
|
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
||||||
|
|
||||||
|
Switching InitializeEstimationProblem(
|
||||||
|
const size_t K, const double between_sigma, const double measurement_sigma,
|
||||||
|
const std::vector<double>& measurements,
|
||||||
|
const std::string& transitionProbabilityTable,
|
||||||
|
HybridNonlinearFactorGraph& graph, Values& initial) {
|
||||||
|
Switching switching(K, between_sigma, measurement_sigma, measurements,
|
||||||
|
transitionProbabilityTable);
|
||||||
|
|
||||||
|
// Add prior on M(0)
|
||||||
|
graph.push_back(switching.modeChain.at(0));
|
||||||
|
|
||||||
|
// Add the X(0) prior
|
||||||
|
graph.push_back(switching.unaryFactors.at(0));
|
||||||
|
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
||||||
|
|
||||||
|
return switching;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace estimation_fixture
|
||||||
|
|
||||||
TEST(HybridEstimation, Full) {
|
TEST(HybridEstimation, Full) {
|
||||||
size_t K = 6;
|
size_t K = 6;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
|
std::vector<double> measurements = {0, 1, 2, 2, 2, 3};
|
||||||
|
|
@ -90,37 +117,80 @@ TEST(HybridEstimation, Full) {
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// Test approximate inference with an additional pruning step.
|
// Test approximate inference with an additional pruning step.
|
||||||
TEST(HybridEstimation, IncrementalSmoother) {
|
TEST(HybridEstimation, IncrementalSmoother) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
size_t K = 15;
|
size_t K = 15;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
|
||||||
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
|
||||||
// Ground truth discrete seq
|
|
||||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
|
||||||
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
|
||||||
// Switching example of robot moving in 1D
|
// Switching example of robot moving in 1D
|
||||||
// with given measurements and equal mode priors.
|
// with given measurements and equal mode priors.
|
||||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
|
||||||
HybridSmoother smoother;
|
|
||||||
HybridNonlinearFactorGraph graph;
|
HybridNonlinearFactorGraph graph;
|
||||||
Values initial;
|
Values initial;
|
||||||
|
Switching switching = InitializeEstimationProblem(K, 1.0, 0.1, measurements,
|
||||||
// Add the X(0) prior
|
"1/1 1/1", graph, initial);
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(0));
|
HybridSmoother smoother;
|
||||||
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
|
||||||
|
|
||||||
HybridGaussianFactorGraph linearized;
|
HybridGaussianFactorGraph linearized;
|
||||||
|
|
||||||
|
constexpr size_t maxNrLeaves = 3;
|
||||||
for (size_t k = 1; k < K; k++) {
|
for (size_t k = 1; k < K; k++) {
|
||||||
// Motion Model
|
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||||
// Measurement
|
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
|
|
||||||
|
|
||||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
linearized = *graph.linearize(initial);
|
linearized = *graph.linearize(initial);
|
||||||
Ordering ordering = smoother.getOrdering(linearized);
|
Ordering ordering = smoother.getOrdering(linearized);
|
||||||
|
|
||||||
smoother.update(linearized, 3, ordering);
|
smoother.update(linearized, maxNrLeaves, ordering);
|
||||||
|
graph.resize(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
HybridValues delta = smoother.hybridBayesNet().optimize();
|
||||||
|
|
||||||
|
Values result = initial.retract(delta.continuous());
|
||||||
|
|
||||||
|
DiscreteValues expected_discrete;
|
||||||
|
for (size_t k = 0; k < K - 1; k++) {
|
||||||
|
expected_discrete[M(k)] = discrete_seq[k];
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_discrete, delta.discrete()));
|
||||||
|
|
||||||
|
Values expected_continuous;
|
||||||
|
for (size_t k = 0; k < K; k++) {
|
||||||
|
expected_continuous.insert(X(k), measurements[k]);
|
||||||
|
}
|
||||||
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
// Test if pruned factor is set to correct error and no errors are thrown.
|
||||||
|
TEST(HybridEstimation, ValidPruningError) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
|
size_t K = 8;
|
||||||
|
|
||||||
|
HybridNonlinearFactorGraph graph;
|
||||||
|
Values initial;
|
||||||
|
Switching switching = InitializeEstimationProblem(K, 1e-2, 1e-3, measurements,
|
||||||
|
"1/1 1/1", graph, initial);
|
||||||
|
HybridSmoother smoother;
|
||||||
|
|
||||||
|
HybridGaussianFactorGraph linearized;
|
||||||
|
|
||||||
|
constexpr size_t maxNrLeaves = 3;
|
||||||
|
for (size_t k = 1; k < K; k++) {
|
||||||
|
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||||
|
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||||
|
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||||
|
|
||||||
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
|
linearized = *graph.linearize(initial);
|
||||||
|
Ordering ordering = smoother.getOrdering(linearized);
|
||||||
|
|
||||||
|
smoother.update(linearized, maxNrLeaves, ordering);
|
||||||
|
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,37 +214,31 @@ TEST(HybridEstimation, IncrementalSmoother) {
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
// Test approximate inference with an additional pruning step.
|
// Test approximate inference with an additional pruning step.
|
||||||
TEST(HybridEstimation, ISAM) {
|
TEST(HybridEstimation, ISAM) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
size_t K = 15;
|
size_t K = 15;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6,
|
|
||||||
7, 8, 9, 9, 9, 10, 11, 11, 11, 11};
|
|
||||||
// Ground truth discrete seq
|
|
||||||
std::vector<size_t> discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0,
|
|
||||||
1, 1, 1, 0, 0, 1, 1, 0, 0, 0};
|
|
||||||
// Switching example of robot moving in 1D
|
// Switching example of robot moving in 1D
|
||||||
// with given measurements and equal mode priors.
|
// with given measurements and equal mode priors.
|
||||||
Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1");
|
|
||||||
HybridNonlinearISAM isam;
|
|
||||||
HybridNonlinearFactorGraph graph;
|
HybridNonlinearFactorGraph graph;
|
||||||
Values initial;
|
Values initial;
|
||||||
|
Switching switching = InitializeEstimationProblem(K, 1.0, 0.1, measurements,
|
||||||
// gttic_(Estimation);
|
"1/1 1/1", graph, initial);
|
||||||
|
HybridNonlinearISAM isam;
|
||||||
// Add the X(0) prior
|
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(0));
|
|
||||||
initial.insert(X(0), switching.linearizationPoint.at<double>(X(0)));
|
|
||||||
|
|
||||||
HybridGaussianFactorGraph linearized;
|
HybridGaussianFactorGraph linearized;
|
||||||
|
|
||||||
|
const size_t maxNrLeaves = 3;
|
||||||
for (size_t k = 1; k < K; k++) {
|
for (size_t k = 1; k < K; k++) {
|
||||||
// Motion Model
|
if (k > 1) graph.push_back(switching.modeChain.at(k - 1)); // Mode chain
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(k));
|
graph.push_back(switching.binaryFactors.at(k - 1)); // Motion Model
|
||||||
// Measurement
|
graph.push_back(switching.unaryFactors.at(k)); // Measurement
|
||||||
graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1));
|
|
||||||
|
|
||||||
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
initial.insert(X(k), switching.linearizationPoint.at<double>(X(k)));
|
||||||
|
|
||||||
isam.update(graph, initial, 3);
|
isam.update(graph, initial, maxNrLeaves);
|
||||||
// isam.bayesTree().print("\n\n");
|
// isam.saveGraph("NLiSAM" + std::to_string(k) + ".dot");
|
||||||
|
// GTSAM_PRINT(isam);
|
||||||
|
|
||||||
graph.resize(0);
|
graph.resize(0);
|
||||||
initial.clear();
|
initial.clear();
|
||||||
|
|
@ -196,65 +260,6 @@ TEST(HybridEstimation, ISAM) {
|
||||||
EXPECT(assert_equal(expected_continuous, result));
|
EXPECT(assert_equal(expected_continuous, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief A function to get a specific 1D robot motion problem as a linearized
|
|
||||||
* factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous
|
|
||||||
* positions given the measurements and discrete sequence.
|
|
||||||
*
|
|
||||||
* @param K The number of timesteps.
|
|
||||||
* @param measurements The vector of measurements for each timestep.
|
|
||||||
* @param discrete_seq The discrete sequence governing the motion of the robot.
|
|
||||||
* @param measurement_sigma Noise model sigma for measurements.
|
|
||||||
* @param between_sigma Noise model sigma for the between factor.
|
|
||||||
* @return GaussianFactorGraph::shared_ptr
|
|
||||||
*/
|
|
||||||
GaussianFactorGraph::shared_ptr specificModesFactorGraph(
|
|
||||||
size_t K, const std::vector<double>& measurements,
|
|
||||||
const std::vector<size_t>& discrete_seq, double measurement_sigma = 0.1,
|
|
||||||
double between_sigma = 1.0) {
|
|
||||||
NonlinearFactorGraph graph;
|
|
||||||
Values linearizationPoint;
|
|
||||||
|
|
||||||
// Add measurement factors
|
|
||||||
auto measurement_noise = noiseModel::Isotropic::Sigma(1, measurement_sigma);
|
|
||||||
for (size_t k = 0; k < K; k++) {
|
|
||||||
graph.emplace_shared<PriorFactor<double>>(X(k), measurements.at(k),
|
|
||||||
measurement_noise);
|
|
||||||
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
|
|
||||||
}
|
|
||||||
|
|
||||||
using MotionModel = BetweenFactor<double>;
|
|
||||||
|
|
||||||
// Add "motion models".
|
|
||||||
auto motion_noise_model = noiseModel::Isotropic::Sigma(1, between_sigma);
|
|
||||||
for (size_t k = 0; k < K - 1; k++) {
|
|
||||||
auto motion_model = std::make_shared<MotionModel>(
|
|
||||||
X(k), X(k + 1), discrete_seq.at(k), motion_noise_model);
|
|
||||||
graph.push_back(motion_model);
|
|
||||||
}
|
|
||||||
GaussianFactorGraph::shared_ptr linear_graph =
|
|
||||||
graph.linearize(linearizationPoint);
|
|
||||||
return linear_graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Get the discrete sequence from the integer `x`.
|
|
||||||
*
|
|
||||||
* @tparam K Template parameter so we can set the correct bitset size.
|
|
||||||
* @param x The integer to convert to a discrete binary sequence.
|
|
||||||
* @return std::vector<size_t>
|
|
||||||
*/
|
|
||||||
template <size_t K>
|
|
||||||
std::vector<size_t> getDiscreteSequence(size_t x) {
|
|
||||||
std::bitset<K - 1> seq = x;
|
|
||||||
std::vector<size_t> discrete_seq(K - 1);
|
|
||||||
for (size_t i = 0; i < K - 1; i++) {
|
|
||||||
// Save to discrete vector in reverse order
|
|
||||||
discrete_seq[K - 2 - i] = seq[i];
|
|
||||||
}
|
|
||||||
return discrete_seq;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper method to get the tree of
|
* @brief Helper method to get the tree of
|
||||||
* unnormalized probabilities as per the elimination scheme.
|
* unnormalized probabilities as per the elimination scheme.
|
||||||
|
|
@ -265,7 +270,7 @@ std::vector<size_t> getDiscreteSequence(size_t x) {
|
||||||
* @param graph The HybridGaussianFactorGraph to eliminate.
|
* @param graph The HybridGaussianFactorGraph to eliminate.
|
||||||
* @return AlgebraicDecisionTree<Key>
|
* @return AlgebraicDecisionTree<Key>
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree<Key> getProbPrimeTree(
|
AlgebraicDecisionTree<Key> GetProbPrimeTree(
|
||||||
const HybridGaussianFactorGraph& graph) {
|
const HybridGaussianFactorGraph& graph) {
|
||||||
Ordering continuous(graph.continuousKeySet());
|
Ordering continuous(graph.continuousKeySet());
|
||||||
const auto [bayesNet, remainingGraph] =
|
const auto [bayesNet, remainingGraph] =
|
||||||
|
|
@ -311,8 +316,9 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
|
||||||
* The values should match those of P'(Continuous) for each discrete mode.
|
* The values should match those of P'(Continuous) for each discrete mode.
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, Probability) {
|
TEST(HybridEstimation, Probability) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
constexpr size_t K = 4;
|
constexpr size_t K = 4;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2};
|
|
||||||
double between_sigma = 1.0, measurement_sigma = 0.1;
|
double between_sigma = 1.0, measurement_sigma = 0.1;
|
||||||
|
|
||||||
// Switching example of robot moving in 1D with
|
// Switching example of robot moving in 1D with
|
||||||
|
|
@ -338,12 +344,8 @@ TEST(HybridEstimation, Probability) {
|
||||||
HybridValues hybrid_values = bayesNet->optimize();
|
HybridValues hybrid_values = bayesNet->optimize();
|
||||||
|
|
||||||
// This is the correct sequence as designed
|
// This is the correct sequence as designed
|
||||||
DiscreteValues discrete_seq;
|
DiscreteValues expectedSequence{{M(0), 1}, {M(1), 1}, {M(2), 0}};
|
||||||
discrete_seq[M(0)] = 1;
|
EXPECT(assert_equal(expectedSequence, hybrid_values.discrete()));
|
||||||
discrete_seq[M(1)] = 1;
|
|
||||||
discrete_seq[M(2)] = 0;
|
|
||||||
|
|
||||||
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -353,8 +355,9 @@ TEST(HybridEstimation, Probability) {
|
||||||
* for each discrete mode.
|
* for each discrete mode.
|
||||||
*/
|
*/
|
||||||
TEST(HybridEstimation, ProbabilityMultifrontal) {
|
TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
|
using namespace estimation_fixture;
|
||||||
|
|
||||||
constexpr size_t K = 4;
|
constexpr size_t K = 4;
|
||||||
std::vector<double> measurements = {0, 1, 2, 2};
|
|
||||||
|
|
||||||
double between_sigma = 1.0, measurement_sigma = 0.1;
|
double between_sigma = 1.0, measurement_sigma = 0.1;
|
||||||
|
|
||||||
|
|
@ -365,7 +368,7 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
auto graph = switching.linearizedFactorGraph;
|
auto graph = switching.linearizedFactorGraph;
|
||||||
|
|
||||||
// Get the tree of unnormalized probabilities for each mode sequence.
|
// Get the tree of unnormalized probabilities for each mode sequence.
|
||||||
AlgebraicDecisionTree<Key> expected_probPrimeTree = getProbPrimeTree(graph);
|
AlgebraicDecisionTree<Key> expected_probPrimeTree = GetProbPrimeTree(graph);
|
||||||
|
|
||||||
// Eliminate continuous
|
// Eliminate continuous
|
||||||
Ordering continuous_ordering(graph.continuousKeySet());
|
Ordering continuous_ordering(graph.continuousKeySet());
|
||||||
|
|
@ -409,18 +412,14 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
HybridValues hybrid_values = discreteBayesTree->optimize();
|
HybridValues hybrid_values = discreteBayesTree->optimize();
|
||||||
|
|
||||||
// This is the correct sequence as designed
|
// This is the correct sequence as designed
|
||||||
DiscreteValues discrete_seq;
|
DiscreteValues expectedSequence{{M(0), 1}, {M(1), 1}, {M(2), 0}};
|
||||||
discrete_seq[M(0)] = 1;
|
EXPECT(assert_equal(expectedSequence, hybrid_values.discrete()));
|
||||||
discrete_seq[M(1)] = 1;
|
|
||||||
discrete_seq[M(2)] = 0;
|
|
||||||
|
|
||||||
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************
|
/*********************************************************************************
|
||||||
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
|
static HybridNonlinearFactorGraph CreateHybridNonlinearFactorGraph() {
|
||||||
HybridNonlinearFactorGraph nfg;
|
HybridNonlinearFactorGraph nfg;
|
||||||
|
|
||||||
constexpr double sigma = 0.5; // measurement noise
|
constexpr double sigma = 0.5; // measurement noise
|
||||||
|
|
@ -446,8 +445,8 @@ static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
|
||||||
/*********************************************************************************
|
/*********************************************************************************
|
||||||
// Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
|
// Create a hybrid linear factor graph f(x0, x1, m0; z0, z1)
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
static HybridGaussianFactorGraph::shared_ptr CreateHybridGaussianFactorGraph() {
|
||||||
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();
|
HybridNonlinearFactorGraph nfg = CreateHybridNonlinearFactorGraph();
|
||||||
|
|
||||||
Values initial;
|
Values initial;
|
||||||
double z0 = 0.0, z1 = 1.0;
|
double z0 = 0.0, z1 = 1.0;
|
||||||
|
|
@ -459,9 +458,9 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
||||||
/*********************************************************************************
|
/*********************************************************************************
|
||||||
* Do hybrid elimination and do regression test on discrete conditional.
|
* Do hybrid elimination and do regression test on discrete conditional.
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, eliminateSequentialRegression) {
|
TEST(HybridEstimation, EliminateSequentialRegression) {
|
||||||
// Create the factor graph from the nonlinear factor graph.
|
// Create the factor graph from the nonlinear factor graph.
|
||||||
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
HybridGaussianFactorGraph::shared_ptr fg = CreateHybridGaussianFactorGraph();
|
||||||
|
|
||||||
// Create expected discrete conditional on m0.
|
// Create expected discrete conditional on m0.
|
||||||
DiscreteKey m(M(0), 2);
|
DiscreteKey m(M(0), 2);
|
||||||
|
|
@ -496,7 +495,7 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, CorrectnessViaSampling) {
|
TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
// 1. Create the factor graph from the nonlinear factor graph.
|
// 1. Create the factor graph from the nonlinear factor graph.
|
||||||
const auto fg = createHybridGaussianFactorGraph();
|
const auto fg = CreateHybridGaussianFactorGraph();
|
||||||
|
|
||||||
// 2. Eliminate into BN
|
// 2. Eliminate into BN
|
||||||
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential();
|
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential();
|
||||||
|
|
@ -513,8 +512,6 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
// the normalizing term computed via the Bayes net determinant.
|
// the normalizing term computed via the Bayes net determinant.
|
||||||
const HybridValues sample = bn->sample(&rng);
|
const HybridValues sample = bn->sample(&rng);
|
||||||
double expected_ratio = compute_ratio(sample);
|
double expected_ratio = compute_ratio(sample);
|
||||||
// regression
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.728588, expected_ratio, 1e-6);
|
|
||||||
|
|
||||||
// 3. Do sampling
|
// 3. Do sampling
|
||||||
constexpr int num_samples = 10;
|
constexpr int num_samples = 10;
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
|
@ -25,6 +27,7 @@
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
|
|
@ -74,17 +77,6 @@ TEST(HybridGaussianConditional, Invariants) {
|
||||||
/// Check LogProbability.
|
/// Check LogProbability.
|
||||||
TEST(HybridGaussianConditional, LogProbability) {
|
TEST(HybridGaussianConditional, LogProbability) {
|
||||||
using namespace equal_constants;
|
using namespace equal_constants;
|
||||||
auto actual = hybrid_conditional.logProbability(vv);
|
|
||||||
|
|
||||||
// Check result.
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {mode};
|
|
||||||
std::vector<double> leaves = {conditionals[0]->logProbability(vv),
|
|
||||||
conditionals[1]->logProbability(vv)};
|
|
||||||
AlgebraicDecisionTree<Key> expected(discrete_keys, leaves);
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected, actual, 1e-6));
|
|
||||||
|
|
||||||
// Check for non-tree version.
|
|
||||||
for (size_t mode : {0, 1}) {
|
for (size_t mode : {0, 1}) {
|
||||||
const HybridValues hv{vv, {{M(0), mode}}};
|
const HybridValues hv{vv, {{M(0), mode}}};
|
||||||
EXPECT_DOUBLES_EQUAL(conditionals[mode]->logProbability(vv),
|
EXPECT_DOUBLES_EQUAL(conditionals[mode]->logProbability(vv),
|
||||||
|
|
@ -168,6 +160,9 @@ TEST(HybridGaussianConditional, ContinuousParents) {
|
||||||
// Check that the continuous parent keys are correct:
|
// Check that the continuous parent keys are correct:
|
||||||
EXPECT(continuousParentKeys.size() == 1);
|
EXPECT(continuousParentKeys.size() == 1);
|
||||||
EXPECT(continuousParentKeys[0] == X(0));
|
EXPECT(continuousParentKeys[0] == X(0));
|
||||||
|
|
||||||
|
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv0));
|
||||||
|
EXPECT(HybridGaussianConditional::CheckInvariants(hybrid_conditional, hv1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -221,30 +216,16 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
||||||
// Check the detailed JacobianFactor calculation for mode==1.
|
// Check the detailed JacobianFactor calculation for mode==1.
|
||||||
{
|
{
|
||||||
// We have a JacobianFactor
|
// We have a JacobianFactor
|
||||||
const auto gf1 = (*likelihood)(assignment1);
|
const auto [gf1, _] = (*likelihood)(assignment1);
|
||||||
const auto jf1 = std::dynamic_pointer_cast<JacobianFactor>(gf1);
|
const auto jf1 = std::dynamic_pointer_cast<JacobianFactor>(gf1);
|
||||||
CHECK(jf1);
|
CHECK(jf1);
|
||||||
|
|
||||||
// It has 2 rows, not 1!
|
// Check that the JacobianFactor error with constants is equal to the
|
||||||
CHECK(jf1->rows() == 2);
|
// conditional error:
|
||||||
|
EXPECT_DOUBLES_EQUAL(hybrid_conditional.error(hv1),
|
||||||
// Check that the constant C1 is properly encoded in the JacobianFactor.
|
jf1->error(hv1) + conditionals[1]->negLogConstant() -
|
||||||
const double C1 =
|
hybrid_conditional.negLogConstant(),
|
||||||
conditionals[1]->negLogConstant() - hybrid_conditional.negLogConstant();
|
1e-8);
|
||||||
const double c1 = std::sqrt(2.0 * C1);
|
|
||||||
Vector expected_unwhitened(2);
|
|
||||||
expected_unwhitened << 4.9 - 5.0, -c1;
|
|
||||||
Vector actual_unwhitened = jf1->unweighted_error(vv);
|
|
||||||
EXPECT(assert_equal(expected_unwhitened, actual_unwhitened));
|
|
||||||
|
|
||||||
// Make sure the noise model does not touch it.
|
|
||||||
Vector expected_whitened(2);
|
|
||||||
expected_whitened << (4.9 - 5.0) / 3.0, -c1;
|
|
||||||
Vector actual_whitened = jf1->error_vector(vv);
|
|
||||||
EXPECT(assert_equal(expected_whitened, actual_whitened));
|
|
||||||
|
|
||||||
// Check that the error is equal to the conditional error:
|
|
||||||
EXPECT_DOUBLES_EQUAL(hybrid_conditional.error(hv1), jf1->error(hv1), 1e-8);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
// Check that the ratio of probPrime to evaluate is the same for all modes.
|
||||||
|
|
@ -258,8 +239,60 @@ TEST(HybridGaussianConditional, Likelihood2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// Test pruning a HybridGaussianConditional with two discrete keys, based on a
|
||||||
|
// DecisionTreeFactor with 3 keys:
|
||||||
|
TEST(HybridGaussianConditional, Prune) {
|
||||||
|
// Create a two key conditional:
|
||||||
|
DiscreteKeys modes{{M(1), 2}, {M(2), 2}};
|
||||||
|
std::vector<GaussianConditional::shared_ptr> gcs;
|
||||||
|
for (size_t i = 0; i < 4; i++) {
|
||||||
|
gcs.push_back(
|
||||||
|
GaussianConditional::sharedMeanAndStddev(Z(0), Vector1(i + 1), i + 1));
|
||||||
|
}
|
||||||
|
auto empty = std::make_shared<GaussianConditional>();
|
||||||
|
HybridGaussianConditional::Conditionals conditionals(modes, gcs);
|
||||||
|
HybridGaussianConditional hgc(modes, conditionals);
|
||||||
|
|
||||||
|
DiscreteKeys keys = modes;
|
||||||
|
keys.push_back({M(3), 2});
|
||||||
|
{
|
||||||
|
for (size_t i = 0; i < 8; i++) {
|
||||||
|
std::vector<double> potentials{0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
|
potentials[i] = 1;
|
||||||
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
// Prune the HybridGaussianConditional
|
||||||
|
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||||
|
// Check that the pruned HybridGaussianConditional has 1 conditional
|
||||||
|
EXPECT_LONGS_EQUAL(1, pruned->nrComponents());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const std::vector<double> potentials{0, 0, 0.5, 0, //
|
||||||
|
0, 0, 0.5, 0};
|
||||||
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
|
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||||
|
|
||||||
|
// Check that the pruned HybridGaussianConditional has 2 conditionals
|
||||||
|
EXPECT_LONGS_EQUAL(2, pruned->nrComponents());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const std::vector<double> potentials{0.2, 0, 0.3, 0, //
|
||||||
|
0, 0, 0.5, 0};
|
||||||
|
const DecisionTreeFactor decisionTreeFactor(keys, potentials);
|
||||||
|
|
||||||
|
const auto pruned = hgc.prune(decisionTreeFactor);
|
||||||
|
|
||||||
|
// Check that the pruned HybridGaussianConditional has 3 conditionals
|
||||||
|
EXPECT_LONGS_EQUAL(3, pruned->nrComponents());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
return TestRegistry::runAllTests(tr);
|
return TestRegistry::runAllTests(tr);
|
||||||
}
|
}
|
||||||
/* ************************************************************************* */
|
/* *************************************************************************
|
||||||
|
*/
|
||||||
|
|
|
||||||
|
|
@ -82,40 +82,25 @@ TEST(HybridGaussianFactor, ConstructorVariants) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// "Add" two hybrid factors together.
|
TEST(HybridGaussianFactor, Keys) {
|
||||||
TEST(HybridGaussianFactor, Sum) {
|
|
||||||
using namespace test_constructor;
|
using namespace test_constructor;
|
||||||
DiscreteKey m2(2, 3);
|
|
||||||
|
|
||||||
auto A3 = Matrix::Zero(2, 3);
|
|
||||||
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
|
||||||
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
|
||||||
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
|
||||||
|
|
||||||
// TODO(Frank): why specify keys at all? And: keys in factor should be *all*
|
|
||||||
// keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey?
|
|
||||||
// Design review!
|
|
||||||
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
|
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
|
||||||
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
|
||||||
|
|
||||||
// Check the number of keys matches what we expect
|
// Check the number of keys matches what we expect
|
||||||
EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size());
|
EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size());
|
||||||
EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size());
|
EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size());
|
||||||
EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size());
|
EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size());
|
||||||
|
|
||||||
// Create sum of two hybrid factors: it will be a decision tree now on both
|
DiscreteKey m2(2, 3);
|
||||||
// discrete variables m1 and m2:
|
auto A3 = Matrix::Zero(2, 3);
|
||||||
GaussianFactorGraphTree sum;
|
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
sum += hybridFactorA;
|
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
sum += hybridFactorB;
|
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
|
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
||||||
|
|
||||||
// Let's check that this worked:
|
// Check the number of keys matches what we expect
|
||||||
Assignment<Key> mode;
|
EXPECT_LONGS_EQUAL(3, hybridFactorB.keys().size());
|
||||||
mode[m1.first] = 1;
|
EXPECT_LONGS_EQUAL(2, hybridFactorB.continuousKeys().size());
|
||||||
mode[m2.first] = 2;
|
EXPECT_LONGS_EQUAL(1, hybridFactorB.discreteKeys().size());
|
||||||
auto actual = sum(mode);
|
|
||||||
EXPECT(actual.at(0) == f11);
|
|
||||||
EXPECT(actual.at(1) == f22);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -124,8 +109,7 @@ TEST(HybridGaussianFactor, Printing) {
|
||||||
HybridGaussianFactor hybridFactor(m1, {f10, f11});
|
HybridGaussianFactor hybridFactor(m1, {f10, f11});
|
||||||
|
|
||||||
std::string expected =
|
std::string expected =
|
||||||
R"(HybridGaussianFactor
|
R"(Hybrid [x1 x2; 1]{
|
||||||
Hybrid [x1 x2; 1]{
|
|
||||||
Choice(1)
|
Choice(1)
|
||||||
0 Leaf :
|
0 Leaf :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
|
|
@ -138,6 +122,7 @@ Hybrid [x1 x2; 1]{
|
||||||
]
|
]
|
||||||
b = [ 0 0 ]
|
b = [ 0 0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
|
|
@ -150,6 +135,7 @@ Hybrid [x1 x2; 1]{
|
||||||
]
|
]
|
||||||
b = [ 0 0 ]
|
b = [ 0 0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0
|
||||||
|
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
@ -357,16 +343,9 @@ TEST(HybridGaussianFactor, DifferentCovariancesFG) {
|
||||||
cv.insert(X(0), Vector1(0.0));
|
cv.insert(X(0), Vector1(0.0));
|
||||||
cv.insert(X(1), Vector1(0.0));
|
cv.insert(X(1), Vector1(0.0));
|
||||||
|
|
||||||
// Check that the error values at the MLE point μ.
|
|
||||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
|
||||||
|
|
||||||
DiscreteValues dv0{{M(1), 0}};
|
DiscreteValues dv0{{M(1), 0}};
|
||||||
DiscreteValues dv1{{M(1), 1}};
|
DiscreteValues dv1{{M(1), 1}};
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9);
|
|
||||||
|
|
||||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,38 +13,34 @@
|
||||||
* @file testHybridGaussianFactorGraph.cpp
|
* @file testHybridGaussianFactorGraph.cpp
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <CppUnitLite/Test.h>
|
#include <CppUnitLite/Test.h>
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/TestableAssertions.h>
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
#include <gtsam/base/Vector.h>
|
#include <gtsam/base/Vector.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
#include <gtsam/inference/DotWriter.h>
|
|
||||||
#include <gtsam/inference/Key.h>
|
#include <gtsam/inference/Key.h>
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
#include <gtsam/linear/JacobianFactor.h>
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
|
||||||
#include <iostream>
|
|
||||||
#include <iterator>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <numeric>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
@ -53,17 +49,15 @@
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
using gtsam::symbol_shorthand::D;
|
|
||||||
using gtsam::symbol_shorthand::M;
|
using gtsam::symbol_shorthand::M;
|
||||||
using gtsam::symbol_shorthand::N;
|
using gtsam::symbol_shorthand::N;
|
||||||
using gtsam::symbol_shorthand::X;
|
using gtsam::symbol_shorthand::X;
|
||||||
using gtsam::symbol_shorthand::Y;
|
|
||||||
using gtsam::symbol_shorthand::Z;
|
using gtsam::symbol_shorthand::Z;
|
||||||
|
|
||||||
// Set up sampling
|
// Set up sampling
|
||||||
std::mt19937_64 kRng(42);
|
std::mt19937_64 kRng(42);
|
||||||
|
|
||||||
static const DiscreteKey m1(M(1), 2);
|
static const DiscreteKey m0(M(0), 2), m1(M(1), 2), m2(M(2), 2);
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, Creation) {
|
TEST(HybridGaussianFactorGraph, Creation) {
|
||||||
|
|
@ -76,7 +70,7 @@ TEST(HybridGaussianFactorGraph, Creation) {
|
||||||
// Define a hybrid gaussian conditional P(x0|x1, c0)
|
// Define a hybrid gaussian conditional P(x0|x1, c0)
|
||||||
// and add it to the factor graph.
|
// and add it to the factor graph.
|
||||||
HybridGaussianConditional gm(
|
HybridGaussianConditional gm(
|
||||||
{M(0), 2},
|
m0,
|
||||||
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
{std::make_shared<GaussianConditional>(X(0), Z_3x1, I_3x3, X(1), I_3x3),
|
||||||
std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1),
|
std::make_shared<GaussianConditional>(X(0), Vector3::Ones(), I_3x3, X(1),
|
||||||
I_3x3)});
|
I_3x3)});
|
||||||
|
|
@ -97,22 +91,6 @@ TEST(HybridGaussianFactorGraph, EliminateSequential) {
|
||||||
EXPECT_LONGS_EQUAL(result.first->size(), 1);
|
EXPECT_LONGS_EQUAL(result.first->size(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST(HybridGaussianFactorGraph, EliminateMultifrontal) {
|
|
||||||
// Test multifrontal elimination
|
|
||||||
HybridGaussianFactorGraph hfg;
|
|
||||||
|
|
||||||
// Add priors on x0 and c1
|
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
|
||||||
|
|
||||||
Ordering ordering;
|
|
||||||
ordering.push_back(X(0));
|
|
||||||
auto result = hfg.eliminatePartialMultifrontal(ordering);
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(result.first->size(), 1);
|
|
||||||
EXPECT_LONGS_EQUAL(result.second->size(), 1);
|
|
||||||
}
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
||||||
namespace two {
|
namespace two {
|
||||||
|
|
@ -138,7 +116,8 @@ TEST(HybridGaussianFactorGraph, hybridEliminationOneFactor) {
|
||||||
// Check that factor is discrete and correct
|
// Check that factor is discrete and correct
|
||||||
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
auto factor = std::dynamic_pointer_cast<DecisionTreeFactor>(result.second);
|
||||||
CHECK(factor);
|
CHECK(factor);
|
||||||
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor));
|
// regression test
|
||||||
|
EXPECT(assert_equal(DecisionTreeFactor{m1, "1 1"}, *factor, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -178,7 +157,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) {
|
||||||
// Discrete probability table for c1
|
// Discrete probability table for c1
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
||||||
// Joint discrete probability table for c1, c2
|
// Joint discrete probability table for c1, c2
|
||||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
hfg.add(DecisionTreeFactor({m1, m2}, "1 2 3 4"));
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr result = hfg.eliminateSequential();
|
HybridBayesNet::shared_ptr result = hfg.eliminateSequential();
|
||||||
|
|
||||||
|
|
@ -187,295 +166,219 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
|
// Test API for the smallest switching network.
|
||||||
HybridGaussianFactorGraph hfg;
|
// None of these are regression tests.
|
||||||
|
TEST(HybridBayesNet, Switching) {
|
||||||
|
// Create switching network with two continuous variables and one discrete:
|
||||||
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1;z1) ϕ(m0)
|
||||||
|
const double betweenSigma = 0.3, priorSigma = 0.1;
|
||||||
|
Switching s(2, betweenSigma, priorSigma);
|
||||||
|
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
// Check size of linearized factor graph
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||||
|
EXPECT_LONGS_EQUAL(4, graph.size());
|
||||||
|
|
||||||
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(1))));
|
// Create some continuous and discrete values
|
||||||
|
const VectorValues continuousValues{{X(0), Vector1(0.1)},
|
||||||
|
{X(1), Vector1(1.2)}};
|
||||||
|
const DiscreteValues modeZero{{M(0), 0}}, modeOne{{M(0), 1}};
|
||||||
|
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
// Get the hybrid gaussian factor and check it is as expected
|
||||||
// TODO(Varun) Adding extra discrete variable not connected to continuous
|
auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(graph.at(1));
|
||||||
// variable throws segfault
|
CHECK(hgf);
|
||||||
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
|
||||||
|
|
||||||
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();
|
// Get factors and scalars for both modes
|
||||||
|
auto [factor0, scalar0] = (*hgf)(modeZero);
|
||||||
|
auto [factor1, scalar1] = (*hgf)(modeOne);
|
||||||
|
CHECK(factor0);
|
||||||
|
CHECK(factor1);
|
||||||
|
|
||||||
// The bayes tree should have 3 cliques
|
// Check scalars against negLogConstant of noise model
|
||||||
EXPECT_LONGS_EQUAL(3, result->size());
|
auto betweenModel = noiseModel::Isotropic::Sigma(1, betweenSigma);
|
||||||
// GTSAM_PRINT(*result);
|
EXPECT_DOUBLES_EQUAL(betweenModel->negLogConstant(), scalar0, 1e-9);
|
||||||
// GTSAM_PRINT(*result->marginalFactor(M(2)));
|
EXPECT_DOUBLES_EQUAL(betweenModel->negLogConstant(), scalar1, 1e-9);
|
||||||
|
|
||||||
|
// Check error for M(0) = 0
|
||||||
|
const HybridValues values0{continuousValues, modeZero};
|
||||||
|
double expectedError0 = 0;
|
||||||
|
for (const auto &factor : graph) expectedError0 += factor->error(values0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(expectedError0, graph.error(values0), 1e-5);
|
||||||
|
|
||||||
|
// Check error for M(0) = 1
|
||||||
|
const HybridValues values1{continuousValues, modeOne};
|
||||||
|
double expectedError1 = 0;
|
||||||
|
for (const auto &factor : graph) expectedError1 += factor->error(values1);
|
||||||
|
EXPECT_DOUBLES_EQUAL(expectedError1, graph.error(values1), 1e-5);
|
||||||
|
|
||||||
|
// Check errorTree
|
||||||
|
AlgebraicDecisionTree<Key> actualErrors = graph.errorTree(continuousValues);
|
||||||
|
// Create expected error tree
|
||||||
|
const AlgebraicDecisionTree<Key> expectedErrors(M(0), expectedError0,
|
||||||
|
expectedError1);
|
||||||
|
|
||||||
|
// Check that the actual error tree matches the expected one
|
||||||
|
EXPECT(assert_equal(expectedErrors, actualErrors, 1e-5));
|
||||||
|
|
||||||
|
// Check probPrime
|
||||||
|
const double probPrime0 = graph.probPrime(values0);
|
||||||
|
EXPECT_DOUBLES_EQUAL(std::exp(-expectedError0), probPrime0, 1e-5);
|
||||||
|
|
||||||
|
const double probPrime1 = graph.probPrime(values1);
|
||||||
|
EXPECT_DOUBLES_EQUAL(std::exp(-expectedError1), probPrime1, 1e-5);
|
||||||
|
|
||||||
|
// Check discretePosterior
|
||||||
|
const AlgebraicDecisionTree<Key> graphPosterior =
|
||||||
|
graph.discretePosterior(continuousValues);
|
||||||
|
const double sum = probPrime0 + probPrime1;
|
||||||
|
const AlgebraicDecisionTree<Key> expectedPosterior(M(0), probPrime0 / sum,
|
||||||
|
probPrime1 / sum);
|
||||||
|
EXPECT(assert_equal(expectedPosterior, graphPosterior, 1e-5));
|
||||||
|
|
||||||
|
// Make the clique of factors connected to x0:
|
||||||
|
HybridGaussianFactorGraph factors_x0;
|
||||||
|
factors_x0.push_back(graph.at(0));
|
||||||
|
factors_x0.push_back(hgf);
|
||||||
|
|
||||||
|
// Test collectProductFactor
|
||||||
|
auto productFactor = factors_x0.collectProductFactor();
|
||||||
|
|
||||||
|
// For M(0) = 0
|
||||||
|
auto [gaussianFactor0, actualScalar0] = productFactor(modeZero);
|
||||||
|
EXPECT(gaussianFactor0.size() == 2);
|
||||||
|
EXPECT_DOUBLES_EQUAL((*hgf)(modeZero).second, actualScalar0, 1e-5);
|
||||||
|
|
||||||
|
// For M(0) = 1
|
||||||
|
auto [gaussianFactor1, actualScalar1] = productFactor(modeOne);
|
||||||
|
EXPECT(gaussianFactor1.size() == 2);
|
||||||
|
EXPECT_DOUBLES_EQUAL((*hgf)(modeOne).second, actualScalar1, 1e-5);
|
||||||
|
|
||||||
|
// Test eliminate x0
|
||||||
|
const Ordering ordering{X(0)};
|
||||||
|
auto [conditional, factor] = factors_x0.eliminate(ordering);
|
||||||
|
|
||||||
|
// Check the conditional
|
||||||
|
CHECK(conditional);
|
||||||
|
EXPECT(conditional->isHybrid());
|
||||||
|
auto p_x0_given_x1_m = conditional->asHybrid();
|
||||||
|
CHECK(p_x0_given_x1_m);
|
||||||
|
EXPECT(HybridGaussianConditional::CheckInvariants(*p_x0_given_x1_m, values1));
|
||||||
|
EXPECT_LONGS_EQUAL(1, p_x0_given_x1_m->nrFrontals()); // x0
|
||||||
|
EXPECT_LONGS_EQUAL(2, p_x0_given_x1_m->nrParents()); // x1, m0
|
||||||
|
|
||||||
|
// Check the remaining factor
|
||||||
|
EXPECT(factor);
|
||||||
|
EXPECT(std::dynamic_pointer_cast<HybridGaussianFactor>(factor));
|
||||||
|
auto phi_x1_m = std::dynamic_pointer_cast<HybridGaussianFactor>(factor);
|
||||||
|
EXPECT_LONGS_EQUAL(2, phi_x1_m->keys().size()); // x1, m0
|
||||||
|
// Check that the scalars incorporate the negative log constant of the
|
||||||
|
// conditional
|
||||||
|
EXPECT_DOUBLES_EQUAL(scalar0 - (*p_x0_given_x1_m)(modeZero)->negLogConstant(),
|
||||||
|
(*phi_x1_m)(modeZero).second, 1e-9);
|
||||||
|
EXPECT_DOUBLES_EQUAL(scalar1 - (*p_x0_given_x1_m)(modeOne)->negLogConstant(),
|
||||||
|
(*phi_x1_m)(modeOne).second, 1e-9);
|
||||||
|
|
||||||
|
// Check that the conditional and remaining factor are consistent for both
|
||||||
|
// modes
|
||||||
|
for (auto &&mode : {modeZero, modeOne}) {
|
||||||
|
const auto gc = (*p_x0_given_x1_m)(mode);
|
||||||
|
const auto [gf, scalar] = (*phi_x1_m)(mode);
|
||||||
|
|
||||||
|
// The error of the original factors should equal the sum of errors of the
|
||||||
|
// conditional and remaining factor, modulo the normalization constant of
|
||||||
|
// the conditional.
|
||||||
|
double originalError = factors_x0.error({continuousValues, mode});
|
||||||
|
const double actualError = gc->negLogConstant() +
|
||||||
|
gc->error(continuousValues) +
|
||||||
|
gf->error(continuousValues) + scalar;
|
||||||
|
EXPECT_DOUBLES_EQUAL(originalError, actualError, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
// Create a clique for x1
|
||||||
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
|
HybridGaussianFactorGraph factors_x1;
|
||||||
HybridGaussianFactorGraph hfg;
|
factors_x1.push_back(
|
||||||
|
factor); // Use the remaining factor from previous elimination
|
||||||
|
factors_x1.push_back(
|
||||||
|
graph.at(2)); // Add the factor for x1 from the original graph
|
||||||
|
|
||||||
// Prior on x0
|
// Test collectProductFactor for x1 clique
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
auto productFactor_x1 = factors_x1.collectProductFactor();
|
||||||
// Factor between x0-x1
|
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
|
||||||
|
|
||||||
// Hybrid factor P(x1|c1)
|
// For M(0) = 0
|
||||||
hfg.add(HybridGaussianFactor(m1, two::components(X(1))));
|
auto [gaussianFactor_x1_0, actualScalar_x1_0] = productFactor_x1(modeZero);
|
||||||
// Prior factor on c1
|
EXPECT_LONGS_EQUAL(2, gaussianFactor_x1_0.size());
|
||||||
hfg.add(DecisionTreeFactor(m1, {2, 8}));
|
// NOTE(Frank): prior on x1 does not contribute to the scalar
|
||||||
|
EXPECT_DOUBLES_EQUAL((*phi_x1_m)(modeZero).second, actualScalar_x1_0, 1e-5);
|
||||||
|
|
||||||
// Get a constrained ordering keeping c1 last
|
// For M(0) = 1
|
||||||
auto ordering_full = HybridOrdering(hfg);
|
auto [gaussianFactor_x1_1, actualScalar_x1_1] = productFactor_x1(modeOne);
|
||||||
|
EXPECT_LONGS_EQUAL(2, gaussianFactor_x1_1.size());
|
||||||
|
// NOTE(Frank): prior on x1 does not contribute to the scalar
|
||||||
|
EXPECT_DOUBLES_EQUAL((*phi_x1_m)(modeOne).second, actualScalar_x1_1, 1e-5);
|
||||||
|
|
||||||
// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
|
// Test eliminate for x1 clique
|
||||||
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
|
Ordering ordering_x1{X(1)};
|
||||||
|
auto [conditional_x1, factor_x1] = factors_x1.eliminate(ordering_x1);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(3, hbt->size());
|
// Check the conditional for x1
|
||||||
}
|
CHECK(conditional_x1);
|
||||||
|
EXPECT(conditional_x1->isHybrid());
|
||||||
|
auto p_x1_given_m = conditional_x1->asHybrid();
|
||||||
|
CHECK(p_x1_given_m);
|
||||||
|
EXPECT_LONGS_EQUAL(1, p_x1_given_m->nrFrontals()); // x1
|
||||||
|
EXPECT_LONGS_EQUAL(1, p_x1_given_m->nrParents()); // m0
|
||||||
|
|
||||||
/* ************************************************************************* */
|
// Check the remaining factor for x1
|
||||||
/*
|
CHECK(factor_x1);
|
||||||
* This test is about how to assemble the Bayes Tree roots after we do partial
|
auto phi_x1 = std::dynamic_pointer_cast<DecisionTreeFactor>(factor_x1);
|
||||||
* elimination
|
CHECK(phi_x1);
|
||||||
*/
|
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||||
TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) {
|
// We can't really check the error of the decision tree factor phi_x1, because
|
||||||
HybridGaussianFactorGraph hfg;
|
// the continuous factor whose error(kEmpty) we need is not available.
|
||||||
|
|
||||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
// Now test full elimination of the graph:
|
||||||
hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1));
|
auto hybridBayesNet = graph.eliminateSequential();
|
||||||
|
CHECK(hybridBayesNet);
|
||||||
|
|
||||||
hfg.add(HybridGaussianFactor({M(0), 2}, two::components(X(0))));
|
// Check that the posterior P(M|X=continuousValues) from the Bayes net is the
|
||||||
hfg.add(HybridGaussianFactor({M(1), 2}, two::components(X(2))));
|
// same as the same posterior from the graph. This is a sanity check that the
|
||||||
|
// elimination is done correctly.
|
||||||
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));
|
AlgebraicDecisionTree<Key> bnPosterior =
|
||||||
|
hybridBayesNet->discretePosterior(continuousValues);
|
||||||
hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1));
|
EXPECT(assert_equal(graphPosterior, bnPosterior));
|
||||||
hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1));
|
|
||||||
|
|
||||||
hfg.add(HybridGaussianFactor({M(3), 2}, two::components(X(3))));
|
|
||||||
hfg.add(HybridGaussianFactor({M(2), 2}, two::components(X(5))));
|
|
||||||
|
|
||||||
auto ordering_full =
|
|
||||||
Ordering::ColamdConstrainedLast(hfg, {M(0), M(1), M(2), M(3)});
|
|
||||||
|
|
||||||
const auto [hbt, remaining] = hfg.eliminatePartialMultifrontal(ordering_full);
|
|
||||||
|
|
||||||
// 9 cliques in the bayes tree and 0 remaining variables to eliminate.
|
|
||||||
EXPECT_LONGS_EQUAL(9, hbt->size());
|
|
||||||
EXPECT_LONGS_EQUAL(0, remaining->size());
|
|
||||||
|
|
||||||
/*
|
|
||||||
(Fan) Explanation: the Junction tree will need to re-eliminate to get to the
|
|
||||||
marginal on X(1), which is not possible because it involves eliminating
|
|
||||||
discrete before continuous. The solution to this, however, is in Murphy02.
|
|
||||||
TLDR is that this is 1. expensive and 2. inexact. nevertheless it is doable.
|
|
||||||
And I believe that we should do this.
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
void dotPrint(const HybridGaussianFactorGraph::shared_ptr &hfg,
|
|
||||||
const HybridBayesTree::shared_ptr &hbt,
|
|
||||||
const Ordering &ordering) {
|
|
||||||
DotWriter dw;
|
|
||||||
dw.positionHints['c'] = 2;
|
|
||||||
dw.positionHints['x'] = 1;
|
|
||||||
std::cout << hfg->dot(DefaultKeyFormatter, dw);
|
|
||||||
std::cout << "\n";
|
|
||||||
hbt->dot(std::cout);
|
|
||||||
|
|
||||||
std::cout << "\n";
|
|
||||||
std::cout << hfg->eliminateSequential(ordering)->dot(DefaultKeyFormatter, dw);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// TODO(fan): make a graph like Varun's paper one
|
|
||||||
TEST(HybridGaussianFactorGraph, Switching) {
|
|
||||||
auto N = 12;
|
|
||||||
auto hfg = makeSwitchingChain(N);
|
|
||||||
|
|
||||||
// X(5) will be the center, X(1-4), X(6-9)
|
|
||||||
// X(3), X(7)
|
|
||||||
// X(2), X(8)
|
|
||||||
// X(1), X(4), X(6), X(9)
|
|
||||||
// M(5) will be the center, M(1-4), M(6-8)
|
|
||||||
// M(3), M(7)
|
|
||||||
// M(1), M(4), M(2), M(6), M(8)
|
|
||||||
// auto ordering_full =
|
|
||||||
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
|
|
||||||
// X(5),
|
|
||||||
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
|
|
||||||
KeyVector ordering;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::vector<int> naturalX(N);
|
|
||||||
std::iota(naturalX.begin(), naturalX.end(), 1);
|
|
||||||
std::vector<Key> ordX;
|
|
||||||
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
|
||||||
[](int x) { return X(x); });
|
|
||||||
|
|
||||||
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
|
||||||
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
|
||||||
// TODO(dellaert): this has no effect!
|
|
||||||
for (auto &l : lvls) {
|
|
||||||
l = -l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
|
||||||
std::vector<int> naturalC(N - 1);
|
|
||||||
std::iota(naturalC.begin(), naturalC.end(), 1);
|
|
||||||
std::vector<Key> ordC;
|
|
||||||
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
|
||||||
[](int x) { return M(x); });
|
|
||||||
|
|
||||||
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
|
||||||
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
|
||||||
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
|
|
||||||
}
|
|
||||||
auto ordering_full = Ordering(ordering);
|
|
||||||
|
|
||||||
// GTSAM_PRINT(*hfg);
|
|
||||||
// GTSAM_PRINT(ordering_full);
|
|
||||||
|
|
||||||
const auto [hbt, remaining] =
|
|
||||||
hfg->eliminatePartialMultifrontal(ordering_full);
|
|
||||||
|
|
||||||
// 12 cliques in the bayes tree and 0 remaining variables to eliminate.
|
|
||||||
EXPECT_LONGS_EQUAL(12, hbt->size());
|
|
||||||
EXPECT_LONGS_EQUAL(0, remaining->size());
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
// TODO(fan): make a graph like Varun's paper one
|
|
||||||
TEST(HybridGaussianFactorGraph, SwitchingISAM) {
|
|
||||||
auto N = 11;
|
|
||||||
auto hfg = makeSwitchingChain(N);
|
|
||||||
|
|
||||||
// X(5) will be the center, X(1-4), X(6-9)
|
|
||||||
// X(3), X(7)
|
|
||||||
// X(2), X(8)
|
|
||||||
// X(1), X(4), X(6), X(9)
|
|
||||||
// M(5) will be the center, M(1-4), M(6-8)
|
|
||||||
// M(3), M(7)
|
|
||||||
// M(1), M(4), M(2), M(6), M(8)
|
|
||||||
// auto ordering_full =
|
|
||||||
// Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7),
|
|
||||||
// X(5),
|
|
||||||
// M(1), M(4), M(2), M(6), M(8), M(3), M(7), M(5)});
|
|
||||||
KeyVector ordering;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::vector<int> naturalX(N);
|
|
||||||
std::iota(naturalX.begin(), naturalX.end(), 1);
|
|
||||||
std::vector<Key> ordX;
|
|
||||||
std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX),
|
|
||||||
[](int x) { return X(x); });
|
|
||||||
|
|
||||||
auto [ndX, lvls] = makeBinaryOrdering(ordX);
|
|
||||||
std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering));
|
|
||||||
// TODO(dellaert): this has no effect!
|
|
||||||
for (auto &l : lvls) {
|
|
||||||
l = -l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
|
||||||
std::vector<int> naturalC(N - 1);
|
|
||||||
std::iota(naturalC.begin(), naturalC.end(), 1);
|
|
||||||
std::vector<Key> ordC;
|
|
||||||
std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC),
|
|
||||||
[](int x) { return M(x); });
|
|
||||||
|
|
||||||
// std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering));
|
|
||||||
const auto [ndC, lvls] = makeBinaryOrdering(ordC);
|
|
||||||
std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering));
|
|
||||||
}
|
|
||||||
auto ordering_full = Ordering(ordering);
|
|
||||||
|
|
||||||
const auto [hbt, remaining] =
|
|
||||||
hfg->eliminatePartialMultifrontal(ordering_full);
|
|
||||||
|
|
||||||
auto new_fg = makeSwitchingChain(12);
|
|
||||||
auto isam = HybridGaussianISAM(*hbt);
|
|
||||||
|
|
||||||
// Run an ISAM update.
|
|
||||||
HybridGaussianFactorGraph factorGraph;
|
|
||||||
factorGraph.push_back(new_fg->at(new_fg->size() - 2));
|
|
||||||
factorGraph.push_back(new_fg->at(new_fg->size() - 1));
|
|
||||||
isam.update(factorGraph);
|
|
||||||
|
|
||||||
// ISAM should have 12 factors after the last update
|
|
||||||
EXPECT_LONGS_EQUAL(12, isam.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
|
|
||||||
const int N = 7;
|
|
||||||
auto hfg = makeSwitchingChain(N, X);
|
|
||||||
hfg->push_back(*makeSwitchingChain(N, Y, D));
|
|
||||||
|
|
||||||
for (int t = 1; t <= N; t++) {
|
|
||||||
hfg->add(JacobianFactor(X(t), I_3x3, Y(t), -I_3x3, Vector3(1.0, 0.0, 0.0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
KeyVector ordering;
|
|
||||||
|
|
||||||
KeyVector naturalX(N);
|
|
||||||
std::iota(naturalX.begin(), naturalX.end(), 1);
|
|
||||||
KeyVector ordX;
|
|
||||||
for (size_t i = 1; i <= N; i++) {
|
|
||||||
ordX.emplace_back(X(i));
|
|
||||||
ordX.emplace_back(Y(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 1; i <= N - 1; i++) {
|
|
||||||
ordX.emplace_back(M(i));
|
|
||||||
}
|
|
||||||
for (size_t i = 1; i <= N - 1; i++) {
|
|
||||||
ordX.emplace_back(D(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
DotWriter dw;
|
|
||||||
dw.positionHints['x'] = 1;
|
|
||||||
dw.positionHints['c'] = 0;
|
|
||||||
dw.positionHints['d'] = 3;
|
|
||||||
dw.positionHints['y'] = 2;
|
|
||||||
// std::cout << hfg->dot(DefaultKeyFormatter, dw);
|
|
||||||
// std::cout << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
DotWriter dw;
|
|
||||||
dw.positionHints['y'] = 9;
|
|
||||||
// dw.positionHints['c'] = 0;
|
|
||||||
// dw.positionHints['d'] = 3;
|
|
||||||
dw.positionHints['x'] = 1;
|
|
||||||
// std::cout << "\n";
|
|
||||||
// std::cout << hfg->eliminateSequential(Ordering(ordX))
|
|
||||||
// ->dot(DefaultKeyFormatter, dw);
|
|
||||||
// hfg->eliminateMultifrontal(Ordering(ordX))->dot(std::cout);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ordering ordering_partial;
|
|
||||||
for (size_t i = 1; i <= N; i++) {
|
|
||||||
ordering_partial.emplace_back(X(i));
|
|
||||||
ordering_partial.emplace_back(Y(i));
|
|
||||||
}
|
|
||||||
const auto [hbn, remaining] =
|
|
||||||
hfg->eliminatePartialSequential(ordering_partial);
|
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(14, hbn->size());
|
|
||||||
EXPECT_LONGS_EQUAL(11, remaining->size());
|
|
||||||
|
|
||||||
{
|
|
||||||
DotWriter dw;
|
|
||||||
dw.positionHints['x'] = 1;
|
|
||||||
dw.positionHints['c'] = 0;
|
|
||||||
dw.positionHints['d'] = 3;
|
|
||||||
dw.positionHints['y'] = 2;
|
|
||||||
// std::cout << remaining->dot(DefaultKeyFormatter, dw);
|
|
||||||
// std::cout << "\n";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
// Test subset of API for switching network with 3 states.
|
||||||
|
// None of these are regression tests.
|
||||||
|
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
||||||
|
// Create switching network with three continuous variables and two discrete:
|
||||||
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
// Check size of linearized factor graph
|
||||||
|
const HybridGaussianFactorGraph &graph = s.linearizedFactorGraph;
|
||||||
|
EXPECT_LONGS_EQUAL(7, graph.size());
|
||||||
|
|
||||||
|
// Eliminate the graph
|
||||||
|
const HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||||
|
|
||||||
|
const HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
const double error = graph.error(delta);
|
||||||
|
|
||||||
|
// Check that the probability prime is the exponential of the error
|
||||||
|
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
|
||||||
|
|
||||||
|
// Check that the posterior P(M|X=continuousValues) from the Bayes net is the
|
||||||
|
// same as the same posterior from the graph. This is a sanity check that the
|
||||||
|
// elimination is done correctly.
|
||||||
|
const AlgebraicDecisionTree<Key> graphPosterior =
|
||||||
|
graph.discretePosterior(delta.continuous());
|
||||||
|
const AlgebraicDecisionTree<Key> bnPosterior =
|
||||||
|
hybridBayesNet->discretePosterior(delta.continuous());
|
||||||
|
EXPECT(assert_equal(graphPosterior, bnPosterior));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
// Select a particular continuous factor graph given a discrete assignment
|
// Select a particular continuous factor graph given a discrete assignment
|
||||||
TEST(HybridGaussianFactorGraph, DiscreteSelection) {
|
TEST(HybridGaussianFactorGraph, DiscreteSelection) {
|
||||||
Switching s(3);
|
Switching s(3);
|
||||||
|
|
@ -546,23 +449,43 @@ TEST(HybridGaussianFactorGraph, optimize) {
|
||||||
// Test adding of gaussian conditional and re-elimination.
|
// Test adding of gaussian conditional and re-elimination.
|
||||||
TEST(HybridGaussianFactorGraph, Conditionals) {
|
TEST(HybridGaussianFactorGraph, Conditionals) {
|
||||||
Switching switching(4);
|
Switching switching(4);
|
||||||
HybridGaussianFactorGraph hfg;
|
|
||||||
|
|
||||||
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
|
HybridGaussianFactorGraph hfg;
|
||||||
|
hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X0)
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
ordering.push_back(X(0));
|
ordering.push_back(X(0));
|
||||||
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
|
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);
|
||||||
|
|
||||||
hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
|
HybridGaussianFactorGraph hfg2;
|
||||||
hfg.push_back(*bayes_net);
|
hfg2.push_back(*bayes_net); // P(X0)
|
||||||
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
|
hfg2.push_back(switching.linearizedFactorGraph.at(1)); // P(X0, X1 | M0)
|
||||||
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
|
hfg2.push_back(switching.linearizedFactorGraph.at(2)); // P(X1, X2 | M1)
|
||||||
ordering.push_back(X(1));
|
hfg2.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
|
||||||
ordering.push_back(X(2));
|
ordering += X(1), X(2), M(0), M(1);
|
||||||
ordering.push_back(M(0));
|
|
||||||
ordering.push_back(M(1));
|
|
||||||
|
|
||||||
bayes_net = hfg.eliminateSequential(ordering);
|
// Created product of first two factors and check eliminate:
|
||||||
|
HybridGaussianFactorGraph fragment;
|
||||||
|
fragment.push_back(hfg2[0]);
|
||||||
|
fragment.push_back(hfg2[1]);
|
||||||
|
|
||||||
|
// Check that product
|
||||||
|
HybridGaussianProductFactor product = fragment.collectProductFactor();
|
||||||
|
auto leaf = fragment(DiscreteValues{{M(0), 0}});
|
||||||
|
EXPECT_LONGS_EQUAL(2, leaf.size());
|
||||||
|
|
||||||
|
// Check product and that pruneEmpty does not touch it
|
||||||
|
auto pruned = product.removeEmpty();
|
||||||
|
LONGS_EQUAL(2, pruned.nrLeaves());
|
||||||
|
|
||||||
|
// Test eliminate
|
||||||
|
auto [hybridConditional, factor] = fragment.eliminate({X(0)});
|
||||||
|
EXPECT(hybridConditional->isHybrid());
|
||||||
|
EXPECT(hybridConditional->keys() == KeyVector({X(0), X(1), M(0)}));
|
||||||
|
|
||||||
|
EXPECT(dynamic_pointer_cast<HybridGaussianFactor>(factor));
|
||||||
|
EXPECT(factor->keys() == KeyVector({X(1), M(0)}));
|
||||||
|
|
||||||
|
bayes_net = hfg2.eliminateSequential(ordering);
|
||||||
|
|
||||||
HybridValues result = bayes_net->optimize();
|
HybridValues result = bayes_net->optimize();
|
||||||
|
|
||||||
|
|
@ -582,55 +505,7 @@ TEST(HybridGaussianFactorGraph, Conditionals) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
// Test hybrid gaussian factor graph errorTree during incremental operation
|
||||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {
|
|
||||||
Switching s(3);
|
|
||||||
|
|
||||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
|
||||||
|
|
||||||
const HybridValues delta = hybridBayesNet->optimize();
|
|
||||||
const double error = graph.error(delta);
|
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(1.58886, error, 1e-5));
|
|
||||||
|
|
||||||
// Real test:
|
|
||||||
EXPECT(assert_equal(graph.probPrime(delta), exp(-error), 1e-7));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test hybrid gaussian factor graph error and unnormalized probabilities
|
|
||||||
TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
|
||||||
Switching s(3);
|
|
||||||
|
|
||||||
HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
|
||||||
auto error_tree = graph.errorTree(delta.continuous());
|
|
||||||
|
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
|
||||||
std::vector<double> leaves = {0.9998558, 0.4902432, 0.5193694, 0.0097568};
|
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
|
||||||
|
|
||||||
auto probabilities = graph.probPrime(delta.continuous());
|
|
||||||
std::vector<double> prob_leaves = {0.36793249, 0.61247742, 0.59489556,
|
|
||||||
0.99029064};
|
|
||||||
AlgebraicDecisionTree<Key> expected_probabilities(discrete_keys, prob_leaves);
|
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_probabilities, probabilities, 1e-7));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
|
||||||
// Test hybrid gaussian factor graph errorTree during
|
|
||||||
// incremental operation
|
|
||||||
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
Switching s(4);
|
Switching s(4);
|
||||||
|
|
||||||
|
|
@ -646,16 +521,13 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();
|
||||||
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
EXPECT_LONGS_EQUAL(5, hybridBayesNet->size());
|
||||||
|
|
||||||
|
// Check discrete posterior at optimum
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
auto error_tree = graph.errorTree(delta.continuous());
|
AlgebraicDecisionTree<Key> graphPosterior =
|
||||||
|
graph.discretePosterior(delta.continuous());
|
||||||
std::vector<DiscreteKey> discrete_keys = {{M(0), 2}, {M(1), 2}};
|
AlgebraicDecisionTree<Key> bnPosterior =
|
||||||
std::vector<double> leaves = {0.99985581, 0.4902432, 0.51936941,
|
hybridBayesNet->discretePosterior(delta.continuous());
|
||||||
0.0097568009};
|
EXPECT(assert_equal(graphPosterior, bnPosterior));
|
||||||
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
|
||||||
|
|
||||||
graph = HybridGaussianFactorGraph();
|
graph = HybridGaussianFactorGraph();
|
||||||
graph.push_back(*hybridBayesNet);
|
graph.push_back(*hybridBayesNet);
|
||||||
|
|
@ -666,28 +538,21 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
||||||
EXPECT_LONGS_EQUAL(7, hybridBayesNet->size());
|
EXPECT_LONGS_EQUAL(7, hybridBayesNet->size());
|
||||||
|
|
||||||
delta = hybridBayesNet->optimize();
|
delta = hybridBayesNet->optimize();
|
||||||
auto error_tree2 = graph.errorTree(delta.continuous());
|
graphPosterior = graph.discretePosterior(delta.continuous());
|
||||||
|
bnPosterior = hybridBayesNet->discretePosterior(delta.continuous());
|
||||||
discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
|
EXPECT(assert_equal(graphPosterior, bnPosterior));
|
||||||
leaves = {0.50985198, 0.0097577296, 0.50009425, 0,
|
|
||||||
0.52922138, 0.029127133, 0.50985105, 0.0097567964};
|
|
||||||
AlgebraicDecisionTree<Key> expected_error2(discrete_keys, leaves);
|
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
// Check that collectProductFactor works correctly.
|
||||||
// assignment.
|
TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||||
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
VectorValues vv{{Z(0), Vector1(5.0)}};
|
||||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
auto fg = tiny::createHybridGaussianFactorGraph(num_measurements, vv);
|
||||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
// Assemble graph tree:
|
// Assemble graph tree:
|
||||||
auto actual = fg.assembleGraphTree();
|
auto actual = fg.collectProductFactor();
|
||||||
|
|
||||||
// Create expected decision tree with two factor graphs:
|
// Create expected decision tree with two factor graphs:
|
||||||
|
|
||||||
|
|
@ -706,13 +571,15 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||||
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
|
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
|
||||||
|
|
||||||
// Expected decision tree with two factor graphs:
|
// Expected decision tree with two factor graphs:
|
||||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
// f(x0;mode=0)P(x0)
|
||||||
GaussianFactorGraphTree expected{
|
GaussianFactorGraph expectedFG0{(*hybrid)(d0).first, prior};
|
||||||
M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
|
EXPECT(assert_equal(expectedFG0, actual(d0).first, 1e-5));
|
||||||
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})};
|
EXPECT(assert_equal(0.0, actual(d0).second, 1e-5));
|
||||||
|
|
||||||
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
// f(x0;mode=1)P(x0)
|
||||||
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
GaussianFactorGraph expectedFG1{(*hybrid)(d1).first, prior};
|
||||||
|
EXPECT(assert_equal(expectedFG1, actual(d1).first, 1e-5));
|
||||||
|
EXPECT(assert_equal(1.79176, actual(d1).second, 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
|
|
@ -752,7 +619,6 @@ bool ratioTest(const HybridBayesNet &bn, const VectorValues &measurements,
|
||||||
// Test ratios for a number of independent samples:
|
// Test ratios for a number of independent samples:
|
||||||
for (size_t i = 0; i < num_samples; i++) {
|
for (size_t i = 0; i < num_samples; i++) {
|
||||||
HybridValues sample = bn.sample(&kRng);
|
HybridValues sample = bn.sample(&kRng);
|
||||||
// GTSAM_PRINT(sample);
|
|
||||||
// std::cout << "ratio: " << compute_ratio(&sample) << std::endl;
|
// std::cout << "ratio: " << compute_ratio(&sample) << std::endl;
|
||||||
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
if (std::abs(expected_ratio - compute_ratio(&sample)) > 1e-6) return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@
|
||||||
* -------------------------------------------------------------------------- */
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @file testHybridIncremental.cpp
|
* @file testHybridGaussianISAM.cpp
|
||||||
* @brief Unit tests for incremental inference
|
* @brief Unit tests for incremental inference
|
||||||
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
|
* @author Fan Jiang, Varun Agrawal, Frank Dellaert
|
||||||
* @date Jan 2021
|
* @date Jan 2021
|
||||||
|
|
@ -27,8 +27,6 @@
|
||||||
#include <gtsam/nonlinear/PriorFactor.h>
|
#include <gtsam/nonlinear/PriorFactor.h>
|
||||||
#include <gtsam/sam/BearingRangeFactor.h>
|
#include <gtsam/sam/BearingRangeFactor.h>
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
|
|
||||||
#include "Switching.h"
|
#include "Switching.h"
|
||||||
|
|
||||||
// Include for test suite
|
// Include for test suite
|
||||||
|
|
@ -36,77 +34,63 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
using noiseModel::Isotropic;
|
|
||||||
using symbol_shorthand::L;
|
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::W;
|
using symbol_shorthand::W;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
using symbol_shorthand::Y;
|
using symbol_shorthand::Y;
|
||||||
using symbol_shorthand::Z;
|
using symbol_shorthand::Z;
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
namespace switching3 {
|
||||||
|
// ϕ(x0) ϕ(x0,x1,m0) ϕ(x1,x2,m1) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0) ϕ(m0,m1)
|
||||||
|
const Switching switching(3);
|
||||||
|
const HybridGaussianFactorGraph &lfg = switching.linearizedFactorGraph;
|
||||||
|
|
||||||
|
// First update graph: ϕ(x0) ϕ(x0,x1,m0) ϕ(m0)
|
||||||
|
const HybridGaussianFactorGraph graph1{lfg.at(0), lfg.at(1), lfg.at(5)};
|
||||||
|
|
||||||
|
// Second update graph: ϕ(x1,x2,m1) ϕ(x1;z1) ϕ(x2;z2) ϕ(m0,m1)
|
||||||
|
const HybridGaussianFactorGraph graph2{lfg.at(2), lfg.at(3), lfg.at(4),
|
||||||
|
lfg.at(6)};
|
||||||
|
} // namespace switching3
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test if we can perform elimination incrementally.
|
// Test if we can perform elimination incrementally.
|
||||||
TEST(HybridGaussianElimination, IncrementalElimination) {
|
TEST(HybridGaussianElimination, IncrementalElimination) {
|
||||||
Switching switching(3);
|
using namespace switching3;
|
||||||
HybridGaussianISAM isam;
|
HybridGaussianISAM isam;
|
||||||
HybridGaussianFactorGraph graph1;
|
|
||||||
|
|
||||||
// Create initial factor graph
|
// Run first update step
|
||||||
// * * *
|
|
||||||
// | | |
|
|
||||||
// X0 -*- X1 -*- X2
|
|
||||||
// \*-M0-*/
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(0)); // P(X0)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(1)); // P(X0, X1 | M0)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(2)); // P(X1, X2 | M1)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(5)); // P(M0)
|
|
||||||
|
|
||||||
// Run update step
|
|
||||||
isam.update(graph1);
|
isam.update(graph1);
|
||||||
|
|
||||||
// Check that after update we have 2 hybrid Bayes net nodes:
|
// Check that after update we have 2 hybrid Bayes net nodes:
|
||||||
// P(X0 | X1, M0) and P(X1, X2 | M0, M1), P(M0, M1)
|
// P(M0) and P(X0, X1 | M0)
|
||||||
EXPECT_LONGS_EQUAL(3, isam.size());
|
EXPECT_LONGS_EQUAL(2, isam.size());
|
||||||
EXPECT(isam[X(0)]->conditional()->frontals() == KeyVector{X(0)});
|
EXPECT(isam[M(0)]->conditional()->frontals() == KeyVector({M(0)}));
|
||||||
EXPECT(isam[X(0)]->conditional()->parents() == KeyVector({X(1), M(0)}));
|
EXPECT(isam[M(0)]->conditional()->parents() == KeyVector());
|
||||||
EXPECT(isam[X(1)]->conditional()->frontals() == KeyVector({X(1), X(2)}));
|
EXPECT(isam[X(0)]->conditional()->frontals() == KeyVector({X(0), X(1)}));
|
||||||
EXPECT(isam[X(1)]->conditional()->parents() == KeyVector({M(0), M(1)}));
|
EXPECT(isam[X(0)]->conditional()->parents() == KeyVector({M(0)}));
|
||||||
|
|
||||||
/********************************************************/
|
/********************************************************/
|
||||||
// New factor graph for incremental update.
|
// Run second update step
|
||||||
HybridGaussianFactorGraph graph2;
|
|
||||||
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(3)); // P(X1)
|
|
||||||
graph2.push_back(switching.linearizedFactorGraph.at(4)); // P(X2)
|
|
||||||
graph2.push_back(switching.linearizedFactorGraph.at(6)); // P(M0, M1)
|
|
||||||
|
|
||||||
isam.update(graph2);
|
isam.update(graph2);
|
||||||
|
|
||||||
// Check that after the second update we have
|
// Check that after update we have 3 hybrid Bayes net nodes:
|
||||||
// 1 additional hybrid Bayes net node:
|
// P(X1, X2 | M0, M1) P(X1, X2 | M0, M1)
|
||||||
// P(X1, X2 | M0, M1)
|
|
||||||
EXPECT_LONGS_EQUAL(3, isam.size());
|
EXPECT_LONGS_EQUAL(3, isam.size());
|
||||||
EXPECT(isam[X(2)]->conditional()->frontals() == KeyVector({X(1), X(2)}));
|
EXPECT(isam[M(0)]->conditional()->frontals() == KeyVector({M(0), M(1)}));
|
||||||
EXPECT(isam[X(2)]->conditional()->parents() == KeyVector({M(0), M(1)}));
|
EXPECT(isam[M(0)]->conditional()->parents() == KeyVector());
|
||||||
|
EXPECT(isam[X(1)]->conditional()->frontals() == KeyVector({X(1), X(2)}));
|
||||||
|
EXPECT(isam[X(1)]->conditional()->parents() == KeyVector({M(0), M(1)}));
|
||||||
|
EXPECT(isam[X(0)]->conditional()->frontals() == KeyVector{X(0)});
|
||||||
|
EXPECT(isam[X(0)]->conditional()->parents() == KeyVector({X(1), M(0)}));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test if we can incrementally do the inference
|
// Test if we can incrementally do the inference
|
||||||
TEST(HybridGaussianElimination, IncrementalInference) {
|
TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
Switching switching(3);
|
using namespace switching3;
|
||||||
HybridGaussianISAM isam;
|
HybridGaussianISAM isam;
|
||||||
HybridGaussianFactorGraph graph1;
|
|
||||||
|
|
||||||
// Create initial factor graph
|
|
||||||
// * * *
|
|
||||||
// | | |
|
|
||||||
// X0 -*- X1 -*- X2
|
|
||||||
// | |
|
|
||||||
// *-M0 - * - M1
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(0)); // P(X0)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(1)); // P(X0, X1 | M0)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(3)); // P(X1)
|
|
||||||
graph1.push_back(switching.linearizedFactorGraph.at(5)); // P(M0)
|
|
||||||
|
|
||||||
// Run update step
|
// Run update step
|
||||||
isam.update(graph1);
|
isam.update(graph1);
|
||||||
|
|
@ -115,13 +99,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
|
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
|
||||||
|
|
||||||
/********************************************************/
|
/********************************************************/
|
||||||
// New factor graph for incremental update.
|
// Second incremental update.
|
||||||
HybridGaussianFactorGraph graph2;
|
|
||||||
|
|
||||||
graph2.push_back(switching.linearizedFactorGraph.at(2)); // P(X1, X2 | M1)
|
|
||||||
graph2.push_back(switching.linearizedFactorGraph.at(4)); // P(X2)
|
|
||||||
graph2.push_back(switching.linearizedFactorGraph.at(6)); // P(M0, M1)
|
|
||||||
|
|
||||||
isam.update(graph2);
|
isam.update(graph2);
|
||||||
|
|
||||||
/********************************************************/
|
/********************************************************/
|
||||||
|
|
@ -160,44 +138,19 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
||||||
// The other discrete probabilities on M(2) are calculated the same way
|
// The other discrete probabilities on M(2) are calculated the same way
|
||||||
const Ordering discreteOrdering{M(0), M(1)};
|
const Ordering discreteOrdering{M(0), M(1)};
|
||||||
HybridBayesTree::shared_ptr discreteBayesTree =
|
HybridBayesTree::shared_ptr discreteBayesTree =
|
||||||
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal(
|
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
|
||||||
discreteOrdering);
|
|
||||||
|
|
||||||
DiscreteValues m00;
|
|
||||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
|
||||||
DiscreteConditional decisionTree =
|
|
||||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
|
|
||||||
double m00_prob = decisionTree(m00);
|
|
||||||
|
|
||||||
auto discreteConditional = isam[M(1)]->conditional()->asDiscrete();
|
|
||||||
|
|
||||||
// Test the probability values with regression tests.
|
// Test the probability values with regression tests.
|
||||||
DiscreteValues assignment;
|
auto discrete = isam[M(1)]->conditional()->asDiscrete();
|
||||||
EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
|
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||||
assignment[M(0)] = 0;
|
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||||
assignment[M(1)] = 0;
|
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.307775, (*discrete)({{M(0), 1}, {M(1), 1}}), 1e-5));
|
||||||
assignment[M(0)] = 1;
|
|
||||||
assignment[M(1)] = 0;
|
|
||||||
EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
assignment[M(0)] = 0;
|
|
||||||
assignment[M(1)] = 1;
|
|
||||||
EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
assignment[M(0)] = 1;
|
|
||||||
assignment[M(1)] = 1;
|
|
||||||
EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
|
|
||||||
// Check if the clique conditional generated from incremental elimination
|
// Check that the clique conditional generated from incremental elimination
|
||||||
// matches that of batch elimination.
|
// matches that of batch elimination.
|
||||||
auto expectedChordal =
|
auto expectedConditional = (*discreteBayesTree)[M(1)]->conditional();
|
||||||
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal();
|
auto actualConditional = isam[M(1)]->conditional();
|
||||||
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
|
||||||
isam[M(1)]->conditional()->inner());
|
|
||||||
// Account for the probability terms from evaluating continuous FGs
|
|
||||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
|
||||||
vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
|
|
||||||
auto expectedConditional =
|
|
||||||
std::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
|
||||||
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -227,7 +180,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we calculate the actual factors using full elimination
|
// Now we calculate the actual factors using full elimination
|
||||||
const auto [unprunedHybridBayesTree, unprunedRemainingGraph] =
|
const auto [unPrunedHybridBayesTree, unPrunedRemainingGraph] =
|
||||||
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
switching.linearizedFactorGraph.eliminatePartialMultifrontal(ordering);
|
||||||
|
|
||||||
size_t maxNrLeaves = 5;
|
size_t maxNrLeaves = 5;
|
||||||
|
|
@ -236,7 +189,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
||||||
incrementalHybrid.prune(maxNrLeaves);
|
incrementalHybrid.prune(maxNrLeaves);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
unpruned factor is:
|
unPruned factor is:
|
||||||
Choice(m3)
|
Choice(m3)
|
||||||
0 Choice(m2)
|
0 Choice(m2)
|
||||||
0 0 Choice(m1)
|
0 0 Choice(m1)
|
||||||
|
|
@ -282,8 +235,8 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
||||||
|
|
||||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||||
// bayes net, at the same positions.
|
// bayes net, at the same positions.
|
||||||
auto &unprunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
auto &unPrunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
||||||
unprunedHybridBayesTree->clique(X(3))->conditional()->inner());
|
unPrunedHybridBayesTree->clique(X(3))->conditional()->inner());
|
||||||
auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
||||||
incrementalHybrid[X(3)]->conditional()->inner());
|
incrementalHybrid[X(3)]->conditional()->inner());
|
||||||
|
|
||||||
|
|
@ -298,7 +251,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
||||||
EXPECT(lastDensity(assignment) == nullptr);
|
EXPECT(lastDensity(assignment) == nullptr);
|
||||||
} else {
|
} else {
|
||||||
CHECK(lastDensity(assignment));
|
CHECK(lastDensity(assignment));
|
||||||
EXPECT(assert_equal(*unprunedLastDensity(assignment),
|
EXPECT(assert_equal(*unPrunedLastDensity(assignment),
|
||||||
*lastDensity(assignment)));
|
*lastDensity(assignment)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -306,7 +259,7 @@ TEST(HybridGaussianElimination, Approx_inference) {
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test approximate inference with an additional pruning step.
|
// Test approximate inference with an additional pruning step.
|
||||||
TEST(HybridGaussianElimination, Incremental_approximate) {
|
TEST(HybridGaussianElimination, IncrementalApproximate) {
|
||||||
Switching switching(5);
|
Switching switching(5);
|
||||||
HybridGaussianISAM incrementalHybrid;
|
HybridGaussianISAM incrementalHybrid;
|
||||||
HybridGaussianFactorGraph graph1;
|
HybridGaussianFactorGraph graph1;
|
||||||
|
|
@ -330,7 +283,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
||||||
incrementalHybrid.prune(maxComponents);
|
incrementalHybrid.prune(maxComponents);
|
||||||
|
|
||||||
// Check if we have a bayes tree with 4 hybrid nodes,
|
// Check if we have a bayes tree with 4 hybrid nodes,
|
||||||
// each with 2, 4, 8, and 5 (pruned) leaves respetively.
|
// each with 2, 4, 8, and 5 (pruned) leaves respectively.
|
||||||
EXPECT_LONGS_EQUAL(4, incrementalHybrid.size());
|
EXPECT_LONGS_EQUAL(4, incrementalHybrid.size());
|
||||||
EXPECT_LONGS_EQUAL(
|
EXPECT_LONGS_EQUAL(
|
||||||
2, incrementalHybrid[X(0)]->conditional()->asHybrid()->nrComponents());
|
2, incrementalHybrid[X(0)]->conditional()->asHybrid()->nrComponents());
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file testHybridGaussianProductFactor.cpp
|
||||||
|
* @brief Unit tests for HybridGaussianProductFactor
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date October 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/TestableAssertions.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
#include <gtsam/linear/JacobianFactor.h>
|
||||||
|
|
||||||
|
// Include for test suite
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
using symbol_shorthand::M;
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
namespace examples {
|
||||||
|
static const DiscreteKey m1(M(1), 2), m2(M(2), 3);
|
||||||
|
|
||||||
|
const auto A1 = Matrix::Zero(2, 1);
|
||||||
|
const auto A2 = Matrix::Zero(2, 2);
|
||||||
|
const auto b = Matrix::Zero(2, 1);
|
||||||
|
|
||||||
|
const auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||||
|
const auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||||
|
const HybridGaussianFactor hybridFactorA(m1, {{f10, 10}, {f11, 11}});
|
||||||
|
|
||||||
|
const auto A3 = Matrix::Zero(2, 3);
|
||||||
|
const auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
|
const auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
|
const auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||||
|
|
||||||
|
const HybridGaussianFactor hybridFactorB(m2, {{f20, 20}, {f21, 21}, {f22, 22}});
|
||||||
|
// Simulate a pruned hybrid factor, in this case m2==1 is nulled out.
|
||||||
|
const HybridGaussianFactor prunedFactorB(
|
||||||
|
m2, {{f20, 20}, {nullptr, 1000}, {f22, 22}});
|
||||||
|
} // namespace examples
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Constructor
|
||||||
|
TEST(HybridGaussianProductFactor, Construct) {
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Add two Gaussian factors and check only one leaf in tree
|
||||||
|
TEST(HybridGaussianProductFactor, AddTwoGaussianFactors) {
|
||||||
|
using namespace examples;
|
||||||
|
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
product += f10;
|
||||||
|
product += f11;
|
||||||
|
|
||||||
|
// Check that the product has only one leaf and no discrete variables.
|
||||||
|
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
|
||||||
|
EXPECT(product.labels().empty());
|
||||||
|
|
||||||
|
// Retrieve the single leaf
|
||||||
|
auto leaf = product(Assignment<Key>());
|
||||||
|
|
||||||
|
// Check that the leaf contains both factors
|
||||||
|
EXPECT_LONGS_EQUAL(2, leaf.first.size());
|
||||||
|
EXPECT(leaf.first.at(0) == f10);
|
||||||
|
EXPECT(leaf.first.at(1) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, leaf.second, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Add two GaussianConditionals and check the resulting tree
|
||||||
|
TEST(HybridGaussianProductFactor, AddTwoGaussianConditionals) {
|
||||||
|
// Create two GaussianConditionals
|
||||||
|
Vector1 d(1.0);
|
||||||
|
Matrix11 R = I_1x1, S = I_1x1;
|
||||||
|
auto gc1 = std::make_shared<GaussianConditional>(X(1), d, R, X(2), S);
|
||||||
|
auto gc2 = std::make_shared<GaussianConditional>(X(2), d, R);
|
||||||
|
|
||||||
|
// Create a HybridGaussianProductFactor and add the conditionals
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
product += std::static_pointer_cast<GaussianFactor>(gc1);
|
||||||
|
product += std::static_pointer_cast<GaussianFactor>(gc2);
|
||||||
|
|
||||||
|
// Check that the product has only one leaf and no discrete variables
|
||||||
|
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
|
||||||
|
EXPECT(product.labels().empty());
|
||||||
|
|
||||||
|
// Retrieve the single leaf
|
||||||
|
auto leaf = product(Assignment<Key>());
|
||||||
|
|
||||||
|
// Check that the leaf contains both conditionals
|
||||||
|
EXPECT_LONGS_EQUAL(2, leaf.first.size());
|
||||||
|
EXPECT(leaf.first.at(0) == gc1);
|
||||||
|
EXPECT(leaf.first.at(1) == gc2);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0, leaf.second, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check AsProductFactor
|
||||||
|
TEST(HybridGaussianProductFactor, AsProductFactor) {
|
||||||
|
using namespace examples;
|
||||||
|
auto product = hybridFactorA.asProductFactor();
|
||||||
|
|
||||||
|
// Let's check that this worked:
|
||||||
|
Assignment<Key> mode;
|
||||||
|
mode[m1.first] = 0;
|
||||||
|
auto actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f10);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
|
mode[m1.first] = 1;
|
||||||
|
actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// "Add" one hybrid factors together.
|
||||||
|
TEST(HybridGaussianProductFactor, AddOne) {
|
||||||
|
using namespace examples;
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
product += hybridFactorA;
|
||||||
|
|
||||||
|
// Let's check that this worked:
|
||||||
|
Assignment<Key> mode;
|
||||||
|
mode[m1.first] = 0;
|
||||||
|
auto actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f10);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
|
mode[m1.first] = 1;
|
||||||
|
actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// "Add" two HFG together.
|
||||||
|
TEST(HybridGaussianProductFactor, AddTwo) {
|
||||||
|
using namespace examples;
|
||||||
|
|
||||||
|
// Create product of two hybrid factors: it will be a decision tree now on
|
||||||
|
// both discrete variables m1 and m2:
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
product += hybridFactorA;
|
||||||
|
product += hybridFactorB;
|
||||||
|
|
||||||
|
// Let's check that this worked:
|
||||||
|
auto actual00 = product({{M(1), 0}, {M(2), 0}});
|
||||||
|
EXPECT(actual00.first.at(0) == f10);
|
||||||
|
EXPECT(actual00.first.at(1) == f20);
|
||||||
|
EXPECT_DOUBLES_EQUAL(10 + 20, actual00.second, 1e-9);
|
||||||
|
|
||||||
|
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
||||||
|
EXPECT(actual12.first.at(0) == f11);
|
||||||
|
EXPECT(actual12.first.at(1) == f22);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11 + 22, actual12.second, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// "Add" two HFG together.
|
||||||
|
TEST(HybridGaussianProductFactor, AddPruned) {
|
||||||
|
using namespace examples;
|
||||||
|
|
||||||
|
// Create product of two hybrid factors: it will be a decision tree now on
|
||||||
|
// both discrete variables m1 and m2:
|
||||||
|
HybridGaussianProductFactor product;
|
||||||
|
product += hybridFactorA;
|
||||||
|
product += prunedFactorB;
|
||||||
|
EXPECT_LONGS_EQUAL(6, product.nrLeaves());
|
||||||
|
|
||||||
|
auto pruned = product.removeEmpty();
|
||||||
|
EXPECT_LONGS_EQUAL(5, pruned.nrLeaves());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -192,24 +192,29 @@ TEST(HybridGaussianFactor, TwoStateModel2) {
|
||||||
HybridBayesNet hbn = CreateBayesNet(hybridMotionModel);
|
HybridBayesNet hbn = CreateBayesNet(hybridMotionModel);
|
||||||
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
|
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
|
||||||
|
|
||||||
// Check that ratio of Bayes net and factor graph for different modes is
|
HybridBayesNet::shared_ptr eliminated = gfg.eliminateSequential();
|
||||||
// equal for several values of {x0,x1}.
|
|
||||||
for (VectorValues vv :
|
for (VectorValues vv :
|
||||||
{VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}},
|
{VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}},
|
||||||
VectorValues{{X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}}}) {
|
VectorValues{{X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}}}) {
|
||||||
vv.insert(given); // add measurements for HBN
|
vv.insert(given); // add measurements for HBN
|
||||||
HybridValues hv0(vv, {{M(1), 0}}), hv1(vv, {{M(1), 1}});
|
const auto& expectedDiscretePosterior = hbn.discretePosterior(vv);
|
||||||
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
|
|
||||||
gfg.error(hv1) / hbn.error(hv1), 1e-9);
|
|
||||||
}
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
|
// Equality of posteriors asserts that the factor graph is correct (same
|
||||||
|
// ratios for all modes)
|
||||||
|
EXPECT(
|
||||||
|
assert_equal(expectedDiscretePosterior, gfg.discretePosterior(vv)));
|
||||||
|
|
||||||
|
// This one asserts that HBN resulting from elimination is correct.
|
||||||
|
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||||
|
eliminated->discretePosterior(vv)));
|
||||||
|
}
|
||||||
|
|
||||||
// Importance sampling run with 100k samples gives 50.095/49.905
|
// Importance sampling run with 100k samples gives 50.095/49.905
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
|
|
||||||
// Since no measurement on x1, we a 50/50 probability
|
// Since no measurement on x1, we a 50/50 probability
|
||||||
auto p_m = bn->at(2)->asDiscrete();
|
auto p_m = eliminated->at(2)->asDiscrete();
|
||||||
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()({{M(1), 0}}), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()({{M(1), 0}}), 1e-9);
|
||||||
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()({{M(1), 1}}), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.5, p_m->operator()({{M(1), 1}}), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
@ -221,6 +226,7 @@ TEST(HybridGaussianFactor, TwoStateModel2) {
|
||||||
|
|
||||||
HybridBayesNet hbn = CreateBayesNet(hybridMotionModel, true);
|
HybridBayesNet hbn = CreateBayesNet(hybridMotionModel, true);
|
||||||
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
|
HybridGaussianFactorGraph gfg = hbn.toFactorGraph(given);
|
||||||
|
HybridBayesNet::shared_ptr eliminated = gfg.eliminateSequential();
|
||||||
|
|
||||||
// Check that ratio of Bayes net and factor graph for different modes is
|
// Check that ratio of Bayes net and factor graph for different modes is
|
||||||
// equal for several values of {x0,x1}.
|
// equal for several values of {x0,x1}.
|
||||||
|
|
@ -228,17 +234,22 @@ TEST(HybridGaussianFactor, TwoStateModel2) {
|
||||||
{VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}},
|
{VectorValues{{X(0), Vector1(0.0)}, {X(1), Vector1(1.0)}},
|
||||||
VectorValues{{X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}}}) {
|
VectorValues{{X(0), Vector1(0.5)}, {X(1), Vector1(3.0)}}}) {
|
||||||
vv.insert(given); // add measurements for HBN
|
vv.insert(given); // add measurements for HBN
|
||||||
HybridValues hv0(vv, {{M(1), 0}}), hv1(vv, {{M(1), 1}});
|
const auto& expectedDiscretePosterior = hbn.discretePosterior(vv);
|
||||||
EXPECT_DOUBLES_EQUAL(gfg.error(hv0) / hbn.error(hv0),
|
|
||||||
gfg.error(hv1) / hbn.error(hv1), 1e-9);
|
|
||||||
}
|
|
||||||
|
|
||||||
HybridBayesNet::shared_ptr bn = gfg.eliminateSequential();
|
// Equality of posteriors asserts that the factor graph is correct (same
|
||||||
|
// ratios for all modes)
|
||||||
|
EXPECT(
|
||||||
|
assert_equal(expectedDiscretePosterior, gfg.discretePosterior(vv)));
|
||||||
|
|
||||||
|
// This one asserts that HBN resulting from elimination is correct.
|
||||||
|
EXPECT(assert_equal(expectedDiscretePosterior,
|
||||||
|
eliminated->discretePosterior(vv)));
|
||||||
|
}
|
||||||
|
|
||||||
// Values taken from an importance sampling run with 100k samples:
|
// Values taken from an importance sampling run with 100k samples:
|
||||||
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
// approximateDiscreteMarginal(hbn, hybridMotionModel, given);
|
||||||
DiscreteConditional expected(m1, "48.3158/51.6842");
|
DiscreteConditional expected(m1, "48.3158/51.6842");
|
||||||
EXPECT(assert_equal(expected, *(bn->at(2)->asDiscrete()), 0.002));
|
EXPECT(assert_equal(expected, *(eliminated->at(2)->asDiscrete()), 0.002));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -211,13 +211,44 @@ TEST(HybridNonlinearFactorGraph, PushBack) {
|
||||||
// EXPECT_LONGS_EQUAL(3, hnfg.size());
|
// EXPECT_LONGS_EQUAL(3, hnfg.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test hybrid nonlinear factor graph errorTree
|
||||||
|
TEST(HybridNonlinearFactorGraph, ErrorTree) {
|
||||||
|
Switching s(3);
|
||||||
|
|
||||||
|
const HybridNonlinearFactorGraph &graph = s.nonlinearFactorGraph();
|
||||||
|
const Values &values = s.linearizationPoint;
|
||||||
|
|
||||||
|
auto error_tree = graph.errorTree(s.linearizationPoint);
|
||||||
|
|
||||||
|
auto dkeys = graph.discreteKeys();
|
||||||
|
DiscreteKeys discrete_keys(dkeys.begin(), dkeys.end());
|
||||||
|
|
||||||
|
// Compute the sum of errors for each factor.
|
||||||
|
auto assignments = DiscreteValues::CartesianProduct(discrete_keys);
|
||||||
|
std::vector<double> leaves(assignments.size());
|
||||||
|
for (auto &&factor : graph) {
|
||||||
|
for (size_t i = 0; i < assignments.size(); ++i) {
|
||||||
|
leaves[i] +=
|
||||||
|
factor->error(HybridValues(VectorValues(), assignments[i], values));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Swap i=1 and i=2 to give correct ordering.
|
||||||
|
double temp = leaves[1];
|
||||||
|
leaves[1] = leaves[2];
|
||||||
|
leaves[2] = temp;
|
||||||
|
AlgebraicDecisionTree<Key> expected_error(discrete_keys, leaves);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_error, error_tree, 1e-7));
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************
|
/****************************************************************************
|
||||||
* Test construction of switching-like hybrid factor graph.
|
* Test construction of switching-like hybrid factor graph.
|
||||||
*/
|
*/
|
||||||
TEST(HybridNonlinearFactorGraph, Switching) {
|
TEST(HybridNonlinearFactorGraph, Switching) {
|
||||||
Switching self(3);
|
Switching self(3);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph.size());
|
EXPECT_LONGS_EQUAL(7, self.nonlinearFactorGraph().size());
|
||||||
EXPECT_LONGS_EQUAL(7, self.linearizedFactorGraph.size());
|
EXPECT_LONGS_EQUAL(7, self.linearizedFactorGraph.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,7 +260,7 @@ TEST(HybridNonlinearFactorGraph, Linearization) {
|
||||||
|
|
||||||
// Linearize here:
|
// Linearize here:
|
||||||
HybridGaussianFactorGraph actualLinearized =
|
HybridGaussianFactorGraph actualLinearized =
|
||||||
*self.nonlinearFactorGraph.linearize(self.linearizationPoint);
|
*self.nonlinearFactorGraph().linearize(self.linearizationPoint);
|
||||||
|
|
||||||
EXPECT_LONGS_EQUAL(7, actualLinearized.size());
|
EXPECT_LONGS_EQUAL(7, actualLinearized.size());
|
||||||
}
|
}
|
||||||
|
|
@ -378,7 +409,7 @@ TEST(HybridNonlinearFactorGraph, Partial_Elimination) {
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
TEST(HybridNonlinearFactorGraph, Error) {
|
TEST(HybridNonlinearFactorGraph, Error) {
|
||||||
Switching self(3);
|
Switching self(3);
|
||||||
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph;
|
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph();
|
||||||
|
|
||||||
{
|
{
|
||||||
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}},
|
HybridValues values(VectorValues(), DiscreteValues{{M(0), 0}, {M(1), 0}},
|
||||||
|
|
@ -410,8 +441,9 @@ TEST(HybridNonlinearFactorGraph, Error) {
|
||||||
TEST(HybridNonlinearFactorGraph, PrintErrors) {
|
TEST(HybridNonlinearFactorGraph, PrintErrors) {
|
||||||
Switching self(3);
|
Switching self(3);
|
||||||
|
|
||||||
// Get nonlinear factor graph and add linear factors to be holistic
|
// Get nonlinear factor graph and add linear factors to be holistic.
|
||||||
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph;
|
// TODO(Frank): ???
|
||||||
|
HybridNonlinearFactorGraph fg = self.nonlinearFactorGraph();
|
||||||
fg.add(self.linearizedFactorGraph);
|
fg.add(self.linearizedFactorGraph);
|
||||||
|
|
||||||
// Optimize to get HybridValues
|
// Optimize to get HybridValues
|
||||||
|
|
@ -514,14 +546,17 @@ TEST(HybridNonlinearFactorGraph, Printing) {
|
||||||
#ifdef GTSAM_DT_MERGING
|
#ifdef GTSAM_DT_MERGING
|
||||||
string expected_hybridFactorGraph = R"(
|
string expected_hybridFactorGraph = R"(
|
||||||
size: 7
|
size: 7
|
||||||
factor 0:
|
Factor 0
|
||||||
|
GaussianFactor:
|
||||||
|
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
10
|
10
|
||||||
]
|
]
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 1:
|
|
||||||
HybridGaussianFactor
|
Factor 1
|
||||||
|
HybridGaussianFactor:
|
||||||
Hybrid [x0 x1; m0]{
|
Hybrid [x0 x1; m0]{
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf :
|
0 Leaf :
|
||||||
|
|
@ -533,6 +568,7 @@ Hybrid [x0 x1; m0]{
|
||||||
]
|
]
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0.918939
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
|
|
@ -543,10 +579,12 @@ Hybrid [x0 x1; m0]{
|
||||||
]
|
]
|
||||||
b = [ -0 ]
|
b = [ -0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0.918939
|
||||||
|
|
||||||
}
|
}
|
||||||
factor 2:
|
|
||||||
HybridGaussianFactor
|
Factor 2
|
||||||
|
HybridGaussianFactor:
|
||||||
Hybrid [x1 x2; m1]{
|
Hybrid [x1 x2; m1]{
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Leaf :
|
0 Leaf :
|
||||||
|
|
@ -558,6 +596,7 @@ Hybrid [x1 x2; m1]{
|
||||||
]
|
]
|
||||||
b = [ -1 ]
|
b = [ -1 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0.918939
|
||||||
|
|
||||||
1 Leaf :
|
1 Leaf :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
|
|
@ -568,24 +607,37 @@ Hybrid [x1 x2; m1]{
|
||||||
]
|
]
|
||||||
b = [ -0 ]
|
b = [ -0 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
scalar: 0.918939
|
||||||
|
|
||||||
}
|
}
|
||||||
factor 3:
|
|
||||||
|
Factor 3
|
||||||
|
GaussianFactor:
|
||||||
|
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
10
|
10
|
||||||
]
|
]
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 4:
|
|
||||||
|
Factor 4
|
||||||
|
GaussianFactor:
|
||||||
|
|
||||||
A[x2] = [
|
A[x2] = [
|
||||||
10
|
10
|
||||||
]
|
]
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 5: P( m0 ):
|
|
||||||
|
Factor 5
|
||||||
|
DiscreteFactor:
|
||||||
|
P( m0 ):
|
||||||
Leaf 0.5
|
Leaf 0.5
|
||||||
|
|
||||||
factor 6: P( m1 | m0 ):
|
|
||||||
|
Factor 6
|
||||||
|
DiscreteFactor:
|
||||||
|
P( m1 | m0 ):
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
0 0 Leaf 0.33333333
|
0 0 Leaf 0.33333333
|
||||||
|
|
@ -594,6 +646,7 @@ factor 6: P( m1 | m0 ):
|
||||||
1 0 Leaf 0.66666667
|
1 0 Leaf 0.66666667
|
||||||
1 1 Leaf 0.4
|
1 1 Leaf 0.4
|
||||||
|
|
||||||
|
|
||||||
)";
|
)";
|
||||||
#else
|
#else
|
||||||
string expected_hybridFactorGraph = R"(
|
string expected_hybridFactorGraph = R"(
|
||||||
|
|
@ -686,7 +739,7 @@ factor 6: P( m1 | m0 ):
|
||||||
// Expected output for hybridBayesNet.
|
// Expected output for hybridBayesNet.
|
||||||
string expected_hybridBayesNet = R"(
|
string expected_hybridBayesNet = R"(
|
||||||
size: 3
|
size: 3
|
||||||
conditional 0: Hybrid P( x0 | x1 m0)
|
conditional 0: P( x0 | x1 m0)
|
||||||
Discrete Keys = (m0, 2),
|
Discrete Keys = (m0, 2),
|
||||||
logNormalizationConstant: 1.38862
|
logNormalizationConstant: 1.38862
|
||||||
|
|
||||||
|
|
@ -705,7 +758,7 @@ conditional 0: Hybrid P( x0 | x1 m0)
|
||||||
logNormalizationConstant: 1.38862
|
logNormalizationConstant: 1.38862
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
conditional 1: Hybrid P( x1 | x2 m0 m1)
|
conditional 1: P( x1 | x2 m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
logNormalizationConstant: 1.3935
|
logNormalizationConstant: 1.3935
|
||||||
|
|
||||||
|
|
@ -740,7 +793,7 @@ conditional 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
logNormalizationConstant: 1.3935
|
logNormalizationConstant: 1.3935
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
conditional 2: Hybrid P( x2 | m0 m1)
|
conditional 2: P( x2 | m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
logNormalizationConstant: 1.38857
|
logNormalizationConstant: 1.38857
|
||||||
|
|
||||||
|
|
@ -921,8 +974,6 @@ TEST(HybridNonlinearFactorGraph, DifferentMeans) {
|
||||||
VectorValues cont0 = bn->optimize(dv0);
|
VectorValues cont0 = bn->optimize(dv0);
|
||||||
double error0 = bn->error(HybridValues(cont0, dv0));
|
double error0 = bn->error(HybridValues(cont0, dv0));
|
||||||
|
|
||||||
// TODO(Varun) Perform importance sampling to estimate error?
|
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT_DOUBLES_EQUAL(0.69314718056, error0, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.69314718056, error0, 1e-9);
|
||||||
|
|
||||||
|
|
@ -994,16 +1045,9 @@ TEST(HybridNonlinearFactorGraph, DifferentCovariances) {
|
||||||
cv.insert(X(0), Vector1(0.0));
|
cv.insert(X(0), Vector1(0.0));
|
||||||
cv.insert(X(1), Vector1(0.0));
|
cv.insert(X(1), Vector1(0.0));
|
||||||
|
|
||||||
// Check that the error values at the MLE point μ.
|
|
||||||
AlgebraicDecisionTree<Key> errorTree = hbn->errorTree(cv);
|
|
||||||
|
|
||||||
DiscreteValues dv0{{M(1), 0}};
|
DiscreteValues dv0{{M(1), 0}};
|
||||||
DiscreteValues dv1{{M(1), 1}};
|
DiscreteValues dv1{{M(1), 1}};
|
||||||
|
|
||||||
// regression
|
|
||||||
EXPECT_DOUBLES_EQUAL(9.90348755254, errorTree(dv0), 1e-9);
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.69314718056, errorTree(dv1), 1e-9);
|
|
||||||
|
|
||||||
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
DiscreteConditional expected_m1(m1, "0.5/0.5");
|
||||||
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
DiscreteConditional actual_m1 = *(hbn->at(2)->asDiscrete());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,10 +57,10 @@ TEST(HybridNonlinearISAM, IncrementalElimination) {
|
||||||
// | | |
|
// | | |
|
||||||
// X0 -*- X1 -*- X2
|
// X0 -*- X1 -*- X2
|
||||||
// \*-M0-*/
|
// \*-M0-*/
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X0)
|
graph1.push_back(switching.unaryFactors.at(0)); // P(X0)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X0, X1 | M0)
|
graph1.push_back(switching.binaryFactors.at(0)); // P(X0, X1 | M0)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(2)); // P(X1, X2 | M1)
|
graph1.push_back(switching.binaryFactors.at(1)); // P(X1, X2 | M1)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M0)
|
graph1.push_back(switching.modeChain.at(0)); // P(M0)
|
||||||
|
|
||||||
initial.insert<double>(X(0), 1);
|
initial.insert<double>(X(0), 1);
|
||||||
initial.insert<double>(X(1), 2);
|
initial.insert<double>(X(1), 2);
|
||||||
|
|
@ -83,9 +83,9 @@ TEST(HybridNonlinearISAM, IncrementalElimination) {
|
||||||
HybridNonlinearFactorGraph graph2;
|
HybridNonlinearFactorGraph graph2;
|
||||||
initial = Values();
|
initial = Values();
|
||||||
|
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X1)
|
graph1.push_back(switching.unaryFactors.at(1)); // P(X1)
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X2)
|
graph2.push_back(switching.unaryFactors.at(2)); // P(X2)
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M0, M1)
|
graph2.push_back(switching.modeChain.at(1)); // P(M0, M1)
|
||||||
|
|
||||||
isam.update(graph2, initial);
|
isam.update(graph2, initial);
|
||||||
|
|
||||||
|
|
@ -112,10 +112,10 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
// X0 -*- X1 -*- X2
|
// X0 -*- X1 -*- X2
|
||||||
// | |
|
// | |
|
||||||
// *-M0 - * - M1
|
// *-M0 - * - M1
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(0)); // P(X0)
|
graph1.push_back(switching.unaryFactors.at(0)); // P(X0)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(1)); // P(X0, X1 | M0)
|
graph1.push_back(switching.binaryFactors.at(0)); // P(X0, X1 | M0)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(3)); // P(X1)
|
graph1.push_back(switching.unaryFactors.at(1)); // P(X1)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(5)); // P(M0)
|
graph1.push_back(switching.modeChain.at(0)); // P(M0)
|
||||||
|
|
||||||
initial.insert<double>(X(0), 1);
|
initial.insert<double>(X(0), 1);
|
||||||
initial.insert<double>(X(1), 2);
|
initial.insert<double>(X(1), 2);
|
||||||
|
|
@ -134,9 +134,9 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
|
|
||||||
initial.insert<double>(X(2), 3);
|
initial.insert<double>(X(2), 3);
|
||||||
|
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(2)); // P(X1, X2 | M1)
|
graph2.push_back(switching.binaryFactors.at(1)); // P(X1, X2 | M1)
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // P(X2)
|
graph2.push_back(switching.unaryFactors.at(2)); // P(X2)
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(6)); // P(M0, M1)
|
graph2.push_back(switching.modeChain.at(1)); // P(M0, M1)
|
||||||
|
|
||||||
isam.update(graph2, initial);
|
isam.update(graph2, initial);
|
||||||
bayesTree = isam.bayesTree();
|
bayesTree = isam.bayesTree();
|
||||||
|
|
@ -175,46 +175,22 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
||||||
EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional));
|
EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional));
|
||||||
|
|
||||||
// We only perform manual continuous elimination for 0,0.
|
// We only perform manual continuous elimination for 0,0.
|
||||||
// The other discrete probabilities on M(1) are calculated the same way
|
// The other discrete probabilities on M(2) are calculated the same way
|
||||||
const Ordering discreteOrdering{M(0), M(1)};
|
const Ordering discreteOrdering{M(0), M(1)};
|
||||||
HybridBayesTree::shared_ptr discreteBayesTree =
|
HybridBayesTree::shared_ptr discreteBayesTree =
|
||||||
expectedRemainingGraph->BaseEliminateable::eliminateMultifrontal(
|
expectedRemainingGraph->eliminateMultifrontal(discreteOrdering);
|
||||||
discreteOrdering);
|
|
||||||
|
|
||||||
DiscreteValues m00;
|
|
||||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
|
||||||
DiscreteConditional decisionTree =
|
|
||||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
|
|
||||||
double m00_prob = decisionTree(m00);
|
|
||||||
|
|
||||||
auto discreteConditional = bayesTree[M(1)]->conditional()->asDiscrete();
|
|
||||||
|
|
||||||
// Test the probability values with regression tests.
|
// Test the probability values with regression tests.
|
||||||
DiscreteValues assignment;
|
auto discrete = bayesTree[M(1)]->conditional()->asDiscrete();
|
||||||
EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
|
EXPECT(assert_equal(0.095292, (*discrete)({{M(0), 0}, {M(1), 0}}), 1e-5));
|
||||||
assignment[M(0)] = 0;
|
EXPECT(assert_equal(0.282758, (*discrete)({{M(0), 1}, {M(1), 0}}), 1e-5));
|
||||||
assignment[M(1)] = 0;
|
EXPECT(assert_equal(0.314175, (*discrete)({{M(0), 0}, {M(1), 1}}), 1e-5));
|
||||||
EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
|
EXPECT(assert_equal(0.307775, (*discrete)({{M(0), 1}, {M(1), 1}}), 1e-5));
|
||||||
assignment[M(0)] = 1;
|
|
||||||
assignment[M(1)] = 0;
|
|
||||||
EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
assignment[M(0)] = 0;
|
|
||||||
assignment[M(1)] = 1;
|
|
||||||
EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
assignment[M(0)] = 1;
|
|
||||||
assignment[M(1)] = 1;
|
|
||||||
EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
|
|
||||||
|
|
||||||
// Check if the clique conditional generated from incremental elimination
|
// Check that the clique conditional generated from incremental elimination
|
||||||
// matches that of batch elimination.
|
// matches that of batch elimination.
|
||||||
auto expectedChordal = expectedRemainingGraph->eliminateMultifrontal();
|
auto expectedConditional = (*discreteBayesTree)[M(1)]->conditional();
|
||||||
auto actualConditional = dynamic_pointer_cast<DecisionTreeFactor>(
|
auto actualConditional = bayesTree[M(1)]->conditional();
|
||||||
bayesTree[M(1)]->conditional()->inner());
|
|
||||||
// Account for the probability terms from evaluating continuous FGs
|
|
||||||
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
|
|
||||||
vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
|
|
||||||
auto expectedConditional =
|
|
||||||
std::make_shared<DecisionTreeFactor>(discrete_keys, probs);
|
|
||||||
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -227,18 +203,19 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
Values initial;
|
Values initial;
|
||||||
|
|
||||||
// Add the 3 hybrid factors, x0-x1, x1-x2, x2-x3
|
// Add the 3 hybrid factors, x0-x1, x1-x2, x2-x3
|
||||||
for (size_t i = 1; i < 4; i++) {
|
for (size_t i = 0; i < 3; i++) {
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
graph1.push_back(switching.binaryFactors.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the Gaussian factors, 1 prior on X(0),
|
// Add the Gaussian factors, 1 prior on X(0),
|
||||||
// 3 measurements on X(1), X(2), X(3)
|
// 3 measurements on X(1), X(2), X(3)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(0));
|
for (size_t i = 0; i < 4; i++) {
|
||||||
for (size_t i = 4; i <= 7; i++) {
|
graph1.push_back(switching.unaryFactors.at(i));
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
initial.insert<double>(X(i), i + 1);
|
||||||
initial.insert<double>(X(i - 4), i - 3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(Frank): no mode chain?
|
||||||
|
|
||||||
// Create ordering.
|
// Create ordering.
|
||||||
Ordering ordering;
|
Ordering ordering;
|
||||||
for (size_t j = 0; j < 4; j++) {
|
for (size_t j = 0; j < 4; j++) {
|
||||||
|
|
@ -246,7 +223,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now we calculate the actual factors using full elimination
|
// Now we calculate the actual factors using full elimination
|
||||||
const auto [unprunedHybridBayesTree, unprunedRemainingGraph] =
|
const auto [unPrunedHybridBayesTree, unPrunedRemainingGraph] =
|
||||||
switching.linearizedFactorGraph
|
switching.linearizedFactorGraph
|
||||||
.BaseEliminateable::eliminatePartialMultifrontal(ordering);
|
.BaseEliminateable::eliminatePartialMultifrontal(ordering);
|
||||||
|
|
||||||
|
|
@ -257,7 +234,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
bayesTree.prune(maxNrLeaves);
|
bayesTree.prune(maxNrLeaves);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
unpruned factor is:
|
unPruned factor is:
|
||||||
Choice(m3)
|
Choice(m3)
|
||||||
0 Choice(m2)
|
0 Choice(m2)
|
||||||
0 0 Choice(m1)
|
0 0 Choice(m1)
|
||||||
|
|
@ -303,8 +280,8 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
|
|
||||||
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
// Check that the hybrid nodes of the bayes net match those of the pre-pruning
|
||||||
// bayes net, at the same positions.
|
// bayes net, at the same positions.
|
||||||
auto &unprunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
auto &unPrunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
||||||
unprunedHybridBayesTree->clique(X(3))->conditional()->inner());
|
unPrunedHybridBayesTree->clique(X(3))->conditional()->inner());
|
||||||
auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
|
||||||
bayesTree[X(3)]->conditional()->inner());
|
bayesTree[X(3)]->conditional()->inner());
|
||||||
|
|
||||||
|
|
@ -319,7 +296,7 @@ TEST(HybridNonlinearISAM, Approx_inference) {
|
||||||
EXPECT(lastDensity(assignment) == nullptr);
|
EXPECT(lastDensity(assignment) == nullptr);
|
||||||
} else {
|
} else {
|
||||||
CHECK(lastDensity(assignment));
|
CHECK(lastDensity(assignment));
|
||||||
EXPECT(assert_equal(*unprunedLastDensity(assignment),
|
EXPECT(assert_equal(*unPrunedLastDensity(assignment),
|
||||||
*lastDensity(assignment)));
|
*lastDensity(assignment)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -335,19 +312,20 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
|
|
||||||
/***** Run Round 1 *****/
|
/***** Run Round 1 *****/
|
||||||
// Add the 3 hybrid factors, x0-x1, x1-x2, x2-x3
|
// Add the 3 hybrid factors, x0-x1, x1-x2, x2-x3
|
||||||
for (size_t i = 1; i < 4; i++) {
|
for (size_t i = 0; i < 3; i++) {
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
graph1.push_back(switching.binaryFactors.at(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the Gaussian factors, 1 prior on X(0),
|
// Add the Gaussian factors, 1 prior on X(0),
|
||||||
// 3 measurements on X(1), X(2), X(3)
|
// 3 measurements on X(1), X(2), X(3)
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(0));
|
for (size_t i = 0; i < 4; i++) {
|
||||||
initial.insert<double>(X(0), 1);
|
graph1.push_back(switching.unaryFactors.at(i));
|
||||||
for (size_t i = 5; i <= 7; i++) {
|
initial.insert<double>(X(i), i + 1);
|
||||||
graph1.push_back(switching.nonlinearFactorGraph.at(i));
|
|
||||||
initial.insert<double>(X(i - 4), i - 3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(Frank): no mode chain?
|
||||||
|
|
||||||
|
|
||||||
// Run update with pruning
|
// Run update with pruning
|
||||||
size_t maxComponents = 5;
|
size_t maxComponents = 5;
|
||||||
incrementalHybrid.update(graph1, initial);
|
incrementalHybrid.update(graph1, initial);
|
||||||
|
|
@ -368,8 +346,8 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
||||||
|
|
||||||
/***** Run Round 2 *****/
|
/***** Run Round 2 *****/
|
||||||
HybridGaussianFactorGraph graph2;
|
HybridGaussianFactorGraph graph2;
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(4)); // x3-x4
|
graph2.push_back(switching.binaryFactors.at(3)); // x3-x4
|
||||||
graph2.push_back(switching.nonlinearFactorGraph.at(8)); // x4 measurement
|
graph2.push_back(switching.unaryFactors.at(4)); // x4 measurement
|
||||||
initial = Values();
|
initial = Values();
|
||||||
initial.insert<double>(X(4), 5);
|
initial.insert<double>(X(4), 5);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,13 +52,18 @@ BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
|
||||||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs,
|
||||||
"gtsam_HybridGaussianFactor_Factors");
|
"gtsam_HybridGaussianFactor_Factors");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Leaf,
|
||||||
"gtsam_HybridGaussianFactor_Factors_Leaf");
|
"gtsam_HybridGaussianFactor_Factors_Leaf");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Choice,
|
||||||
"gtsam_HybridGaussianFactor_Factors_Choice");
|
"gtsam_HybridGaussianFactor_Factors_Choice");
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(GaussianFactorGraphValuePair,
|
||||||
|
"gtsam_GaussianFactorGraphValuePair");
|
||||||
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianProductFactor,
|
||||||
|
"gtsam_HybridGaussianProductFactor");
|
||||||
|
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,
|
||||||
"gtsam_HybridGaussianConditional");
|
"gtsam_HybridGaussianConditional");
|
||||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional::Conditionals,
|
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional::Conditionals,
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,9 @@ namespace gtsam {
|
||||||
/** Access the conditional */
|
/** Access the conditional */
|
||||||
const sharedConditional& conditional() const { return conditional_; }
|
const sharedConditional& conditional() const { return conditional_; }
|
||||||
|
|
||||||
|
/** Write access to the conditional */
|
||||||
|
sharedConditional& conditional() { return conditional_; }
|
||||||
|
|
||||||
/// Return true if this clique is the root of a Bayes tree.
|
/// Return true if this clique is the root of a Bayes tree.
|
||||||
inline bool isRoot() const { return parent_.expired(); }
|
inline bool isRoot() const { return parent_.expired(); }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file EdgeKey.cpp
|
||||||
|
* @date Oct 24, 2024
|
||||||
|
* @author: Frank Dellaert
|
||||||
|
* @author: Akshay Krishnan
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/inference/EdgeKey.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
EdgeKey::operator std::string() const {
|
||||||
|
return "{" + std::to_string(i_) + ", " + std::to_string(j_) + "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
GTSAM_EXPORT std::ostream& operator<<(std::ostream& os, const EdgeKey& key) {
|
||||||
|
os << (std::string)key;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EdgeKey::print(const std::string& s) const {
|
||||||
|
std::cout << s << *this << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file EdgeKey.h
|
||||||
|
* @date Oct 24, 2024
|
||||||
|
* @author: Frank Dellaert
|
||||||
|
* @author: Akshay Krishnan
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
class GTSAM_EXPORT EdgeKey {
|
||||||
|
protected:
|
||||||
|
std::uint32_t i_; ///< Upper 32 bits
|
||||||
|
std::uint32_t j_; ///< Lower 32 bits
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor
|
||||||
|
EdgeKey() : i_(0), j_(0) {}
|
||||||
|
|
||||||
|
/// Constructor
|
||||||
|
EdgeKey(std::uint32_t i, std::uint32_t j) : i_(i), j_(j) {}
|
||||||
|
|
||||||
|
EdgeKey(Key key)
|
||||||
|
: i_(static_cast<std::uint32_t>(key >> 32)),
|
||||||
|
j_(static_cast<std::uint32_t>(key)) {}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name API
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Cast to Key
|
||||||
|
operator Key() const { return ((std::uint64_t)i_ << 32) | j_; }
|
||||||
|
|
||||||
|
/// Retrieve high 32 bits
|
||||||
|
inline std::uint32_t i() const { return i_; }
|
||||||
|
|
||||||
|
/// Retrieve low 32 bits
|
||||||
|
inline std::uint32_t j() const { return j_; }
|
||||||
|
|
||||||
|
/** Create a string from the key */
|
||||||
|
operator std::string() const;
|
||||||
|
|
||||||
|
/// Output stream operator
|
||||||
|
friend GTSAM_EXPORT std::ostream& operator<<(std::ostream&, const EdgeKey&);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Prints the EdgeKey with an optional prefix string.
|
||||||
|
void print(const std::string& s = "") const;
|
||||||
|
|
||||||
|
/// Checks if this EdgeKey is equal to another, tolerance is ignored.
|
||||||
|
bool equals(const EdgeKey& expected, double tol = 0.0) const {
|
||||||
|
return (*this) == expected;
|
||||||
|
}
|
||||||
|
/// @}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// traits
|
||||||
|
template <>
|
||||||
|
struct traits<EdgeKey> : public Testable<EdgeKey> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -19,9 +19,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <functional>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/inference/Symbol.h>
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -38,82 +40,110 @@ protected:
|
||||||
std::uint64_t j_;
|
std::uint64_t j_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/** Default constructor */
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor
|
||||||
LabeledSymbol();
|
LabeledSymbol();
|
||||||
|
|
||||||
/** Copy constructor */
|
/// Copy constructor
|
||||||
LabeledSymbol(const LabeledSymbol& key);
|
LabeledSymbol(const LabeledSymbol& key);
|
||||||
|
|
||||||
/** Constructor */
|
/// Constructor fro characters c and label, and integer j
|
||||||
LabeledSymbol(unsigned char c, unsigned char label, std::uint64_t j);
|
LabeledSymbol(unsigned char c, unsigned char label, std::uint64_t j);
|
||||||
|
|
||||||
/** Constructor that decodes an integer gtsam::Key */
|
/// Constructor that decodes an integer Key
|
||||||
LabeledSymbol(gtsam::Key key);
|
LabeledSymbol(Key key);
|
||||||
|
|
||||||
/** Cast to integer */
|
/// @}
|
||||||
operator gtsam::Key() const;
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
// Testable Requirements
|
/// Prints the LabeledSymbol with an optional prefix string.
|
||||||
void print(const std::string& s = "") const;
|
void print(const std::string& s = "") const;
|
||||||
|
|
||||||
|
/// Checks if this LabeledSymbol is equal to another, tolerance is ignored.
|
||||||
bool equals(const LabeledSymbol& expected, double tol = 0.0) const {
|
bool equals(const LabeledSymbol& expected, double tol = 0.0) const {
|
||||||
return (*this) == expected;
|
return (*this) == expected;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** return the integer version */
|
/// @}
|
||||||
gtsam::Key key() const { return (gtsam::Key) *this; }
|
/// @name Standard API
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Retrieve label character */
|
/// Cast to Key
|
||||||
|
operator Key() const;
|
||||||
|
|
||||||
|
/// return the integer version
|
||||||
|
Key key() const { return (Key) * this; }
|
||||||
|
|
||||||
|
/// Retrieve label character
|
||||||
inline unsigned char label() const { return label_; }
|
inline unsigned char label() const { return label_; }
|
||||||
|
|
||||||
/** Retrieve key character */
|
/// Retrieve key character
|
||||||
inline unsigned char chr() const { return c_; }
|
inline unsigned char chr() const { return c_; }
|
||||||
|
|
||||||
/** Retrieve key index */
|
/// Retrieve key index
|
||||||
inline size_t index() const { return j_; }
|
inline size_t index() const { return j_; }
|
||||||
|
|
||||||
/** Create a string from the key */
|
/// Create a string from the key
|
||||||
operator std::string() const;
|
operator std::string() const;
|
||||||
|
|
||||||
/** Comparison for use in maps */
|
/// Output stream operator that can be used with key_formatter (see Key.h).
|
||||||
|
friend GTSAM_EXPORT std::ostream& operator<<(std::ostream&,
|
||||||
|
const LabeledSymbol&);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Comparison
|
||||||
|
/// @{
|
||||||
|
|
||||||
bool operator<(const LabeledSymbol& comp) const;
|
bool operator<(const LabeledSymbol& comp) const;
|
||||||
bool operator==(const LabeledSymbol& comp) const;
|
bool operator==(const LabeledSymbol& comp) const;
|
||||||
bool operator==(gtsam::Key comp) const;
|
bool operator==(Key comp) const;
|
||||||
bool operator!=(const LabeledSymbol& comp) const;
|
bool operator!=(const LabeledSymbol& comp) const;
|
||||||
bool operator!=(gtsam::Key comp) const;
|
bool operator!=(Key comp) const;
|
||||||
|
|
||||||
/** Return a filter function that returns true when evaluated on a gtsam::Key whose
|
/// @}
|
||||||
* character (when converted to a LabeledSymbol) matches \c c. Use this with the
|
/// @name Filtering
|
||||||
* Values::filter() function to retrieve all key-value pairs with the
|
/// @{
|
||||||
* requested character.
|
/// Return a filter function that returns true when evaluated on a Key whose
|
||||||
*/
|
/// character (when converted to a LabeledSymbol) matches \c c. Use this with
|
||||||
|
/// the Values::filter() function to retrieve all key-value pairs with the
|
||||||
|
/// requested character.
|
||||||
|
|
||||||
// Checks only the type
|
/// Checks only the type
|
||||||
static std::function<bool(gtsam::Key)> TypeTest(unsigned char c);
|
static std::function<bool(Key)> TypeTest(unsigned char c);
|
||||||
|
|
||||||
// Checks only the robot ID (label_)
|
/// Checks only the robot ID (label_)
|
||||||
static std::function<bool(gtsam::Key)> LabelTest(unsigned char label);
|
static std::function<bool(Key)> LabelTest(unsigned char label);
|
||||||
|
|
||||||
// Checks both type and the robot ID
|
/// Checks both type and the robot ID
|
||||||
static std::function<bool(gtsam::Key)> TypeLabelTest(unsigned char c, unsigned char label);
|
static std::function<bool(Key)> TypeLabelTest(unsigned char c,
|
||||||
|
unsigned char label);
|
||||||
|
|
||||||
// Converts to upper/lower versions of labels
|
/// @}
|
||||||
|
/// @name Advanced API
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Converts to upper/lower versions of labels
|
||||||
LabeledSymbol upper() const { return LabeledSymbol(c_, toupper(label_), j_); }
|
LabeledSymbol upper() const { return LabeledSymbol(c_, toupper(label_), j_); }
|
||||||
LabeledSymbol lower() const { return LabeledSymbol(c_, tolower(label_), j_); }
|
LabeledSymbol lower() const { return LabeledSymbol(c_, tolower(label_), j_); }
|
||||||
|
|
||||||
// Create a new symbol with a different character.
|
/// Create a new symbol with a different character.
|
||||||
LabeledSymbol newChr(unsigned char c) const { return LabeledSymbol(c, label_, j_); }
|
LabeledSymbol newChr(unsigned char c) const {
|
||||||
|
return LabeledSymbol(c, label_, j_);
|
||||||
|
}
|
||||||
|
|
||||||
// Create a new symbol with a different label.
|
/// Create a new symbol with a different label.
|
||||||
LabeledSymbol newLabel(unsigned char label) const { return LabeledSymbol(c_, label, j_); }
|
LabeledSymbol newLabel(unsigned char label) const {
|
||||||
|
return LabeledSymbol(c_, label, j_);
|
||||||
|
}
|
||||||
|
|
||||||
/// Output stream operator that can be used with key_formatter (see Key.h).
|
/// @}
|
||||||
friend GTSAM_EXPORT std::ostream &operator<<(std::ostream &, const LabeledSymbol &);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/// Serialization function
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
template <class ARCHIVE>
|
template <class ARCHIVE>
|
||||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
|
@ -124,22 +154,24 @@ private:
|
||||||
#endif
|
#endif
|
||||||
}; // \class LabeledSymbol
|
}; // \class LabeledSymbol
|
||||||
|
|
||||||
/** Create a symbol key from a character, label and index, i.e. xA5. */
|
/// Create a symbol key from a character, label and index, i.e. xA5.
|
||||||
inline Key mrsymbol(unsigned char c, unsigned char label, size_t j) {
|
inline Key mrsymbol(unsigned char c, unsigned char label, size_t j) {
|
||||||
return (Key)LabeledSymbol(c, label, j);
|
return (Key)LabeledSymbol(c, label, j);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Return the character portion of a symbol key. */
|
/// Return the character portion of a symbol key.
|
||||||
inline unsigned char mrsymbolChr(Key key) { return LabeledSymbol(key).chr(); }
|
inline unsigned char mrsymbolChr(Key key) { return LabeledSymbol(key).chr(); }
|
||||||
|
|
||||||
/** Return the label portion of a symbol key. */
|
/// Return the label portion of a symbol key.
|
||||||
inline unsigned char mrsymbolLabel(Key key) { return LabeledSymbol(key).label(); }
|
inline unsigned char mrsymbolLabel(Key key) {
|
||||||
|
return LabeledSymbol(key).label();
|
||||||
|
}
|
||||||
|
|
||||||
/** Return the index portion of a symbol key. */
|
/// Return the index portion of a symbol key.
|
||||||
inline size_t mrsymbolIndex(Key key) { return LabeledSymbol(key).index(); }
|
inline size_t mrsymbolIndex(Key key) { return LabeledSymbol(key).index(); }
|
||||||
|
|
||||||
/// traits
|
/// traits
|
||||||
template<> struct traits<LabeledSymbol> : public Testable<LabeledSymbol> {};
|
template <>
|
||||||
|
struct traits<LabeledSymbol> : public Testable<LabeledSymbol> {};
|
||||||
} // \namespace gtsam
|
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @file testEdgeKey.cpp
|
||||||
|
* @date Oct 24, 2024
|
||||||
|
* @author: Frank Dellaert
|
||||||
|
* @author: Akshay Krishnan
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/inference/EdgeKey.h>
|
||||||
|
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(EdgeKey, Construction) {
|
||||||
|
EdgeKey edge(1, 2);
|
||||||
|
EXPECT(edge.i() == 1);
|
||||||
|
EXPECT(edge.j() == 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(EdgeKey, Equality) {
|
||||||
|
EdgeKey edge1(1, 2);
|
||||||
|
EdgeKey edge2(1, 2);
|
||||||
|
EdgeKey edge3(2, 3);
|
||||||
|
|
||||||
|
EXPECT(assert_equal(edge1, edge2));
|
||||||
|
EXPECT(!edge1.equals(edge3));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(EdgeKey, StreamOutput) {
|
||||||
|
EdgeKey edge(1, 2);
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << edge;
|
||||||
|
EXPECT("{1, 2}" == oss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -28,11 +28,11 @@ namespace gtsam {
|
||||||
void ConjugateGradientParameters::print(ostream &os) const {
|
void ConjugateGradientParameters::print(ostream &os) const {
|
||||||
Base::print(os);
|
Base::print(os);
|
||||||
cout << "ConjugateGradientParameters" << endl
|
cout << "ConjugateGradientParameters" << endl
|
||||||
<< "minIter: " << minIterations_ << endl
|
<< "minIter: " << minIterations << endl
|
||||||
<< "maxIter: " << maxIterations_ << endl
|
<< "maxIter: " << maxIterations << endl
|
||||||
<< "resetIter: " << reset_ << endl
|
<< "resetIter: " << reset << endl
|
||||||
<< "eps_rel: " << epsilon_rel_ << endl
|
<< "eps_rel: " << epsilon_rel << endl
|
||||||
<< "eps_abs: " << epsilon_abs_ << endl;
|
<< "eps_abs: " << epsilon_abs << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -24,59 +24,66 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* parameters for the conjugate gradient method
|
* Parameters for the Conjugate Gradient method
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT ConjugateGradientParameters : public IterativeOptimizationParameters {
|
struct GTSAM_EXPORT ConjugateGradientParameters
|
||||||
|
: public IterativeOptimizationParameters {
|
||||||
public:
|
|
||||||
typedef IterativeOptimizationParameters Base;
|
typedef IterativeOptimizationParameters Base;
|
||||||
typedef std::shared_ptr<ConjugateGradientParameters> shared_ptr;
|
typedef std::shared_ptr<ConjugateGradientParameters> shared_ptr;
|
||||||
|
|
||||||
size_t minIterations_; ///< minimum number of cg iterations
|
size_t minIterations; ///< minimum number of cg iterations
|
||||||
size_t maxIterations_; ///< maximum number of cg iterations
|
size_t maxIterations; ///< maximum number of cg iterations
|
||||||
size_t reset_; ///< number of iterations before reset
|
size_t reset; ///< number of iterations before reset
|
||||||
double epsilon_rel_; ///< threshold for relative error decrease
|
double epsilon_rel; ///< threshold for relative error decrease
|
||||||
double epsilon_abs_; ///< threshold for absolute error decrease
|
double epsilon_abs; ///< threshold for absolute error decrease
|
||||||
|
|
||||||
/* Matrix Operation Kernel */
|
/* Matrix Operation Kernel */
|
||||||
enum BLASKernel {
|
enum BLASKernel {
|
||||||
GTSAM = 0, ///< Jacobian Factor Graph of GTSAM
|
GTSAM = 0, ///< Jacobian Factor Graph of GTSAM
|
||||||
} blas_kernel_ ;
|
} blas_kernel;
|
||||||
|
|
||||||
ConjugateGradientParameters()
|
ConjugateGradientParameters()
|
||||||
: minIterations_(1), maxIterations_(500), reset_(501), epsilon_rel_(1e-3),
|
: minIterations(1),
|
||||||
epsilon_abs_(1e-3), blas_kernel_(GTSAM) {}
|
maxIterations(500),
|
||||||
|
reset(501),
|
||||||
|
epsilon_rel(1e-3),
|
||||||
|
epsilon_abs(1e-3),
|
||||||
|
blas_kernel(GTSAM) {}
|
||||||
|
|
||||||
ConjugateGradientParameters(size_t minIterations, size_t maxIterations, size_t reset,
|
ConjugateGradientParameters(size_t minIterations, size_t maxIterations,
|
||||||
double epsilon_rel, double epsilon_abs, BLASKernel blas)
|
size_t reset, double epsilon_rel,
|
||||||
: minIterations_(minIterations), maxIterations_(maxIterations), reset_(reset),
|
double epsilon_abs, BLASKernel blas)
|
||||||
epsilon_rel_(epsilon_rel), epsilon_abs_(epsilon_abs), blas_kernel_(blas) {}
|
: minIterations(minIterations),
|
||||||
|
maxIterations(maxIterations),
|
||||||
|
reset(reset),
|
||||||
|
epsilon_rel(epsilon_rel),
|
||||||
|
epsilon_abs(epsilon_abs),
|
||||||
|
blas_kernel(blas) {}
|
||||||
|
|
||||||
ConjugateGradientParameters(const ConjugateGradientParameters &p)
|
ConjugateGradientParameters(const ConjugateGradientParameters &p)
|
||||||
: Base(p), minIterations_(p.minIterations_), maxIterations_(p.maxIterations_), reset_(p.reset_),
|
: Base(p),
|
||||||
epsilon_rel_(p.epsilon_rel_), epsilon_abs_(p.epsilon_abs_), blas_kernel_(GTSAM) {}
|
minIterations(p.minIterations),
|
||||||
|
maxIterations(p.maxIterations),
|
||||||
|
reset(p.reset),
|
||||||
|
epsilon_rel(p.epsilon_rel),
|
||||||
|
epsilon_abs(p.epsilon_abs),
|
||||||
|
blas_kernel(GTSAM) {}
|
||||||
|
|
||||||
/* general interface */
|
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V43
|
||||||
inline size_t minIterations() const { return minIterations_; }
|
inline size_t getMinIterations() const { return minIterations; }
|
||||||
inline size_t maxIterations() const { return maxIterations_; }
|
inline size_t getMaxIterations() const { return maxIterations; }
|
||||||
inline size_t reset() const { return reset_; }
|
inline size_t getReset() const { return reset; }
|
||||||
inline double epsilon() const { return epsilon_rel_; }
|
inline double getEpsilon() const { return epsilon_rel; }
|
||||||
inline double epsilon_rel() const { return epsilon_rel_; }
|
inline double getEpsilon_rel() const { return epsilon_rel; }
|
||||||
inline double epsilon_abs() const { return epsilon_abs_; }
|
inline double getEpsilon_abs() const { return epsilon_abs; }
|
||||||
|
|
||||||
inline size_t getMinIterations() const { return minIterations_; }
|
inline void setMinIterations(size_t value) { minIterations = value; }
|
||||||
inline size_t getMaxIterations() const { return maxIterations_; }
|
inline void setMaxIterations(size_t value) { maxIterations = value; }
|
||||||
inline size_t getReset() const { return reset_; }
|
inline void setReset(size_t value) { reset = value; }
|
||||||
inline double getEpsilon() const { return epsilon_rel_; }
|
inline void setEpsilon(double value) { epsilon_rel = value; }
|
||||||
inline double getEpsilon_rel() const { return epsilon_rel_; }
|
inline void setEpsilon_rel(double value) { epsilon_rel = value; }
|
||||||
inline double getEpsilon_abs() const { return epsilon_abs_; }
|
inline void setEpsilon_abs(double value) { epsilon_abs = value; }
|
||||||
|
#endif
|
||||||
inline void setMinIterations(size_t value) { minIterations_ = value; }
|
|
||||||
inline void setMaxIterations(size_t value) { maxIterations_ = value; }
|
|
||||||
inline void setReset(size_t value) { reset_ = value; }
|
|
||||||
inline void setEpsilon(double value) { epsilon_rel_ = value; }
|
|
||||||
inline void setEpsilon_rel(double value) { epsilon_rel_ = value; }
|
|
||||||
inline void setEpsilon_abs(double value) { epsilon_abs_ = value; }
|
|
||||||
|
|
||||||
|
|
||||||
void print() const { Base::print(); }
|
void print() const { Base::print(); }
|
||||||
|
|
@ -109,16 +116,17 @@ V preconditionedConjugateGradient(const S &system, const V &initial,
|
||||||
|
|
||||||
double currentGamma = system.dot(residual, residual), prevGamma, alpha, beta;
|
double currentGamma = system.dot(residual, residual), prevGamma, alpha, beta;
|
||||||
|
|
||||||
const size_t iMaxIterations = parameters.maxIterations(),
|
const size_t iMaxIterations = parameters.maxIterations,
|
||||||
iMinIterations = parameters.minIterations(),
|
iMinIterations = parameters.minIterations,
|
||||||
iReset = parameters.reset() ;
|
iReset = parameters.reset;
|
||||||
const double threshold = std::max(parameters.epsilon_abs(),
|
const double threshold =
|
||||||
parameters.epsilon() * parameters.epsilon() * currentGamma);
|
std::max(parameters.epsilon_abs,
|
||||||
|
parameters.epsilon_rel * parameters.epsilon_rel * currentGamma);
|
||||||
|
|
||||||
if (parameters.verbosity() >= ConjugateGradientParameters::COMPLEXITY)
|
if (parameters.verbosity() >= ConjugateGradientParameters::COMPLEXITY)
|
||||||
std::cout << "[PCG] epsilon = " << parameters.epsilon()
|
std::cout << "[PCG] epsilon = " << parameters.epsilon_rel
|
||||||
<< ", max = " << parameters.maxIterations()
|
<< ", max = " << parameters.maxIterations
|
||||||
<< ", reset = " << parameters.reset()
|
<< ", reset = " << parameters.reset
|
||||||
<< ", ||r0||^2 = " << currentGamma
|
<< ", ||r0||^2 = " << currentGamma
|
||||||
<< ", threshold = " << threshold << std::endl;
|
<< ", threshold = " << threshold << std::endl;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,17 +34,13 @@ namespace gtsam {
|
||||||
void PCGSolverParameters::print(ostream &os) const {
|
void PCGSolverParameters::print(ostream &os) const {
|
||||||
Base::print(os);
|
Base::print(os);
|
||||||
os << "PCGSolverParameters:" << endl;
|
os << "PCGSolverParameters:" << endl;
|
||||||
preconditioner_->print(os);
|
preconditioner->print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*****************************************************************************/
|
/*****************************************************************************/
|
||||||
PCGSolver::PCGSolver(const PCGSolverParameters &p) {
|
PCGSolver::PCGSolver(const PCGSolverParameters &p) {
|
||||||
parameters_ = p;
|
parameters_ = p;
|
||||||
preconditioner_ = createPreconditioner(p.preconditioner_);
|
preconditioner_ = createPreconditioner(p.preconditioner);
|
||||||
}
|
|
||||||
|
|
||||||
void PCGSolverParameters::setPreconditionerParams(const std::shared_ptr<PreconditionerParameters> preconditioner) {
|
|
||||||
preconditioner_ = preconditioner;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PCGSolverParameters::print(const std::string &s) const {
|
void PCGSolverParameters::print(const std::string &s) const {
|
||||||
|
|
|
||||||
|
|
@ -31,29 +31,22 @@ class VectorValues;
|
||||||
struct PreconditionerParameters;
|
struct PreconditionerParameters;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameters for PCG
|
* Parameters for Preconditioned Conjugate Gradient solver.
|
||||||
*/
|
*/
|
||||||
struct GTSAM_EXPORT PCGSolverParameters : public ConjugateGradientParameters {
|
struct GTSAM_EXPORT PCGSolverParameters : public ConjugateGradientParameters {
|
||||||
public:
|
|
||||||
typedef ConjugateGradientParameters Base;
|
typedef ConjugateGradientParameters Base;
|
||||||
typedef std::shared_ptr<PCGSolverParameters> shared_ptr;
|
typedef std::shared_ptr<PCGSolverParameters> shared_ptr;
|
||||||
|
|
||||||
PCGSolverParameters() {
|
std::shared_ptr<PreconditionerParameters> preconditioner;
|
||||||
}
|
|
||||||
|
PCGSolverParameters() {}
|
||||||
|
|
||||||
|
PCGSolverParameters(
|
||||||
|
const std::shared_ptr<PreconditionerParameters> &preconditioner)
|
||||||
|
: preconditioner(preconditioner) {}
|
||||||
|
|
||||||
void print(std::ostream &os) const override;
|
void print(std::ostream &os) const override;
|
||||||
|
|
||||||
/* interface to preconditioner parameters */
|
|
||||||
inline const PreconditionerParameters& preconditioner() const {
|
|
||||||
return *preconditioner_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// needed for python wrapper
|
|
||||||
void print(const std::string &s) const;
|
void print(const std::string &s) const;
|
||||||
|
|
||||||
std::shared_ptr<PreconditionerParameters> preconditioner_;
|
|
||||||
|
|
||||||
void setPreconditionerParams(const std::shared_ptr<PreconditionerParameters> preconditioner);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -87,16 +80,16 @@ public:
|
||||||
* System class needed for calling preconditionedConjugateGradient
|
* System class needed for calling preconditionedConjugateGradient
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT GaussianFactorGraphSystem {
|
class GTSAM_EXPORT GaussianFactorGraphSystem {
|
||||||
public:
|
|
||||||
|
|
||||||
GaussianFactorGraphSystem(const GaussianFactorGraph &gfg,
|
|
||||||
const Preconditioner &preconditioner, const KeyInfo &info,
|
|
||||||
const std::map<Key, Vector> &lambda);
|
|
||||||
|
|
||||||
const GaussianFactorGraph &gfg_;
|
const GaussianFactorGraph &gfg_;
|
||||||
const Preconditioner &preconditioner_;
|
const Preconditioner &preconditioner_;
|
||||||
const KeyInfo &keyInfo_;
|
KeyInfo keyInfo_;
|
||||||
const std::map<Key, Vector> &lambda_;
|
std::map<Key, Vector> lambda_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GaussianFactorGraphSystem(const GaussianFactorGraph &gfg,
|
||||||
|
const Preconditioner &preconditioner,
|
||||||
|
const KeyInfo &info,
|
||||||
|
const std::map<Key, Vector> &lambda);
|
||||||
|
|
||||||
void residual(const Vector &x, Vector &r) const;
|
void residual(const Vector &x, Vector &r) const;
|
||||||
void multiply(const Vector &x, Vector& y) const;
|
void multiply(const Vector &x, Vector& y) const;
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,12 @@ namespace gtsam {
|
||||||
|
|
||||||
// init gamma and calculate threshold
|
// init gamma and calculate threshold
|
||||||
gamma = dot(g,g);
|
gamma = dot(g,g);
|
||||||
threshold = std::max(parameters_.epsilon_abs(), parameters_.epsilon() * parameters_.epsilon() * gamma);
|
threshold =
|
||||||
|
std::max(parameters_.epsilon_abs,
|
||||||
|
parameters_.epsilon_rel * parameters_.epsilon_rel * gamma);
|
||||||
|
|
||||||
// Allocate and calculate A*d for first iteration
|
// Allocate and calculate A*d for first iteration
|
||||||
if (gamma > parameters_.epsilon_abs()) Ad = Ab * d;
|
if (gamma > parameters_.epsilon_abs) Ad = Ab * d;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -79,13 +81,13 @@ namespace gtsam {
|
||||||
// take a step, return true if converged
|
// take a step, return true if converged
|
||||||
bool step(const S& Ab, V& x) {
|
bool step(const S& Ab, V& x) {
|
||||||
|
|
||||||
if ((++k) >= ((int)parameters_.maxIterations())) return true;
|
if ((++k) >= ((int)parameters_.maxIterations)) return true;
|
||||||
|
|
||||||
//---------------------------------->
|
//---------------------------------->
|
||||||
double alpha = takeOptimalStep(x);
|
double alpha = takeOptimalStep(x);
|
||||||
|
|
||||||
// update gradient (or re-calculate at reset time)
|
// update gradient (or re-calculate at reset time)
|
||||||
if (k % parameters_.reset() == 0) g = Ab.gradient(x);
|
if (k % parameters_.reset == 0) g = Ab.gradient(x);
|
||||||
// axpy(alpha, Ab ^ Ad, g); // g += alpha*(Ab^Ad)
|
// axpy(alpha, Ab ^ Ad, g); // g += alpha*(Ab^Ad)
|
||||||
else Ab.transposeMultiplyAdd(alpha, Ad, g);
|
else Ab.transposeMultiplyAdd(alpha, Ad, g);
|
||||||
|
|
||||||
|
|
@ -126,11 +128,10 @@ namespace gtsam {
|
||||||
CGState<S, V, E> state(Ab, x, parameters, steepest);
|
CGState<S, V, E> state(Ab, x, parameters, steepest);
|
||||||
|
|
||||||
if (parameters.verbosity() != ConjugateGradientParameters::SILENT)
|
if (parameters.verbosity() != ConjugateGradientParameters::SILENT)
|
||||||
std::cout << "CG: epsilon = " << parameters.epsilon()
|
std::cout << "CG: epsilon = " << parameters.epsilon_rel
|
||||||
<< ", maxIterations = " << parameters.maxIterations()
|
<< ", maxIterations = " << parameters.maxIterations
|
||||||
<< ", ||g0||^2 = " << state.gamma
|
<< ", ||g0||^2 = " << state.gamma
|
||||||
<< ", threshold = " << state.threshold
|
<< ", threshold = " << state.threshold << std::endl;
|
||||||
<< std::endl;
|
|
||||||
|
|
||||||
if ( state.gamma < state.threshold ) {
|
if ( state.gamma < state.threshold ) {
|
||||||
if (parameters.verbosity() != ConjugateGradientParameters::SILENT)
|
if (parameters.verbosity() != ConjugateGradientParameters::SILENT)
|
||||||
|
|
|
||||||
|
|
@ -710,17 +710,11 @@ virtual class IterativeOptimizationParameters {
|
||||||
#include <gtsam/linear/ConjugateGradientSolver.h>
|
#include <gtsam/linear/ConjugateGradientSolver.h>
|
||||||
virtual class ConjugateGradientParameters : gtsam::IterativeOptimizationParameters {
|
virtual class ConjugateGradientParameters : gtsam::IterativeOptimizationParameters {
|
||||||
ConjugateGradientParameters();
|
ConjugateGradientParameters();
|
||||||
int getMinIterations() const ;
|
int minIterations;
|
||||||
int getMaxIterations() const ;
|
int maxIterations;
|
||||||
int getReset() const;
|
int reset;
|
||||||
double getEpsilon_rel() const;
|
double epsilon_rel;
|
||||||
double getEpsilon_abs() const;
|
double epsilon_abs;
|
||||||
|
|
||||||
void setMinIterations(int value);
|
|
||||||
void setMaxIterations(int value);
|
|
||||||
void setReset(int value);
|
|
||||||
void setEpsilon_rel(double value);
|
|
||||||
void setEpsilon_abs(double value);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/Preconditioner.h>
|
#include <gtsam/linear/Preconditioner.h>
|
||||||
|
|
@ -739,8 +733,10 @@ virtual class BlockJacobiPreconditionerParameters : gtsam::PreconditionerParamet
|
||||||
#include <gtsam/linear/PCGSolver.h>
|
#include <gtsam/linear/PCGSolver.h>
|
||||||
virtual class PCGSolverParameters : gtsam::ConjugateGradientParameters {
|
virtual class PCGSolverParameters : gtsam::ConjugateGradientParameters {
|
||||||
PCGSolverParameters();
|
PCGSolverParameters();
|
||||||
|
PCGSolverParameters(const gtsam::PreconditionerParameters* preconditioner);
|
||||||
void print(string s = "");
|
void print(string s = "");
|
||||||
void setPreconditionerParams(gtsam::PreconditionerParameters* preconditioner);
|
|
||||||
|
std::shared_ptr<gtsam::PreconditionerParameters> preconditioner;
|
||||||
};
|
};
|
||||||
|
|
||||||
#include <gtsam/linear/SubgraphSolver.h>
|
#include <gtsam/linear/SubgraphSolver.h>
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,14 @@ VectorValues DoglegOptimizerImpl::ComputeBlend(double delta, const VectorValues&
|
||||||
double tau1 = (-b + sqrt_b_m4ac) / (2.*a);
|
double tau1 = (-b + sqrt_b_m4ac) / (2.*a);
|
||||||
double tau2 = (-b - sqrt_b_m4ac) / (2.*a);
|
double tau2 = (-b - sqrt_b_m4ac) / (2.*a);
|
||||||
|
|
||||||
|
// Determine correct solution accounting for machine precision
|
||||||
double tau;
|
double tau;
|
||||||
if(0.0 <= tau1 && tau1 <= 1.0) {
|
const double eps = std::numeric_limits<double>::epsilon();
|
||||||
assert(!(0.0 <= tau2 && tau2 <= 1.0));
|
if(-eps <= tau1 && tau1 <= 1.0 + eps) {
|
||||||
|
assert(!(-eps <= tau2 && tau2 <= 1.0 + eps));
|
||||||
tau = tau1;
|
tau = tau1;
|
||||||
} else {
|
} else {
|
||||||
assert(0.0 <= tau2 && tau2 <= 1.0);
|
assert(-eps <= tau2 && tau2 <= 1.0 + eps);
|
||||||
tau = tau2;
|
tau = tau2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,11 @@
|
||||||
* @date Jun 11, 2012
|
* @date Jun 11, 2012
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h>
|
|
||||||
#include <gtsam/nonlinear/internal/NonlinearOptimizerState.h>
|
|
||||||
#include <gtsam/nonlinear/Values.h>
|
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/linear/VectorValues.h>
|
#include <gtsam/linear/VectorValues.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/nonlinear/internal/NonlinearOptimizerState.h>
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
|
|
@ -42,21 +42,27 @@ static VectorValues gradientInPlace(const NonlinearFactorGraph &nfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
|
NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
|
||||||
const NonlinearFactorGraph& graph, const Values& initialValues, const Parameters& params)
|
const NonlinearFactorGraph& graph, const Values& initialValues,
|
||||||
: Base(graph, std::unique_ptr<State>(new State(initialValues, graph.error(initialValues)))),
|
const Parameters& params, const DirectionMethod& directionMethod)
|
||||||
|
: Base(graph, std::unique_ptr<State>(
|
||||||
|
new State(initialValues, graph.error(initialValues)))),
|
||||||
params_(params) {}
|
params_(params) {}
|
||||||
|
|
||||||
double NonlinearConjugateGradientOptimizer::System::error(const State& state) const {
|
double NonlinearConjugateGradientOptimizer::System::error(
|
||||||
|
const State& state) const {
|
||||||
return graph_.error(state);
|
return graph_.error(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
NonlinearConjugateGradientOptimizer::System::Gradient NonlinearConjugateGradientOptimizer::System::gradient(
|
NonlinearConjugateGradientOptimizer::System::Gradient
|
||||||
|
NonlinearConjugateGradientOptimizer::System::gradient(
|
||||||
const State& state) const {
|
const State& state) const {
|
||||||
return gradientInPlace(graph_, state);
|
return gradientInPlace(graph_, state);
|
||||||
}
|
}
|
||||||
|
|
||||||
NonlinearConjugateGradientOptimizer::System::State NonlinearConjugateGradientOptimizer::System::advance(
|
NonlinearConjugateGradientOptimizer::System::State
|
||||||
const State ¤t, const double alpha, const Gradient &g) const {
|
NonlinearConjugateGradientOptimizer::System::advance(const State& current,
|
||||||
|
const double alpha,
|
||||||
|
const Gradient& g) const {
|
||||||
Gradient step = g;
|
Gradient step = g;
|
||||||
step *= alpha;
|
step *= alpha;
|
||||||
return current.retract(step);
|
return current.retract(step);
|
||||||
|
|
@ -64,8 +70,10 @@ NonlinearConjugateGradientOptimizer::System::State NonlinearConjugateGradientOpt
|
||||||
|
|
||||||
GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
|
GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
|
||||||
const auto [newValues, dummy] = nonlinearConjugateGradient<System, Values>(
|
const auto [newValues, dummy] = nonlinearConjugateGradient<System, Values>(
|
||||||
System(graph_), state_->values, params_, true /* single iteration */);
|
System(graph_), state_->values, params_, true /* single iteration */,
|
||||||
state_.reset(new State(newValues, graph_.error(newValues), state_->iterations + 1));
|
directionMethod_);
|
||||||
|
state_.reset(
|
||||||
|
new State(newValues, graph_.error(newValues), state_->iterations + 1));
|
||||||
|
|
||||||
// NOTE(frank): We don't linearize this system, so we must return null here.
|
// NOTE(frank): We don't linearize this system, so we must return null here.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -74,11 +82,11 @@ GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
|
||||||
const Values& NonlinearConjugateGradientOptimizer::optimize() {
|
const Values& NonlinearConjugateGradientOptimizer::optimize() {
|
||||||
// Optimize until convergence
|
// Optimize until convergence
|
||||||
System system(graph_);
|
System system(graph_);
|
||||||
const auto [newValues, iterations] =
|
const auto [newValues, iterations] = nonlinearConjugateGradient(
|
||||||
nonlinearConjugateGradient(system, state_->values, params_, false);
|
system, state_->values, params_, false, directionMethod_);
|
||||||
state_.reset(new State(std::move(newValues), graph_.error(newValues), iterations));
|
state_.reset(
|
||||||
|
new State(std::move(newValues), graph_.error(newValues), iterations));
|
||||||
return state_->values;
|
return state_->values;
|
||||||
}
|
}
|
||||||
|
|
||||||
} /* namespace gtsam */
|
} /* namespace gtsam */
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,60 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/** An implementation of the nonlinear CG method using the template below */
|
/// Fletcher-Reeves formula for computing β, the direction of steepest descent.
|
||||||
class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimizer {
|
template <typename Gradient>
|
||||||
|
double FletcherReeves(const Gradient ¤tGradient,
|
||||||
|
const Gradient &prevGradient) {
|
||||||
|
// Fletcher-Reeves: beta = g_n'*g_n/g_n-1'*g_n-1
|
||||||
|
const double beta =
|
||||||
|
currentGradient.dot(currentGradient) / prevGradient.dot(prevGradient);
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Polak-Ribiere formula for computing β, the direction of steepest descent.
|
||||||
|
template <typename Gradient>
|
||||||
|
double PolakRibiere(const Gradient ¤tGradient,
|
||||||
|
const Gradient &prevGradient) {
|
||||||
|
// Polak-Ribiere: beta = g_n'*(g_n-g_n-1)/g_n-1'*g_n-1
|
||||||
|
const double beta =
|
||||||
|
std::max(0.0, currentGradient.dot(currentGradient - prevGradient) /
|
||||||
|
prevGradient.dot(prevGradient));
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The Hestenes-Stiefel formula for computing β,
|
||||||
|
/// the direction of steepest descent.
|
||||||
|
template <typename Gradient>
|
||||||
|
double HestenesStiefel(const Gradient ¤tGradient,
|
||||||
|
const Gradient &prevGradient,
|
||||||
|
const Gradient &direction) {
|
||||||
|
// Hestenes-Stiefel: beta = g_n'*(g_n-g_n-1)/(-s_n-1')*(g_n-g_n-1)
|
||||||
|
Gradient d = currentGradient - prevGradient;
|
||||||
|
const double beta = std::max(0.0, currentGradient.dot(d) / -direction.dot(d));
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The Dai-Yuan formula for computing β, the direction of steepest descent.
|
||||||
|
template <typename Gradient>
|
||||||
|
double DaiYuan(const Gradient ¤tGradient, const Gradient &prevGradient,
|
||||||
|
const Gradient &direction) {
|
||||||
|
// Dai-Yuan: beta = g_n'*g_n/(-s_n-1')*(g_n-g_n-1)
|
||||||
|
const double beta =
|
||||||
|
std::max(0.0, currentGradient.dot(currentGradient) /
|
||||||
|
-direction.dot(currentGradient - prevGradient));
|
||||||
|
return beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum class DirectionMethod {
|
||||||
|
FletcherReeves,
|
||||||
|
PolakRibiere,
|
||||||
|
HestenesStiefel,
|
||||||
|
DaiYuan
|
||||||
|
};
|
||||||
|
|
||||||
|
/** An implementation of the nonlinear CG method using the template below */
|
||||||
|
class GTSAM_EXPORT NonlinearConjugateGradientOptimizer
|
||||||
|
: public NonlinearOptimizer {
|
||||||
/* a class for the nonlinearConjugateGradient template */
|
/* a class for the nonlinearConjugateGradient template */
|
||||||
class System {
|
class System {
|
||||||
public:
|
public:
|
||||||
|
|
@ -37,9 +88,7 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
|
||||||
const NonlinearFactorGraph &graph_;
|
const NonlinearFactorGraph &graph_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
System(const NonlinearFactorGraph &graph) :
|
System(const NonlinearFactorGraph &graph) : graph_(graph) {}
|
||||||
graph_(graph) {
|
|
||||||
}
|
|
||||||
double error(const State &state) const;
|
double error(const State &state) const;
|
||||||
Gradient gradient(const State &state) const;
|
Gradient gradient(const State &state) const;
|
||||||
State advance(const State ¤t, const double alpha,
|
State advance(const State ¤t, const double alpha,
|
||||||
|
|
@ -47,27 +96,25 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer : public NonlinearOptimiz
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef NonlinearOptimizer Base;
|
typedef NonlinearOptimizer Base;
|
||||||
typedef NonlinearOptimizerParams Parameters;
|
typedef NonlinearOptimizerParams Parameters;
|
||||||
typedef std::shared_ptr<NonlinearConjugateGradientOptimizer> shared_ptr;
|
typedef std::shared_ptr<NonlinearConjugateGradientOptimizer> shared_ptr;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Parameters params_;
|
Parameters params_;
|
||||||
|
DirectionMethod directionMethod_ = DirectionMethod::PolakRibiere;
|
||||||
|
|
||||||
const NonlinearOptimizerParams& _params() const override {
|
const NonlinearOptimizerParams &_params() const override { return params_; }
|
||||||
return params_;
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// Constructor
|
/// Constructor
|
||||||
NonlinearConjugateGradientOptimizer(const NonlinearFactorGraph& graph,
|
NonlinearConjugateGradientOptimizer(
|
||||||
const Values& initialValues, const Parameters& params = Parameters());
|
const NonlinearFactorGraph &graph, const Values &initialValues,
|
||||||
|
const Parameters ¶ms = Parameters(),
|
||||||
|
const DirectionMethod &directionMethod = DirectionMethod::PolakRibiere);
|
||||||
|
|
||||||
/// Destructor
|
/// Destructor
|
||||||
~NonlinearConjugateGradientOptimizer() override {
|
~NonlinearConjugateGradientOptimizer() override {}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform a single iteration, returning GaussianFactorGraph corresponding to
|
* Perform a single iteration, returning GaussianFactorGraph corresponding to
|
||||||
|
|
@ -85,28 +132,25 @@ public:
|
||||||
/** Implement the golden-section line search algorithm */
|
/** Implement the golden-section line search algorithm */
|
||||||
template <class S, class V, class W>
|
template <class S, class V, class W>
|
||||||
double lineSearch(const S &system, const V currentValues, const W &gradient) {
|
double lineSearch(const S &system, const V currentValues, const W &gradient) {
|
||||||
|
|
||||||
/* normalize it such that it becomes a unit vector */
|
/* normalize it such that it becomes a unit vector */
|
||||||
const double g = gradient.norm();
|
const double g = gradient.norm();
|
||||||
|
|
||||||
// perform the golden section search algorithm to decide the the optimal step size
|
// perform the golden section search algorithm to decide the the optimal step
|
||||||
// detail refer to http://en.wikipedia.org/wiki/Golden_section_search
|
// size detail refer to http://en.wikipedia.org/wiki/Golden_section_search
|
||||||
const double phi = 0.5 * (1.0 + std::sqrt(5.0)), resphi = 2.0 - phi, tau =
|
const double phi = 0.5 * (1.0 + std::sqrt(5.0)), resphi = 2.0 - phi,
|
||||||
1e-5;
|
tau = 1e-5;
|
||||||
double minStep = -1.0 / g, maxStep = 0, newStep = minStep
|
double minStep = -1.0 / g, maxStep = 0,
|
||||||
+ (maxStep - minStep) / (phi + 1.0);
|
newStep = minStep + (maxStep - minStep) / (phi + 1.0);
|
||||||
|
|
||||||
V newValues = system.advance(currentValues, newStep, gradient);
|
V newValues = system.advance(currentValues, newStep, gradient);
|
||||||
double newError = system.error(newValues);
|
double newError = system.error(newValues);
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
const bool flag = (maxStep - newStep > newStep - minStep) ? true : false;
|
const bool flag = (maxStep - newStep > newStep - minStep);
|
||||||
const double testStep =
|
const double testStep = flag ? newStep + resphi * (maxStep - newStep)
|
||||||
flag ? newStep + resphi * (maxStep - newStep) :
|
: newStep - resphi * (newStep - minStep);
|
||||||
newStep - resphi * (newStep - minStep);
|
|
||||||
|
|
||||||
if ((maxStep - minStep)
|
if ((maxStep - minStep) < tau * (std::abs(testStep) + std::abs(newStep))) {
|
||||||
< tau * (std::abs(testStep) + std::abs(newStep))) {
|
|
||||||
return 0.5 * (minStep + maxStep);
|
return 0.5 * (minStep + maxStep);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -135,19 +179,23 @@ double lineSearch(const S &system, const V currentValues, const W &gradient) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implement the nonlinear conjugate gradient method using the Polak-Ribiere formula suggested in
|
* Implement the nonlinear conjugate gradient method using the Polak-Ribiere
|
||||||
|
* formula suggested in
|
||||||
* http://en.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method.
|
* http://en.wikipedia.org/wiki/Nonlinear_conjugate_gradient_method.
|
||||||
*
|
*
|
||||||
* The S (system) class requires three member functions: error(state), gradient(state) and
|
* The S (system) class requires three member functions: error(state),
|
||||||
* advance(state, step-size, direction). The V class denotes the state or the solution.
|
* gradient(state) and advance(state, step-size, direction). The V class denotes
|
||||||
|
* the state or the solution.
|
||||||
*
|
*
|
||||||
* The last parameter is a switch between gradient-descent and conjugate gradient
|
* The last parameter is a switch between gradient-descent and conjugate
|
||||||
|
* gradient
|
||||||
*/
|
*/
|
||||||
template <class S, class V>
|
template <class S, class V>
|
||||||
std::tuple<V, int> nonlinearConjugateGradient(const S &system,
|
std::tuple<V, int> nonlinearConjugateGradient(
|
||||||
const V &initial, const NonlinearOptimizerParams ¶ms,
|
const S &system, const V &initial, const NonlinearOptimizerParams ¶ms,
|
||||||
const bool singleIteration, const bool gradientDescent = false) {
|
const bool singleIteration,
|
||||||
|
const DirectionMethod &directionMethod = DirectionMethod::PolakRibiere,
|
||||||
|
const bool gradientDescent = false) {
|
||||||
// GTSAM_CONCEPT_MANIFOLD_TYPE(V)
|
// GTSAM_CONCEPT_MANIFOLD_TYPE(V)
|
||||||
|
|
||||||
size_t iteration = 0;
|
size_t iteration = 0;
|
||||||
|
|
@ -184,10 +232,23 @@ std::tuple<V, int> nonlinearConjugateGradient(const S &system,
|
||||||
} else {
|
} else {
|
||||||
prevGradient = currentGradient;
|
prevGradient = currentGradient;
|
||||||
currentGradient = system.gradient(currentValues);
|
currentGradient = system.gradient(currentValues);
|
||||||
// Polak-Ribiere: beta = g'*(g_n-g_n-1)/g_n-1'*g_n-1
|
|
||||||
const double beta = std::max(0.0,
|
double beta;
|
||||||
currentGradient.dot(currentGradient - prevGradient)
|
switch (directionMethod) {
|
||||||
/ prevGradient.dot(prevGradient));
|
case DirectionMethod::FletcherReeves:
|
||||||
|
beta = FletcherReeves(currentGradient, prevGradient);
|
||||||
|
break;
|
||||||
|
case DirectionMethod::PolakRibiere:
|
||||||
|
beta = PolakRibiere(currentGradient, prevGradient);
|
||||||
|
break;
|
||||||
|
case DirectionMethod::HestenesStiefel:
|
||||||
|
beta = HestenesStiefel(currentGradient, prevGradient, direction);
|
||||||
|
break;
|
||||||
|
case DirectionMethod::DaiYuan:
|
||||||
|
beta = DaiYuan(currentGradient, prevGradient, direction);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
direction = currentGradient + (beta * direction);
|
direction = currentGradient + (beta * direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -205,20 +266,21 @@ std::tuple<V, int> nonlinearConjugateGradient(const S &system,
|
||||||
|
|
||||||
// Maybe show output
|
// Maybe show output
|
||||||
if (params.verbosity >= NonlinearOptimizerParams::ERROR)
|
if (params.verbosity >= NonlinearOptimizerParams::ERROR)
|
||||||
std::cout << "iteration: " << iteration << ", currentError: " << currentError << std::endl;
|
std::cout << "iteration: " << iteration
|
||||||
} while (++iteration < params.maxIterations && !singleIteration
|
<< ", currentError: " << currentError << std::endl;
|
||||||
&& !checkConvergence(params.relativeErrorTol, params.absoluteErrorTol,
|
} while (++iteration < params.maxIterations && !singleIteration &&
|
||||||
params.errorTol, prevError, currentError, params.verbosity));
|
!checkConvergence(params.relativeErrorTol, params.absoluteErrorTol,
|
||||||
|
params.errorTol, prevError, currentError,
|
||||||
|
params.verbosity));
|
||||||
|
|
||||||
// Printing if verbose
|
// Printing if verbose
|
||||||
if (params.verbosity >= NonlinearOptimizerParams::ERROR
|
if (params.verbosity >= NonlinearOptimizerParams::ERROR &&
|
||||||
&& iteration >= params.maxIterations)
|
iteration >= params.maxIterations)
|
||||||
std::cout
|
std::cout << "nonlinearConjugateGradient: Terminating because reached "
|
||||||
<< "nonlinearConjugateGradient: Terminating because reached maximum iterations"
|
"maximum iterations"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
return {currentValues, iteration};
|
return {currentValues, iteration};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // \ namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ namespace internal {
|
||||||
* @param alpha Quantile value
|
* @param alpha Quantile value
|
||||||
* @return double
|
* @return double
|
||||||
*/
|
*/
|
||||||
double chi_squared_quantile(const double dofs, const double alpha) {
|
inline double chi_squared_quantile(const double dofs, const double alpha) {
|
||||||
return 2 * igami(dofs / 2, alpha);
|
return 2 * igami(dofs / 2, alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,290 @@
|
||||||
|
/**
|
||||||
|
* @file testNonlinearConjugateGradientOptimizer.cpp
|
||||||
|
* @brief Test nonlinear CG optimizer
|
||||||
|
* @author Yong-Dian Jian
|
||||||
|
* @author Varun Agrawal
|
||||||
|
* @date June 11, 2012
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/geometry/Pose2.h>
|
||||||
|
#include <gtsam/inference/Symbol.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
#include <gtsam/nonlinear/PriorFactor.h>
|
||||||
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/nonlinear/factorTesting.h>
|
||||||
|
#include <gtsam/slam/BetweenFactor.h>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
using symbol_shorthand::X;
|
||||||
|
using symbol_shorthand::Y;
|
||||||
|
|
||||||
|
// Generate a small PoseSLAM problem
|
||||||
|
std::tuple<NonlinearFactorGraph, Values> generateProblem() {
|
||||||
|
// 1. Create graph container and add factors to it
|
||||||
|
NonlinearFactorGraph graph;
|
||||||
|
|
||||||
|
// 2a. Add Gaussian prior
|
||||||
|
Pose2 priorMean(0.0, 0.0, 0.0); // prior at origin
|
||||||
|
SharedDiagonal priorNoise =
|
||||||
|
noiseModel::Diagonal::Sigmas(Vector3(0.3, 0.3, 0.1));
|
||||||
|
graph.addPrior(1, priorMean, priorNoise);
|
||||||
|
|
||||||
|
// 2b. Add odometry factors
|
||||||
|
SharedDiagonal odometryNoise =
|
||||||
|
noiseModel::Diagonal::Sigmas(Vector3(0.2, 0.2, 0.1));
|
||||||
|
graph.emplace_shared<BetweenFactor<Pose2>>(1, 2, Pose2(2.0, 0.0, 0.0),
|
||||||
|
odometryNoise);
|
||||||
|
graph.emplace_shared<BetweenFactor<Pose2>>(2, 3, Pose2(2.0, 0.0, M_PI_2),
|
||||||
|
odometryNoise);
|
||||||
|
graph.emplace_shared<BetweenFactor<Pose2>>(3, 4, Pose2(2.0, 0.0, M_PI_2),
|
||||||
|
odometryNoise);
|
||||||
|
graph.emplace_shared<BetweenFactor<Pose2>>(4, 5, Pose2(2.0, 0.0, M_PI_2),
|
||||||
|
odometryNoise);
|
||||||
|
|
||||||
|
// 2c. Add pose constraint
|
||||||
|
SharedDiagonal constraintUncertainty =
|
||||||
|
noiseModel::Diagonal::Sigmas(Vector3(0.2, 0.2, 0.1));
|
||||||
|
graph.emplace_shared<BetweenFactor<Pose2>>(5, 2, Pose2(2.0, 0.0, M_PI_2),
|
||||||
|
constraintUncertainty);
|
||||||
|
|
||||||
|
// 3. Create the data structure to hold the initialEstimate estimate to the
|
||||||
|
// solution
|
||||||
|
Values initialEstimate;
|
||||||
|
Pose2 x1(0.5, 0.0, 0.2);
|
||||||
|
initialEstimate.insert(1, x1);
|
||||||
|
Pose2 x2(2.3, 0.1, -0.2);
|
||||||
|
initialEstimate.insert(2, x2);
|
||||||
|
Pose2 x3(4.1, 0.1, M_PI_2);
|
||||||
|
initialEstimate.insert(3, x3);
|
||||||
|
Pose2 x4(4.0, 2.0, M_PI);
|
||||||
|
initialEstimate.insert(4, x4);
|
||||||
|
Pose2 x5(2.1, 2.1, -M_PI_2);
|
||||||
|
initialEstimate.insert(5, x5);
|
||||||
|
|
||||||
|
return {graph, initialEstimate};
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(NonlinearConjugateGradientOptimizer, Optimize) {
|
||||||
|
const auto [graph, initialEstimate] = generateProblem();
|
||||||
|
// cout << "initial error = " << graph.error(initialEstimate) << endl;
|
||||||
|
|
||||||
|
NonlinearOptimizerParams param;
|
||||||
|
param.maxIterations =
|
||||||
|
500; /* requires a larger number of iterations to converge */
|
||||||
|
param.verbosity = NonlinearOptimizerParams::SILENT;
|
||||||
|
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(graph, initialEstimate, param);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
// cout << "cg final error = " << graph.error(result) << endl;
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace rosenbrock {
|
||||||
|
|
||||||
|
class Rosenbrock1Factor : public NoiseModelFactorN<double> {
|
||||||
|
private:
|
||||||
|
typedef Rosenbrock1Factor This;
|
||||||
|
typedef NoiseModelFactorN<double> Base;
|
||||||
|
|
||||||
|
double a_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/** Constructor: key is x */
|
||||||
|
Rosenbrock1Factor(Key key, double a, const SharedNoiseModel& model = nullptr)
|
||||||
|
: Base(model, key), a_(a) {}
|
||||||
|
|
||||||
|
/// evaluate error
|
||||||
|
Vector evaluateError(const double& x, OptionalMatrixType H) const override {
|
||||||
|
double d = x - a_;
|
||||||
|
// Because linearized gradient is -A'b/sigma, it will multiply by d
|
||||||
|
if (H) (*H) = Vector1(1).transpose();
|
||||||
|
return Vector1(d);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Factor for the second term of the Rosenbrock function.
|
||||||
|
* f2 = (y - x*x)
|
||||||
|
*
|
||||||
|
* We use the noise model to premultiply with `b`.
|
||||||
|
*/
|
||||||
|
class Rosenbrock2Factor : public NoiseModelFactorN<double, double> {
|
||||||
|
private:
|
||||||
|
typedef Rosenbrock2Factor This;
|
||||||
|
typedef NoiseModelFactorN<double, double> Base;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/** Constructor: key1 is x, key2 is y */
|
||||||
|
Rosenbrock2Factor(Key key1, Key key2, const SharedNoiseModel& model = nullptr)
|
||||||
|
: Base(model, key1, key2) {}
|
||||||
|
|
||||||
|
/// evaluate error
|
||||||
|
Vector evaluateError(const double& x, const double& y, OptionalMatrixType H1,
|
||||||
|
OptionalMatrixType H2) const override {
|
||||||
|
double x2 = x * x, d = x2 - y;
|
||||||
|
// Because linearized gradient is -A'b/sigma, it will multiply by d
|
||||||
|
if (H1) (*H1) = Vector1(2 * x).transpose();
|
||||||
|
if (H2) (*H2) = Vector1(-1).transpose();
|
||||||
|
return Vector1(d);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get a nonlinear factor graph representing
|
||||||
|
* the Rosenbrock Banana function.
|
||||||
|
*
|
||||||
|
* More details: https://en.wikipedia.org/wiki/Rosenbrock_function
|
||||||
|
*
|
||||||
|
* @param a
|
||||||
|
* @param b
|
||||||
|
* @return NonlinearFactorGraph
|
||||||
|
*/
|
||||||
|
static NonlinearFactorGraph GetRosenbrockGraph(double a = 1.0,
|
||||||
|
double b = 100.0) {
|
||||||
|
NonlinearFactorGraph graph;
|
||||||
|
graph.emplace_shared<Rosenbrock1Factor>(
|
||||||
|
X(0), a, noiseModel::Isotropic::Precision(1, 2));
|
||||||
|
graph.emplace_shared<Rosenbrock2Factor>(
|
||||||
|
X(0), Y(0), noiseModel::Isotropic::Precision(1, 2 * b));
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute the Rosenbrock function error given the nonlinear factor graph
|
||||||
|
/// and input values.
|
||||||
|
double f(const NonlinearFactorGraph& graph, double x, double y) {
|
||||||
|
Values initial;
|
||||||
|
initial.insert<double>(X(0), x);
|
||||||
|
initial.insert<double>(Y(0), y);
|
||||||
|
|
||||||
|
return graph.error(initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// True Rosenbrock Banana function.
|
||||||
|
double rosenbrock_func(double x, double y, double a = 1.0, double b = 100.0) {
|
||||||
|
double m = (a - x) * (a - x);
|
||||||
|
double n = b * (y - x * x) * (y - x * x);
|
||||||
|
return m + n;
|
||||||
|
}
|
||||||
|
} // namespace rosenbrock
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Test whether the 2 factors are properly implemented.
|
||||||
|
TEST(NonlinearConjugateGradientOptimizer, Rosenbrock) {
|
||||||
|
using namespace rosenbrock;
|
||||||
|
double a = 1.0, b = 100.0;
|
||||||
|
auto graph = GetRosenbrockGraph(a, b);
|
||||||
|
Rosenbrock1Factor f1 =
|
||||||
|
*std::static_pointer_cast<Rosenbrock1Factor>(graph.at(0));
|
||||||
|
Rosenbrock2Factor f2 =
|
||||||
|
*std::static_pointer_cast<Rosenbrock2Factor>(graph.at(1));
|
||||||
|
Values values;
|
||||||
|
values.insert<double>(X(0), 3.0);
|
||||||
|
values.insert<double>(Y(0), 5.0);
|
||||||
|
EXPECT_CORRECT_FACTOR_JACOBIANS(f1, values, 1e-7, 1e-5);
|
||||||
|
EXPECT_CORRECT_FACTOR_JACOBIANS(f2, values, 1e-7, 1e-5);
|
||||||
|
|
||||||
|
std::mt19937 rng(42);
|
||||||
|
std::uniform_real_distribution<double> dist(0.0, 100.0);
|
||||||
|
for (size_t i = 0; i < 50; ++i) {
|
||||||
|
double x = dist(rng);
|
||||||
|
double y = dist(rng);
|
||||||
|
|
||||||
|
auto graph = GetRosenbrockGraph(a, b);
|
||||||
|
EXPECT_DOUBLES_EQUAL(rosenbrock_func(x, y, a, b), f(graph, x, y), 1e-5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Optimize the Rosenbrock function to verify optimizer works
|
||||||
|
TEST(NonlinearConjugateGradientOptimizer, Optimization) {
|
||||||
|
using namespace rosenbrock;
|
||||||
|
|
||||||
|
double a = 12;
|
||||||
|
auto graph = GetRosenbrockGraph(a);
|
||||||
|
|
||||||
|
// Assert that error at global minimum is 0.
|
||||||
|
double error = f(graph, a, a * a);
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, error, 1e-12);
|
||||||
|
|
||||||
|
NonlinearOptimizerParams param;
|
||||||
|
param.maxIterations = 350;
|
||||||
|
// param.verbosity = NonlinearOptimizerParams::LINEAR;
|
||||||
|
param.verbosity = NonlinearOptimizerParams::SILENT;
|
||||||
|
|
||||||
|
double x = 3.0, y = 5.0;
|
||||||
|
Values initialEstimate;
|
||||||
|
initialEstimate.insert<double>(X(0), x);
|
||||||
|
initialEstimate.insert<double>(Y(0), y);
|
||||||
|
|
||||||
|
GaussianFactorGraph::shared_ptr linear = graph.linearize(initialEstimate);
|
||||||
|
// std::cout << "error: " << f(graph, x, y) << std::endl;
|
||||||
|
// linear->print();
|
||||||
|
// linear->gradientAtZero().print("Gradient: ");
|
||||||
|
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(graph, initialEstimate, param);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
// result.print();
|
||||||
|
// cout << "cg final error = " << graph.error(result) << endl;
|
||||||
|
|
||||||
|
Values expected;
|
||||||
|
expected.insert<double>(X(0), a);
|
||||||
|
expected.insert<double>(Y(0), a * a);
|
||||||
|
EXPECT(assert_equal(expected, result, 1e-1));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
/// Test different direction methods
|
||||||
|
TEST(NonlinearConjugateGradientOptimizer, DirectionMethods) {
|
||||||
|
const auto [graph, initialEstimate] = generateProblem();
|
||||||
|
|
||||||
|
NonlinearOptimizerParams param;
|
||||||
|
param.maxIterations =
|
||||||
|
500; /* requires a larger number of iterations to converge */
|
||||||
|
param.verbosity = NonlinearOptimizerParams::SILENT;
|
||||||
|
|
||||||
|
// Fletcher-Reeves
|
||||||
|
{
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(
|
||||||
|
graph, initialEstimate, param, DirectionMethod::FletcherReeves);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
||||||
|
}
|
||||||
|
// Polak-Ribiere
|
||||||
|
{
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(
|
||||||
|
graph, initialEstimate, param, DirectionMethod::PolakRibiere);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
||||||
|
}
|
||||||
|
// Hestenes-Stiefel
|
||||||
|
{
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(
|
||||||
|
graph, initialEstimate, param, DirectionMethod::HestenesStiefel);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
||||||
|
}
|
||||||
|
// Dai-Yuan
|
||||||
|
{
|
||||||
|
NonlinearConjugateGradientOptimizer optimizer(graph, initialEstimate, param,
|
||||||
|
DirectionMethod::DaiYuan);
|
||||||
|
Values result = optimizer.optimize();
|
||||||
|
|
||||||
|
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
/* ************************************************************************* */
|
||||||
|
|
@ -23,7 +23,6 @@
|
||||||
// numericalDerivative.h : includes things in linear, nonlinear :-(
|
// numericalDerivative.h : includes things in linear, nonlinear :-(
|
||||||
// testLie.h: includes numericalDerivative
|
// testLie.h: includes numericalDerivative
|
||||||
#include <gtsam/base/Lie.h>
|
#include <gtsam/base/Lie.h>
|
||||||
#include <gtsam/base/chartTesting.h>
|
|
||||||
#include <gtsam/base/cholesky.h>
|
#include <gtsam/base/cholesky.h>
|
||||||
#include <gtsam/base/concepts.h>
|
#include <gtsam/base/concepts.h>
|
||||||
#include <gtsam/base/ConcurrentMap.h>
|
#include <gtsam/base/ConcurrentMap.h>
|
||||||
|
|
|
||||||
|
|
@ -67,20 +67,15 @@ ShonanAveragingParameters<d>::ShonanAveragingParameters(
|
||||||
builderParameters.augmentationWeight = SubgraphBuilderParameters::SKELETON;
|
builderParameters.augmentationWeight = SubgraphBuilderParameters::SKELETON;
|
||||||
builderParameters.augmentationFactor = 0.0;
|
builderParameters.augmentationFactor = 0.0;
|
||||||
|
|
||||||
auto pcg = std::make_shared<PCGSolverParameters>();
|
|
||||||
|
|
||||||
// Choose optimization method
|
// Choose optimization method
|
||||||
if (method == "SUBGRAPH") {
|
if (method == "SUBGRAPH") {
|
||||||
lm.iterativeParams =
|
lm.iterativeParams =
|
||||||
std::make_shared<SubgraphSolverParameters>(builderParameters);
|
std::make_shared<SubgraphSolverParameters>(builderParameters);
|
||||||
} else if (method == "SGPC") {
|
} else if (method == "SGPC") {
|
||||||
pcg->preconditioner_ =
|
lm.iterativeParams = std::make_shared<PCGSolverParameters>(
|
||||||
std::make_shared<SubgraphPreconditionerParameters>(builderParameters);
|
std::make_shared<SubgraphPreconditionerParameters>(builderParameters));
|
||||||
lm.iterativeParams = pcg;
|
|
||||||
} else if (method == "JACOBI") {
|
} else if (method == "JACOBI") {
|
||||||
pcg->preconditioner_ =
|
lm.iterativeParams = std::make_shared<PCGSolverParameters>(std::make_shared<BlockJacobiPreconditionerParameters>());
|
||||||
std::make_shared<BlockJacobiPreconditionerParameters>();
|
|
||||||
lm.iterativeParams = pcg;
|
|
||||||
} else if (method == "QR") {
|
} else if (method == "QR") {
|
||||||
lm.setLinearSolverType("MULTIFRONTAL_QR");
|
lm.setLinearSolverType("MULTIFRONTAL_QR");
|
||||||
} else if (method == "CHOLESKY") {
|
} else if (method == "CHOLESKY") {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
* GTSAM Copyright 2010-2024, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
* See LICENSE for the license information
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/*
|
||||||
|
* @file TransferFactor.h
|
||||||
|
* @brief TransferFactor class
|
||||||
|
* @author Frank Dellaert
|
||||||
|
* @date October 24, 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/numericalDerivative.h>
|
||||||
|
#include <gtsam/geometry/FundamentalMatrix.h>
|
||||||
|
#include <gtsam/inference/EdgeKey.h>
|
||||||
|
#include <gtsam/nonlinear/NonlinearFactor.h>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Binary factor in the context of Structure from Motion (SfM).
|
||||||
|
* It is used to transfer transfer corresponding points from two views to a
|
||||||
|
* third based on two fundamental matrices. The factor computes the error
|
||||||
|
* between the transferred points `pa` and `pb`, and the actual point `pc` in
|
||||||
|
* the target view. Jacobians are done using numerical differentiation.
|
||||||
|
*/
|
||||||
|
template <typename F>
|
||||||
|
class TransferFactor : public NoiseModelFactorN<F, F> {
|
||||||
|
EdgeKey key1_, key2_; ///< the two EdgeKeys
|
||||||
|
std::vector<std::tuple<Point2, Point2, Point2>>
|
||||||
|
triplets_; ///< Point triplets
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Constructor for a single triplet of points
|
||||||
|
*
|
||||||
|
* @note: batching all points for the same transfer will be much faster.
|
||||||
|
*
|
||||||
|
* @param key1 First EdgeKey specifying F1: (a, c) or (c, a).
|
||||||
|
* @param key2 Second EdgeKey specifying F2: (b, c) or (c, b).
|
||||||
|
* @param pa The point in the first view (a).
|
||||||
|
* @param pb The point in the second view (b).
|
||||||
|
* @param pc The point in the third (and transfer target) view (c).
|
||||||
|
* @param model An optional SharedNoiseModel that defines the noise model
|
||||||
|
* for this factor. Defaults to nullptr.
|
||||||
|
*/
|
||||||
|
TransferFactor(EdgeKey key1, EdgeKey key2, const Point2& pa, const Point2& pb,
|
||||||
|
const Point2& pc, const SharedNoiseModel& model = nullptr)
|
||||||
|
: NoiseModelFactorN<F, F>(model, key1, key2),
|
||||||
|
key1_(key1),
|
||||||
|
key2_(key2),
|
||||||
|
triplets_({std::make_tuple(pa, pb, pc)}) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Constructor that accepts a vector of point triplets.
|
||||||
|
*
|
||||||
|
* @param key1 First EdgeKey specifying F1: (a, c) or (c, a).
|
||||||
|
* @param key2 Second EdgeKey specifying F2: (b, c) or (c, b).
|
||||||
|
* @param triplets A vector of triplets containing (pa, pb, pc).
|
||||||
|
* @param model An optional SharedNoiseModel that defines the noise model
|
||||||
|
* for this factor. Defaults to nullptr.
|
||||||
|
*/
|
||||||
|
TransferFactor(
|
||||||
|
EdgeKey key1, EdgeKey key2,
|
||||||
|
const std::vector<std::tuple<Point2, Point2, Point2>>& triplets,
|
||||||
|
const SharedNoiseModel& model = nullptr)
|
||||||
|
: NoiseModelFactorN<F, F>(model, key1, key2),
|
||||||
|
key1_(key1),
|
||||||
|
key2_(key2),
|
||||||
|
triplets_(triplets) {}
|
||||||
|
|
||||||
|
// Create Matrix3 objects based on EdgeKey configurations
|
||||||
|
std::pair<Matrix3, Matrix3> getMatrices(const F& F1, const F& F2) const {
|
||||||
|
// Fill Fca and Fcb based on EdgeKey configurations
|
||||||
|
if (key1_.i() == key2_.i()) {
|
||||||
|
return {F1.matrix(), F2.matrix()};
|
||||||
|
} else if (key1_.i() == key2_.j()) {
|
||||||
|
return {F1.matrix(), F2.matrix().transpose()};
|
||||||
|
} else if (key1_.j() == key2_.i()) {
|
||||||
|
return {F1.matrix().transpose(), F2.matrix()};
|
||||||
|
} else if (key1_.j() == key2_.j()) {
|
||||||
|
return {F1.matrix().transpose(), F2.matrix().transpose()};
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"TransferFactor: invalid EdgeKey configuration.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// vector of errors returns 2*N vector
|
||||||
|
Vector evaluateError(const F& F1, const F& F2,
|
||||||
|
OptionalMatrixType H1 = nullptr,
|
||||||
|
OptionalMatrixType H2 = nullptr) const override {
|
||||||
|
std::function<Vector(const F&, const F&)> transfer = [&](const F& F1,
|
||||||
|
const F& F2) {
|
||||||
|
Vector errors(2 * triplets_.size());
|
||||||
|
size_t idx = 0;
|
||||||
|
auto [Fca, Fcb] = getMatrices(F1, F2);
|
||||||
|
for (const auto& tuple : triplets_) {
|
||||||
|
const auto& [pa, pb, pc] = tuple;
|
||||||
|
Point2 transferredPoint = EpipolarTransfer(Fca, pa, Fcb, pb);
|
||||||
|
Vector2 error = transferredPoint - pc;
|
||||||
|
errors.segment<2>(idx) = error;
|
||||||
|
idx += 2;
|
||||||
|
}
|
||||||
|
return errors;
|
||||||
|
};
|
||||||
|
if (H1) *H1 = numericalDerivative21<Vector, F, F>(transfer, F1, F2);
|
||||||
|
if (H2) *H2 = numericalDerivative22<Vector, F, F>(transfer, F1, F2);
|
||||||
|
return transfer(F1, F2);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -55,7 +55,7 @@ class TranslationFactor : public NoiseModelFactorN<Point3, Point3> {
|
||||||
: Base(noiseModel, a, b), measured_w_aZb_(w_aZb.point3()) {}
|
: Base(noiseModel, a, b), measured_w_aZb_(w_aZb.point3()) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Caclulate error: (norm(Tb - Ta) - measurement)
|
* @brief Calculate error: (norm(Tb - Ta) - measurement)
|
||||||
* where Tb and Ta are Point3 translations and measurement is
|
* where Tb and Ta are Point3 translations and measurement is
|
||||||
* the Unit3 translation direction from a to b.
|
* the Unit3 translation direction from a to b.
|
||||||
*
|
*
|
||||||
|
|
@ -120,7 +120,7 @@ class BilinearAngleTranslationFactor
|
||||||
using NoiseModelFactor2<Point3, Point3, Vector1>::evaluateError;
|
using NoiseModelFactor2<Point3, Point3, Vector1>::evaluateError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Caclulate error: (scale * (Tb - Ta) - measurement)
|
* @brief Calculate error: (scale * (Tb - Ta) - measurement)
|
||||||
* where Tb and Ta are Point3 translations and measurement is
|
* where Tb and Ta are Point3 translations and measurement is
|
||||||
* the Unit3 translation direction from a to b.
|
* the Unit3 translation direction from a to b.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ class GTSAM_EXPORT TranslationRecovery {
|
||||||
* @param betweenTranslations relative translations (with scale) between 2
|
* @param betweenTranslations relative translations (with scale) between 2
|
||||||
* points in world coordinate frame known a priori.
|
* points in world coordinate frame known a priori.
|
||||||
* @param rng random number generator
|
* @param rng random number generator
|
||||||
* @param intialValues (optional) initial values from a prior
|
* @param initialValues (optional) initial values from a prior
|
||||||
* @return Values
|
* @return Values
|
||||||
*/
|
*/
|
||||||
Values initializeRandomly(
|
Values initializeRandomly(
|
||||||
|
|
@ -156,7 +156,7 @@ class GTSAM_EXPORT TranslationRecovery {
|
||||||
* points in world coordinate frame known a priori. Unlike
|
* points in world coordinate frame known a priori. Unlike
|
||||||
* relativeTranslations, zero-magnitude betweenTranslations are not treated as
|
* relativeTranslations, zero-magnitude betweenTranslations are not treated as
|
||||||
* hard constraints.
|
* hard constraints.
|
||||||
* @param initialValues intial values for optimization. Initializes randomly
|
* @param initialValues initial values for optimization. Initializes randomly
|
||||||
* if not provided.
|
* if not provided.
|
||||||
* @return Values
|
* @return Values
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,161 @@
|
||||||
|
/*
|
||||||
|
* @file testTransferFactor.cpp
|
||||||
|
* @brief Test TransferFactor class
|
||||||
|
* @author Your Name
|
||||||
|
* @date October 23, 2024
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/geometry/SimpleCamera.h>
|
||||||
|
#include <gtsam/nonlinear/factorTesting.h>
|
||||||
|
#include <gtsam/sfm/TransferFactor.h>
|
||||||
|
|
||||||
|
using namespace gtsam;
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
/// Generate three cameras on a circle, looking in
|
||||||
|
std::array<Pose3, 3> generateCameraPoses() {
|
||||||
|
std::array<Pose3, 3> cameraPoses;
|
||||||
|
const double radius = 1.0;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
double angle = i * 2.0 * M_PI / 3.0;
|
||||||
|
double c = cos(angle), s = sin(angle);
|
||||||
|
Rot3 aRb({-s, c, 0}, {0, 0, -1}, {-c, -s, 0});
|
||||||
|
cameraPoses[i] = {aRb, Point3(radius * c, radius * s, 0)};
|
||||||
|
}
|
||||||
|
return cameraPoses;
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Function to generate a TripleF from camera poses
|
||||||
|
TripleF<SimpleFundamentalMatrix> generateTripleF(
|
||||||
|
const std::array<Pose3, 3>& cameraPoses) {
|
||||||
|
std::array<SimpleFundamentalMatrix, 3> F;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
size_t j = (i + 1) % 3;
|
||||||
|
const Pose3 iPj = cameraPoses[i].between(cameraPoses[j]);
|
||||||
|
EssentialMatrix E(iPj.rotation(), Unit3(iPj.translation()));
|
||||||
|
F[i] = {E, 1000.0, 1000.0, Point2(640 / 2, 480 / 2),
|
||||||
|
Point2(640 / 2, 480 / 2)};
|
||||||
|
}
|
||||||
|
return {F[0], F[1], F[2]}; // Return a TripleF instance
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
namespace fixture {
|
||||||
|
// Generate cameras on a circle
|
||||||
|
std::array<Pose3, 3> cameraPoses = generateCameraPoses();
|
||||||
|
auto triplet = generateTripleF(cameraPoses);
|
||||||
|
double focalLength = 1000;
|
||||||
|
Point2 principalPoint(640 / 2, 480 / 2);
|
||||||
|
const Cal3_S2 K(focalLength, focalLength, 0.0, //
|
||||||
|
principalPoint.x(), principalPoint.y());
|
||||||
|
} // namespace fixture
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Test for getMatrices
|
||||||
|
TEST(TransferFactor, GetMatrices) {
|
||||||
|
using namespace fixture;
|
||||||
|
TransferFactor<SimpleFundamentalMatrix> factor{{2, 0}, {1, 2}, {}};
|
||||||
|
|
||||||
|
// Check that getMatrices is correct
|
||||||
|
auto [Fki, Fkj] = factor.getMatrices(triplet.Fca, triplet.Fbc);
|
||||||
|
EXPECT(assert_equal<Matrix3>(triplet.Fca.matrix(), Fki));
|
||||||
|
EXPECT(assert_equal<Matrix3>(triplet.Fbc.matrix().transpose(), Fkj));
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Test for TransferFactor
|
||||||
|
TEST(TransferFactor, Jacobians) {
|
||||||
|
using namespace fixture;
|
||||||
|
|
||||||
|
// Now project a point into the three cameras
|
||||||
|
const Point3 P(0.1, 0.2, 0.3);
|
||||||
|
|
||||||
|
std::array<Point2, 3> p;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
// Project the point into each camera
|
||||||
|
PinholeCameraCal3_S2 camera(cameraPoses[i], K);
|
||||||
|
p[i] = camera.project(P);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a TransferFactor
|
||||||
|
EdgeKey key01(0, 1), key12(1, 2), key20(2, 0);
|
||||||
|
TransferFactor<SimpleFundamentalMatrix> //
|
||||||
|
factor0{key01, key20, p[1], p[2], p[0]},
|
||||||
|
factor1{key12, key01, p[2], p[0], p[1]},
|
||||||
|
factor2{key20, key12, p[0], p[1], p[2]};
|
||||||
|
|
||||||
|
// Create Values with edge keys
|
||||||
|
Values values;
|
||||||
|
values.insert(key01, triplet.Fab);
|
||||||
|
values.insert(key12, triplet.Fbc);
|
||||||
|
values.insert(key20, triplet.Fca);
|
||||||
|
|
||||||
|
// Check error and Jacobians
|
||||||
|
for (auto&& f : {factor0, factor1, factor2}) {
|
||||||
|
Vector error = f.unwhitenedError(values);
|
||||||
|
EXPECT(assert_equal<Vector>(Z_2x1, error));
|
||||||
|
EXPECT_CORRECT_FACTOR_JACOBIANS(f, values, 1e-5, 1e-7);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
// Test for TransferFactor with multiple tuples
|
||||||
|
TEST(TransferFactor, MultipleTuples) {
|
||||||
|
using namespace fixture;
|
||||||
|
|
||||||
|
// Now project multiple points into the three cameras
|
||||||
|
const size_t numPoints = 5; // Number of points to test
|
||||||
|
std::vector<Point3> points3D;
|
||||||
|
std::vector<std::array<Point2, 3>> projectedPoints;
|
||||||
|
|
||||||
|
// Generate random 3D points and project them
|
||||||
|
for (size_t n = 0; n < numPoints; ++n) {
|
||||||
|
Point3 P(0.1 * n, 0.2 * n, 0.3 + 0.1 * n);
|
||||||
|
points3D.push_back(P);
|
||||||
|
|
||||||
|
std::array<Point2, 3> p;
|
||||||
|
for (size_t i = 0; i < 3; ++i) {
|
||||||
|
// Project the point into each camera
|
||||||
|
PinholeCameraCal3_S2 camera(cameraPoses[i], K);
|
||||||
|
p[i] = camera.project(P);
|
||||||
|
}
|
||||||
|
projectedPoints.push_back(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a vector of tuples for the TransferFactor
|
||||||
|
EdgeKey key01(0, 1), key12(1, 2), key20(2, 0);
|
||||||
|
std::vector<std::tuple<Point2, Point2, Point2>> tuples;
|
||||||
|
|
||||||
|
for (size_t n = 0; n < numPoints; ++n) {
|
||||||
|
const auto& p = projectedPoints[n];
|
||||||
|
tuples.emplace_back(p[1], p[2], p[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TransferFactors using the new constructor
|
||||||
|
TransferFactor<SimpleFundamentalMatrix> factor{key01, key20, tuples};
|
||||||
|
|
||||||
|
// Create Values with edge keys
|
||||||
|
Values values;
|
||||||
|
values.insert(key01, triplet.Fab);
|
||||||
|
values.insert(key12, triplet.Fbc);
|
||||||
|
values.insert(key20, triplet.Fca);
|
||||||
|
|
||||||
|
// Check error and Jacobians for multiple tuples
|
||||||
|
Vector error = factor.unwhitenedError(values);
|
||||||
|
// The error vector should be of size 2 * numPoints
|
||||||
|
EXPECT_LONGS_EQUAL(2 * numPoints, error.size());
|
||||||
|
// Since the points are perfectly matched, the error should be zero
|
||||||
|
EXPECT(assert_equal<Vector>(Vector::Zero(2 * numPoints), error, 1e-9));
|
||||||
|
|
||||||
|
// Check the Jacobians
|
||||||
|
EXPECT_CORRECT_FACTOR_JACOBIANS(factor, values, 1e-5, 1e-7);
|
||||||
|
}
|
||||||
|
|
||||||
|
//*************************************************************************
|
||||||
|
int main() {
|
||||||
|
TestResult tr;
|
||||||
|
return TestRegistry::runAllTests(tr);
|
||||||
|
}
|
||||||
|
//*************************************************************************
|
||||||
|
|
@ -69,7 +69,7 @@ class TestScenario(GtsamTestCase):
|
||||||
lmParams = LevenbergMarquardtParams.CeresDefaults()
|
lmParams = LevenbergMarquardtParams.CeresDefaults()
|
||||||
lmParams.setLinearSolverType("ITERATIVE")
|
lmParams.setLinearSolverType("ITERATIVE")
|
||||||
cgParams = PCGSolverParameters()
|
cgParams = PCGSolverParameters()
|
||||||
cgParams.setPreconditionerParams(DummyPreconditionerParameters())
|
cgParams.preconditioner = DummyPreconditionerParameters()
|
||||||
lmParams.setIterativeParams(cgParams)
|
lmParams.setIterativeParams(cgParams)
|
||||||
actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize()
|
actual = LevenbergMarquardtOptimizer(self.fg, self.initial_values, lmParams).optimize()
|
||||||
self.assertAlmostEqual(0, self.fg.error(actual))
|
self.assertAlmostEqual(0, self.fg.error(actual))
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,24 @@ TEST(DoglegOptimizer, ComputeBlend) {
|
||||||
DOUBLES_EQUAL(Delta, xb.vector().norm(), 1e-10);
|
DOUBLES_EQUAL(Delta, xb.vector().norm(), 1e-10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(DoglegOptimizer, ComputeBlendEdgeCases) {
|
||||||
|
// Test Derived from Issue #1861
|
||||||
|
// Evaluate ComputeBlend Behavior for edge cases where the trust region
|
||||||
|
// is equal in size to that of the newton step or the gradient step.
|
||||||
|
|
||||||
|
// Simulated Newton (n) and Gradient Descent (u) step vectors w/ ||n|| > ||u||
|
||||||
|
VectorValues::Dims dims;
|
||||||
|
dims[0] = 3;
|
||||||
|
VectorValues n(Vector3(0.3233546123, -0.2133456123, 0.3664345632), dims);
|
||||||
|
VectorValues u(Vector3(0.0023456342, -0.04535687, 0.087345661212), dims);
|
||||||
|
|
||||||
|
// Test upper edge case where trust region is equal to magnitude of newton step
|
||||||
|
EXPECT(assert_equal(n, DoglegOptimizerImpl::ComputeBlend(n.norm(), u, n, false)));
|
||||||
|
// Test lower edge case where trust region is equal to magnitude of gradient step
|
||||||
|
EXPECT(assert_equal(u, DoglegOptimizerImpl::ComputeBlend(u.norm(), u, n, false)));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DoglegOptimizer, ComputeDoglegPoint) {
|
TEST(DoglegOptimizer, ComputeDoglegPoint) {
|
||||||
// Create an arbitrary Bayes Net
|
// Create an arbitrary Bayes Net
|
||||||
|
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
/**
|
|
||||||
* @file NonlinearConjugateGradientOptimizer.cpp
|
|
||||||
* @brief Test simple CG optimizer
|
|
||||||
* @author Yong-Dian Jian
|
|
||||||
* @date June 11, 2012
|
|
||||||
*/
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @file testGradientDescentOptimizer.cpp
|
|
||||||
* @brief Small test of NonlinearConjugateGradientOptimizer
|
|
||||||
* @author Yong-Dian Jian
|
|
||||||
* @date Jun 11, 2012
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <gtsam/slam/BetweenFactor.h>
|
|
||||||
#include <gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h>
|
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
|
||||||
#include <gtsam/nonlinear/Values.h>
|
|
||||||
#include <gtsam/geometry/Pose2.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace gtsam;
|
|
||||||
|
|
||||||
// Generate a small PoseSLAM problem
|
|
||||||
std::tuple<NonlinearFactorGraph, Values> generateProblem() {
|
|
||||||
|
|
||||||
// 1. Create graph container and add factors to it
|
|
||||||
NonlinearFactorGraph graph;
|
|
||||||
|
|
||||||
// 2a. Add Gaussian prior
|
|
||||||
Pose2 priorMean(0.0, 0.0, 0.0); // prior at origin
|
|
||||||
SharedDiagonal priorNoise = noiseModel::Diagonal::Sigmas(
|
|
||||||
Vector3(0.3, 0.3, 0.1));
|
|
||||||
graph.addPrior(1, priorMean, priorNoise);
|
|
||||||
|
|
||||||
// 2b. Add odometry factors
|
|
||||||
SharedDiagonal odometryNoise = noiseModel::Diagonal::Sigmas(
|
|
||||||
Vector3(0.2, 0.2, 0.1));
|
|
||||||
graph.emplace_shared<BetweenFactor<Pose2>>(1, 2, Pose2(2.0, 0.0, 0.0), odometryNoise);
|
|
||||||
graph.emplace_shared<BetweenFactor<Pose2>>(2, 3, Pose2(2.0, 0.0, M_PI_2), odometryNoise);
|
|
||||||
graph.emplace_shared<BetweenFactor<Pose2>>(3, 4, Pose2(2.0, 0.0, M_PI_2), odometryNoise);
|
|
||||||
graph.emplace_shared<BetweenFactor<Pose2>>(4, 5, Pose2(2.0, 0.0, M_PI_2), odometryNoise);
|
|
||||||
|
|
||||||
// 2c. Add pose constraint
|
|
||||||
SharedDiagonal constraintUncertainty = noiseModel::Diagonal::Sigmas(
|
|
||||||
Vector3(0.2, 0.2, 0.1));
|
|
||||||
graph.emplace_shared<BetweenFactor<Pose2>>(5, 2, Pose2(2.0, 0.0, M_PI_2),
|
|
||||||
constraintUncertainty);
|
|
||||||
|
|
||||||
// 3. Create the data structure to hold the initialEstimate estimate to the solution
|
|
||||||
Values initialEstimate;
|
|
||||||
Pose2 x1(0.5, 0.0, 0.2);
|
|
||||||
initialEstimate.insert(1, x1);
|
|
||||||
Pose2 x2(2.3, 0.1, -0.2);
|
|
||||||
initialEstimate.insert(2, x2);
|
|
||||||
Pose2 x3(4.1, 0.1, M_PI_2);
|
|
||||||
initialEstimate.insert(3, x3);
|
|
||||||
Pose2 x4(4.0, 2.0, M_PI);
|
|
||||||
initialEstimate.insert(4, x4);
|
|
||||||
Pose2 x5(2.1, 2.1, -M_PI_2);
|
|
||||||
initialEstimate.insert(5, x5);
|
|
||||||
|
|
||||||
return {graph, initialEstimate};
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
TEST(NonlinearConjugateGradientOptimizer, Optimize) {
|
|
||||||
const auto [graph, initialEstimate] = generateProblem();
|
|
||||||
// cout << "initial error = " << graph.error(initialEstimate) << endl;
|
|
||||||
|
|
||||||
NonlinearOptimizerParams param;
|
|
||||||
param.maxIterations = 500; /* requires a larger number of iterations to converge */
|
|
||||||
param.verbosity = NonlinearOptimizerParams::SILENT;
|
|
||||||
|
|
||||||
NonlinearConjugateGradientOptimizer optimizer(graph, initialEstimate, param);
|
|
||||||
Values result = optimizer.optimize();
|
|
||||||
// cout << "cg final error = " << graph.error(result) << endl;
|
|
||||||
|
|
||||||
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************* */
|
|
||||||
int main() {
|
|
||||||
TestResult tr;
|
|
||||||
return TestRegistry::runAllTests(tr);
|
|
||||||
}
|
|
||||||
/* ************************************************************************* */
|
|
||||||
|
|
@ -95,9 +95,9 @@ TEST( Iterative, conjugateGradientDescent_hard_constraint )
|
||||||
VectorValues zeros = config.zeroVectors();
|
VectorValues zeros = config.zeroVectors();
|
||||||
|
|
||||||
ConjugateGradientParameters parameters;
|
ConjugateGradientParameters parameters;
|
||||||
parameters.setEpsilon_abs(1e-3);
|
parameters.epsilon_abs = 1e-3;
|
||||||
parameters.setEpsilon_rel(1e-5);
|
parameters.epsilon_rel = 1e-5;
|
||||||
parameters.setMaxIterations(100);
|
parameters.maxIterations = 100;
|
||||||
VectorValues actual = conjugateGradientDescent(*fg, zeros, parameters);
|
VectorValues actual = conjugateGradientDescent(*fg, zeros, parameters);
|
||||||
|
|
||||||
VectorValues expected;
|
VectorValues expected;
|
||||||
|
|
@ -122,9 +122,9 @@ TEST( Iterative, conjugateGradientDescent_soft_constraint )
|
||||||
VectorValues zeros = config.zeroVectors();
|
VectorValues zeros = config.zeroVectors();
|
||||||
|
|
||||||
ConjugateGradientParameters parameters;
|
ConjugateGradientParameters parameters;
|
||||||
parameters.setEpsilon_abs(1e-3);
|
parameters.epsilon_abs = 1e-3;
|
||||||
parameters.setEpsilon_rel(1e-5);
|
parameters.epsilon_rel = 1e-5;
|
||||||
parameters.setMaxIterations(100);
|
parameters.maxIterations = 100;
|
||||||
VectorValues actual = conjugateGradientDescent(*fg, zeros, parameters);
|
VectorValues actual = conjugateGradientDescent(*fg, zeros, parameters);
|
||||||
|
|
||||||
VectorValues expected;
|
VectorValues expected;
|
||||||
|
|
|
||||||
|
|
@ -125,8 +125,8 @@ TEST( GaussianFactorGraphSystem, multiply_getb)
|
||||||
TEST(PCGSolver, dummy) {
|
TEST(PCGSolver, dummy) {
|
||||||
LevenbergMarquardtParams params;
|
LevenbergMarquardtParams params;
|
||||||
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
||||||
auto pcg = std::make_shared<PCGSolverParameters>();
|
auto pcg = std::make_shared<PCGSolverParameters>(
|
||||||
pcg->preconditioner_ = std::make_shared<DummyPreconditionerParameters>();
|
std::make_shared<DummyPreconditionerParameters>());
|
||||||
params.iterativeParams = pcg;
|
params.iterativeParams = pcg;
|
||||||
|
|
||||||
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
||||||
|
|
@ -145,9 +145,8 @@ TEST(PCGSolver, dummy) {
|
||||||
TEST(PCGSolver, blockjacobi) {
|
TEST(PCGSolver, blockjacobi) {
|
||||||
LevenbergMarquardtParams params;
|
LevenbergMarquardtParams params;
|
||||||
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
||||||
auto pcg = std::make_shared<PCGSolverParameters>();
|
auto pcg = std::make_shared<PCGSolverParameters>(
|
||||||
pcg->preconditioner_ =
|
std::make_shared<BlockJacobiPreconditionerParameters>());
|
||||||
std::make_shared<BlockJacobiPreconditionerParameters>();
|
|
||||||
params.iterativeParams = pcg;
|
params.iterativeParams = pcg;
|
||||||
|
|
||||||
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
||||||
|
|
@ -166,8 +165,8 @@ TEST(PCGSolver, blockjacobi) {
|
||||||
TEST(PCGSolver, subgraph) {
|
TEST(PCGSolver, subgraph) {
|
||||||
LevenbergMarquardtParams params;
|
LevenbergMarquardtParams params;
|
||||||
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
||||||
auto pcg = std::make_shared<PCGSolverParameters>();
|
auto pcg = std::make_shared<PCGSolverParameters>(
|
||||||
pcg->preconditioner_ = std::make_shared<SubgraphPreconditionerParameters>();
|
std::make_shared<SubgraphPreconditionerParameters>());
|
||||||
params.iterativeParams = pcg;
|
params.iterativeParams = pcg;
|
||||||
|
|
||||||
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
NonlinearFactorGraph fg = example::createReallyNonlinearFactorGraph();
|
||||||
|
|
|
||||||
|
|
@ -54,21 +54,23 @@ TEST( PCGsolver, verySimpleLinearSystem) {
|
||||||
// Solve the system using Preconditioned Conjugate Gradient solver
|
// Solve the system using Preconditioned Conjugate Gradient solver
|
||||||
// Common PCG parameters
|
// Common PCG parameters
|
||||||
gtsam::PCGSolverParameters::shared_ptr pcg = std::make_shared<gtsam::PCGSolverParameters>();
|
gtsam::PCGSolverParameters::shared_ptr pcg = std::make_shared<gtsam::PCGSolverParameters>();
|
||||||
pcg->setMaxIterations(500);
|
pcg->maxIterations = 500;
|
||||||
pcg->setEpsilon_abs(0.0);
|
pcg->epsilon_abs = 0.0;
|
||||||
pcg->setEpsilon_rel(0.0);
|
pcg->epsilon_rel = 0.0;
|
||||||
//pcg->setVerbosity("ERROR");
|
//pcg->setVerbosity("ERROR");
|
||||||
|
|
||||||
// With Dummy preconditioner
|
// With Dummy preconditioner
|
||||||
pcg->preconditioner_ = std::make_shared<gtsam::DummyPreconditionerParameters>();
|
pcg->preconditioner =
|
||||||
|
std::make_shared<gtsam::DummyPreconditionerParameters>();
|
||||||
VectorValues deltaPCGDummy = PCGSolver(*pcg).optimize(simpleGFG);
|
VectorValues deltaPCGDummy = PCGSolver(*pcg).optimize(simpleGFG);
|
||||||
EXPECT(assert_equal(exactSolution, deltaPCGDummy, 1e-7));
|
EXPECT(assert_equal(exactSolution, deltaPCGDummy, 1e-7));
|
||||||
//deltaPCGDummy.print("PCG Dummy");
|
//deltaPCGDummy.print("PCG Dummy");
|
||||||
|
|
||||||
// With Block-Jacobi preconditioner
|
// With Block-Jacobi preconditioner
|
||||||
pcg->preconditioner_ = std::make_shared<gtsam::BlockJacobiPreconditionerParameters>();
|
pcg->preconditioner =
|
||||||
|
std::make_shared<gtsam::BlockJacobiPreconditionerParameters>();
|
||||||
// It takes more than 1000 iterations for this test
|
// It takes more than 1000 iterations for this test
|
||||||
pcg->setMaxIterations(1500);
|
pcg->maxIterations = 1500;
|
||||||
VectorValues deltaPCGJacobi = PCGSolver(*pcg).optimize(simpleGFG);
|
VectorValues deltaPCGJacobi = PCGSolver(*pcg).optimize(simpleGFG);
|
||||||
|
|
||||||
EXPECT(assert_equal(exactSolution, deltaPCGJacobi, 1e-5));
|
EXPECT(assert_equal(exactSolution, deltaPCGJacobi, 1e-5));
|
||||||
|
|
@ -105,19 +107,21 @@ TEST(PCGSolver, simpleLinearSystem) {
|
||||||
// Solve the system using Preconditioned Conjugate Gradient solver
|
// Solve the system using Preconditioned Conjugate Gradient solver
|
||||||
// Common PCG parameters
|
// Common PCG parameters
|
||||||
gtsam::PCGSolverParameters::shared_ptr pcg = std::make_shared<gtsam::PCGSolverParameters>();
|
gtsam::PCGSolverParameters::shared_ptr pcg = std::make_shared<gtsam::PCGSolverParameters>();
|
||||||
pcg->setMaxIterations(500);
|
pcg->maxIterations = 500;
|
||||||
pcg->setEpsilon_abs(0.0);
|
pcg->epsilon_abs = 0.0;
|
||||||
pcg->setEpsilon_rel(0.0);
|
pcg->epsilon_rel = 0.0;
|
||||||
//pcg->setVerbosity("ERROR");
|
//pcg->setVerbosity("ERROR");
|
||||||
|
|
||||||
// With Dummy preconditioner
|
// With Dummy preconditioner
|
||||||
pcg->preconditioner_ = std::make_shared<gtsam::DummyPreconditionerParameters>();
|
pcg->preconditioner =
|
||||||
|
std::make_shared<gtsam::DummyPreconditionerParameters>();
|
||||||
VectorValues deltaPCGDummy = PCGSolver(*pcg).optimize(simpleGFG);
|
VectorValues deltaPCGDummy = PCGSolver(*pcg).optimize(simpleGFG);
|
||||||
EXPECT(assert_equal(expectedSolution, deltaPCGDummy, 1e-5));
|
EXPECT(assert_equal(expectedSolution, deltaPCGDummy, 1e-5));
|
||||||
//deltaPCGDummy.print("PCG Dummy");
|
//deltaPCGDummy.print("PCG Dummy");
|
||||||
|
|
||||||
// With Block-Jacobi preconditioner
|
// With Block-Jacobi preconditioner
|
||||||
pcg->preconditioner_ = std::make_shared<gtsam::BlockJacobiPreconditionerParameters>();
|
pcg->preconditioner =
|
||||||
|
std::make_shared<gtsam::BlockJacobiPreconditionerParameters>();
|
||||||
VectorValues deltaPCGJacobi = PCGSolver(*pcg).optimize(simpleGFG);
|
VectorValues deltaPCGJacobi = PCGSolver(*pcg).optimize(simpleGFG);
|
||||||
EXPECT(assert_equal(expectedSolution, deltaPCGJacobi, 1e-5));
|
EXPECT(assert_equal(expectedSolution, deltaPCGJacobi, 1e-5));
|
||||||
//deltaPCGJacobi.print("PCG Jacobi");
|
//deltaPCGJacobi.print("PCG Jacobi");
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ static double error(const GaussianFactorGraph& fg, const VectorValues& x) {
|
||||||
TEST( SubgraphSolver, Parameters )
|
TEST( SubgraphSolver, Parameters )
|
||||||
{
|
{
|
||||||
LONGS_EQUAL(SubgraphSolverParameters::SILENT, kParameters.verbosity());
|
LONGS_EQUAL(SubgraphSolverParameters::SILENT, kParameters.verbosity());
|
||||||
LONGS_EQUAL(500, kParameters.maxIterations());
|
LONGS_EQUAL(500, kParameters.maxIterations);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ int main(int argc, char* argv[]) {
|
||||||
// params.setVerbosityLM("SUMMARY");
|
// params.setVerbosityLM("SUMMARY");
|
||||||
// params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
// params.linearSolverType = LevenbergMarquardtParams::Iterative;
|
||||||
// auto pcg = std::make_shared<PCGSolverParameters>();
|
// auto pcg = std::make_shared<PCGSolverParameters>();
|
||||||
// pcg->preconditioner_ =
|
// pcg->preconditioner =
|
||||||
// std::make_shared<SubgraphPreconditionerParameters>();
|
// std::make_shared<SubgraphPreconditionerParameters>();
|
||||||
// std::make_shared<BlockJacobiPreconditionerParameters>();
|
// std::make_shared<BlockJacobiPreconditionerParameters>();
|
||||||
// params.iterativeParams = pcg;
|
// params.iterativeParams = pcg;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue