commit
beceb654b6
|
@ -0,0 +1,391 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Efficent Marginals Computation\n",
|
||||
"\n",
|
||||
"GTSAM can very efficiently calculate marginals in Bayes trees. In this post, we illustrate the “shortcut” mechanism for **caching** the conditional distribution $P(S \\mid R)$ in a Bayes tree, allowing efficient other marginal queries. We assume familiarity with **Bayes trees** from [the previous post](#)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Toy Example\n",
|
||||
"\n",
|
||||
"We create a small Bayes tree:\n",
|
||||
"\n",
|
||||
"\\begin{equation}\n",
|
||||
"P(a \\mid b) P(b,c \\mid r) P(f \\mid e) P(d,e \\mid r) P(r).\n",
|
||||
"\\end{equation}\n",
|
||||
"\n",
|
||||
"Below is some Python code (using GTSAM’s discrete wrappers) to define and build the corresponding Bayes tree. We'll use a discrete example, i.e., we'll create a `DiscreteBayesTree`.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from gtsam import DiscreteConditional, DiscreteBayesTree, DiscreteBayesTreeClique, DecisionTreeFactor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Make discrete keys (key in elimination order, cardinality):\n",
|
||||
"keys = [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 2), (6, 2)]\n",
|
||||
"names = {0: 'a', 1: 'f', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'r'}\n",
|
||||
"aKey, fKey, bKey, cKey, dKey, eKey, rKey = keys\n",
|
||||
"keyFormatter = lambda key: names[key]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 1. Root Clique: P(r)\n",
|
||||
"cliqueR = DiscreteBayesTreeClique(DiscreteConditional(rKey, \"0.4/0.6\"))\n",
|
||||
"\n",
|
||||
"# 2. Child Clique 1: P(b, c | r)\n",
|
||||
"cliqueBC = DiscreteBayesTreeClique(\n",
|
||||
" DiscreteConditional(\n",
|
||||
" 2, DecisionTreeFactor([bKey, cKey, rKey], \"0.3 0.7 0.1 0.9 0.2 0.8 0.4 0.6\")\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# 3. Child Clique 2: P(d, e | r)\n",
|
||||
"cliqueDE = DiscreteBayesTreeClique(\n",
|
||||
" DiscreteConditional(\n",
|
||||
" 2, DecisionTreeFactor([dKey, eKey, rKey], \"0.1 0.9 0.9 0.1 0.2 0.8 0.3 0.7\")\n",
|
||||
" )\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# 4. Leaf Clique from Child 1: P(a | b)\n",
|
||||
"cliqueA = DiscreteBayesTreeClique(DiscreteConditional(aKey, [bKey], \"1/3 3/1\"))\n",
|
||||
"\n",
|
||||
"# 5. Leaf Clique from Child 2: P(f | e)\n",
|
||||
"cliqueF = DiscreteBayesTreeClique(DiscreteConditional(fKey, [eKey], \"1/3 3/1\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build the BayesTree:\n",
|
||||
"bayesTree = DiscreteBayesTree()\n",
|
||||
"\n",
|
||||
"# Insert root:\n",
|
||||
"bayesTree.insertRoot(cliqueR)\n",
|
||||
"\n",
|
||||
"# Attach child cliques to root:\n",
|
||||
"bayesTree.addClique(cliqueBC, cliqueR)\n",
|
||||
"bayesTree.addClique(cliqueDE, cliqueR)\n",
|
||||
"\n",
|
||||
"# Attach leaf cliques:\n",
|
||||
"bayesTree.addClique(cliqueA, cliqueBC)\n",
|
||||
"bayesTree.addClique(cliqueF, cliqueDE)\n",
|
||||
"\n",
|
||||
"# bayesTree.print(\"bayesTree\", keyFormatter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||
"<!-- Generated by graphviz version 12.0.0 (0)\n",
|
||||
" -->\n",
|
||||
"<!-- Title: G Pages: 1 -->\n",
|
||||
"<svg width=\"168pt\" height=\"188pt\"\n",
|
||||
" viewBox=\"0.00 0.00 167.97 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
|
||||
"<title>G</title>\n",
|
||||
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-184 163.97,-184 163.97,4 -4,4\"/>\n",
|
||||
"<!-- 0 -->\n",
|
||||
"<g id=\"node1\" class=\"node\">\n",
|
||||
"<title>0</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"79.49\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"79.49\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">r</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 1 -->\n",
|
||||
"<g id=\"node2\" class=\"node\">\n",
|
||||
"<title>1</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"35.49\" cy=\"-90\" rx=\"35.49\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"35.49\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">b, c : r</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 0->1 -->\n",
|
||||
"<g id=\"edge1\" class=\"edge\">\n",
|
||||
"<title>0->1</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M69.5,-145.12C64.29,-136.82 57.78,-126.46 51.85,-117.03\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"54.83,-115.19 46.54,-108.59 48.9,-118.92 54.83,-115.19\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 3 -->\n",
|
||||
"<g id=\"node4\" class=\"node\">\n",
|
||||
"<title>3</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"124.49\" cy=\"-90\" rx=\"35.49\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"124.49\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">d, e : r</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 0->3 -->\n",
|
||||
"<g id=\"edge3\" class=\"edge\">\n",
|
||||
"<title>0->3</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M89.7,-145.12C95.09,-136.73 101.83,-126.24 107.94,-116.73\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"110.72,-118.88 113.18,-108.58 104.83,-115.1 110.72,-118.88\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 2 -->\n",
|
||||
"<g id=\"node3\" class=\"node\">\n",
|
||||
"<title>2</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"35.49\" cy=\"-18\" rx=\"27.3\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"35.49\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">a : b</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 1->2 -->\n",
|
||||
"<g id=\"edge2\" class=\"edge\">\n",
|
||||
"<title>1->2</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M35.49,-71.7C35.49,-64.41 35.49,-55.73 35.49,-47.54\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"38.99,-47.62 35.49,-37.62 31.99,-47.62 38.99,-47.62\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- 4 -->\n",
|
||||
"<g id=\"node5\" class=\"node\">\n",
|
||||
"<title>4</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"124.49\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"124.49\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">f : e</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- 3->4 -->\n",
|
||||
"<g id=\"edge4\" class=\"edge\">\n",
|
||||
"<title>3->4</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M124.49,-71.7C124.49,-64.41 124.49,-55.73 124.49,-47.54\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"127.99,-47.62 124.49,-37.62 120.99,-47.62 127.99,-47.62\"/>\n",
|
||||
"</g>\n",
|
||||
"</g>\n",
|
||||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.sources.Source at 0x10796f1a0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import graphviz\n",
|
||||
"graphviz.Source(bayesTree.dot(keyFormatter))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Naive Computation of P(a)\n",
|
||||
"The marginal $P(a)$ can be computed by summing out the other variables in the tree:\n",
|
||||
"$$\n",
|
||||
"P(a) = \\sum_{b, c, d, e, f, r} P(a, b, c, d, e, f, r)\n",
|
||||
"$$\n",
|
||||
"\n",
|
||||
"Using the Bayes tree structure, we have\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"P(a) = \\sum_{b, c, d, e, f, r} P(a \\mid b) P(b, c \\mid r) P(f \\mid e) P(d, e \\mid r) P(r) \n",
|
||||
"$$\n",
|
||||
"\n",
|
||||
"but we can ignore variables $e$ and $f$ not on the path from $a$ to the root $r$. Indeed, by associativity we have\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"P(a) = \\sum_{r} \\Bigl\\{ \\sum_{e,f} P(f \\mid e) P(d, e \\mid r) \\Bigr\\} \\sum_{b, c, d} P(a \\mid b) P(b, c \\mid r) P(r)\n",
|
||||
"$$\n",
|
||||
"\n",
|
||||
"where the grouped terms sum to one for any value of $r$, and hence\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"P(a) = \\sum_{r, b, c, d} P(a \\mid b) P(b, c \\mid r) P(r).\n",
|
||||
"$$"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Memoization via Shortcuts\n",
|
||||
"\n",
|
||||
"In GTSAM, we compute this recursively\n",
|
||||
"\n",
|
||||
"#### Step 1\n",
|
||||
"We want to compute the marginal via\n",
|
||||
"$$\n",
|
||||
"P(a) = \\sum_{r, b} P(a \\mid b) P(b).\n",
|
||||
"$$\n",
|
||||
"where $P(b)$ is the separator of this clique.\n",
|
||||
"\n",
|
||||
"#### Step 2\n",
|
||||
"To compute the separator marginal, we use the **shortcut** $P(b|r)$:\n",
|
||||
"$$\n",
|
||||
"P(b) = \\sum_{r} P(b \\mid r) P(r).\n",
|
||||
"$$\n",
|
||||
"In general, a shortcut $P(S|R)$ directly conditions this clique's separator $S$ on the root clique $R$, even if there are many other cliques in-between. That is why it is called a *shortcut*.\n",
|
||||
"\n",
|
||||
"#### Step 3 (optional)\n",
|
||||
"If the shortcut was already computed, then we are done! If not, we compute it recursively:\n",
|
||||
"$$\n",
|
||||
"P(S\\mid R) = \\sum_{F_p,\\,S_p \\setminus S}P(F_p \\mid S_p) P(S_p \\mid R).\n",
|
||||
"$$\n",
|
||||
"Above $P(F_p \\mid S_p)$ is the parent clique, and by the running intersection property we know that the seprator $S$ is a subset of the parent clique's variables.\n",
|
||||
"Note that the recursion is because we might not have $P(S_p \\mid R)$ yet, so it might have to be computed in turn, etc. The recursion ends at nodes below the root, and **after we have obtained $P(S\\mid R)$ we cache it**.\n",
|
||||
"\n",
|
||||
"In our example, the computation is simply\n",
|
||||
"$$\n",
|
||||
"P(b|r) = \\sum_{c} P(b, c \\mid r),\n",
|
||||
"$$\n",
|
||||
"because this the parent separator is already the root, so $P(S_p \\mid R)$ is omitted. This is also the end of the recursion.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0\n",
|
||||
"Marginal P(a):\n",
|
||||
" Discrete Conditional\n",
|
||||
" P( 0 ):\n",
|
||||
" Choice(0) \n",
|
||||
" 0 Leaf 0.51\n",
|
||||
" 1 Leaf 0.49\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Marginal of the leaf variable 'a':\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n",
|
||||
"marg_a = bayesTree.marginalFactor(aKey[0])\n",
|
||||
"print(\"Marginal P(a):\\n\", marg_a)\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"3\n",
|
||||
"Marginal P(b):\n",
|
||||
" Discrete Conditional\n",
|
||||
" P( 2 ):\n",
|
||||
" Choice(2) \n",
|
||||
" 0 Leaf 0.48\n",
|
||||
" 1 Leaf 0.52\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"# Marginal of the internal variable 'b':\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n",
|
||||
"marg_b = bayesTree.marginalFactor(bKey[0])\n",
|
||||
"print(\"Marginal P(b):\\n\", marg_b)\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"3\n",
|
||||
"Joint P(a, f):\n",
|
||||
" DiscreteBayesNet\n",
|
||||
" \n",
|
||||
"size: 2\n",
|
||||
"conditional 0: P( 0 | 1 ):\n",
|
||||
" Choice(1) \n",
|
||||
" 0 Choice(0) \n",
|
||||
" 0 0 Leaf 0.51758893\n",
|
||||
" 0 1 Leaf 0.48241107\n",
|
||||
" 1 Choice(0) \n",
|
||||
" 1 0 Leaf 0.50222672\n",
|
||||
" 1 1 Leaf 0.49777328\n",
|
||||
"\n",
|
||||
"conditional 1: P( 1 ):\n",
|
||||
" Choice(1) \n",
|
||||
" 0 Leaf 0.506\n",
|
||||
" 1 Leaf 0.494\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"# Joint of leaf variables 'a' and 'f': P(a, f)\n",
|
||||
"# This effectively needs to gather info from two different branches\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n",
|
||||
"marg_af = bayesTree.jointBayesNet(aKey[0], fKey[0])\n",
|
||||
"print(\"Joint P(a, f):\\n\", marg_af)\n",
|
||||
"print(bayesTree.numCachedSeparatorMarginals())\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "py312",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -670,17 +670,17 @@ virtual class NonlinearEquality2 : gtsam::NoiseModelFactor {
|
|||
};
|
||||
|
||||
#include <gtsam/nonlinear/FixedLagSmoother.h>
|
||||
// This class is not available in python, just use a dictionary
|
||||
class FixedLagSmootherKeyTimestampMapValue {
|
||||
FixedLagSmootherKeyTimestampMapValue(size_t key, double timestamp);
|
||||
FixedLagSmootherKeyTimestampMapValue(const gtsam::FixedLagSmootherKeyTimestampMapValue& other);
|
||||
};
|
||||
|
||||
// This class is not available in python, just use a dictionary
|
||||
class FixedLagSmootherKeyTimestampMap {
|
||||
FixedLagSmootherKeyTimestampMap();
|
||||
FixedLagSmootherKeyTimestampMap(const gtsam::FixedLagSmootherKeyTimestampMap& other);
|
||||
|
||||
// Note: no print function
|
||||
|
||||
// common STL methods
|
||||
size_t size() const;
|
||||
bool empty() const;
|
||||
|
@ -740,6 +740,7 @@ virtual class IncrementalFixedLagSmoother : gtsam::FixedLagSmoother {
|
|||
|
||||
void print(string s = "IncrementalFixedLagSmoother:\n") const;
|
||||
|
||||
gtsam::Matrix marginalCovariance(size_t key) const;
|
||||
gtsam::ISAM2Params params() const;
|
||||
|
||||
gtsam::NonlinearFactorGraph getFactors() const;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Symbolic Module
|
||||
# Symbolic
|
||||
|
||||
The `symbolic` module in GTSAM deals with the *structure* of factor graphs and Bayesian networks, independent of the specific numerical types of factors (like Gaussian or discrete). It allows for analyzing graph connectivity, determining optimal variable elimination orders, and understanding the sparsity structure of the resulting inference objects.
|
||||
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1,97 +0,0 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2018, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Simple robotics example using odometry measurements and bearing-range (laser) measurements
|
||||
Author: Alex Cunningham (C++), Kevin Deng & Frank Dellaert (Python)
|
||||
"""
|
||||
# pylint: disable=invalid-name, E1101
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import gtsam
|
||||
import numpy as np
|
||||
from gtsam.symbol_shorthand import L, X
|
||||
|
||||
# Create noise models
|
||||
PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.3, 0.3, 0.1]))
|
||||
ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.2, 0.2, 0.1]))
|
||||
MEASUREMENT_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.1, 0.2]))
|
||||
|
||||
|
||||
def main():
|
||||
"""Main runner"""
|
||||
|
||||
# Create an empty nonlinear factor graph
|
||||
graph = gtsam.NonlinearFactorGraph()
|
||||
|
||||
# Create the keys corresponding to unknown variables in the factor graph
|
||||
X1 = X(1)
|
||||
X2 = X(2)
|
||||
X3 = X(3)
|
||||
L1 = L(4)
|
||||
L2 = L(5)
|
||||
|
||||
# Add a prior on pose X1 at the origin. A prior factor consists of a mean and a noise model
|
||||
graph.add(
|
||||
gtsam.PriorFactorPose2(X1, gtsam.Pose2(0.0, 0.0, 0.0), PRIOR_NOISE))
|
||||
|
||||
# Add odometry factors between X1,X2 and X2,X3, respectively
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(X1, X2, gtsam.Pose2(2.0, 0.0, 0.0),
|
||||
ODOMETRY_NOISE))
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(X2, X3, gtsam.Pose2(2.0, 0.0, 0.0),
|
||||
ODOMETRY_NOISE))
|
||||
|
||||
# Add Range-Bearing measurements to two different landmarks L1 and L2
|
||||
graph.add(
|
||||
gtsam.BearingRangeFactor2D(X1, L1, gtsam.Rot2.fromDegrees(45),
|
||||
np.sqrt(4.0 + 4.0), MEASUREMENT_NOISE))
|
||||
graph.add(
|
||||
gtsam.BearingRangeFactor2D(X2, L1, gtsam.Rot2.fromDegrees(90), 2.0,
|
||||
MEASUREMENT_NOISE))
|
||||
graph.add(
|
||||
gtsam.BearingRangeFactor2D(X3, L2, gtsam.Rot2.fromDegrees(90), 2.0,
|
||||
MEASUREMENT_NOISE))
|
||||
|
||||
# Print graph
|
||||
print("Factor Graph:\n{}".format(graph))
|
||||
|
||||
# Create (deliberately inaccurate) initial estimate
|
||||
initial_estimate = gtsam.Values()
|
||||
initial_estimate.insert(X1, gtsam.Pose2(-0.25, 0.20, 0.15))
|
||||
initial_estimate.insert(X2, gtsam.Pose2(2.30, 0.10, -0.20))
|
||||
initial_estimate.insert(X3, gtsam.Pose2(4.10, 0.10, 0.10))
|
||||
initial_estimate.insert(L1, gtsam.Point2(1.80, 2.10))
|
||||
initial_estimate.insert(L2, gtsam.Point2(4.10, 1.80))
|
||||
|
||||
# Print
|
||||
print("Initial Estimate:\n{}".format(initial_estimate))
|
||||
|
||||
# Optimize using Levenberg-Marquardt optimization. The optimizer
|
||||
# accepts an optional set of configuration parameters, controlling
|
||||
# things like convergence criteria, the type of linear system solver
|
||||
# to use, and the amount of information displayed during optimization.
|
||||
# Here we will use the default set of parameters. See the
|
||||
# documentation for the full set of parameters.
|
||||
params = gtsam.LevenbergMarquardtParams()
|
||||
optimizer = gtsam.LevenbergMarquardtOptimizer(graph, initial_estimate,
|
||||
params)
|
||||
result = optimizer.optimize()
|
||||
print("\nFinal Result:\n{}".format(result))
|
||||
|
||||
# Calculate and print marginal covariances for all variables
|
||||
marginals = gtsam.Marginals(graph, result)
|
||||
for (key, s) in [(X1, "X1"), (X2, "X2"), (X3, "X3"), (L1, "L1"),
|
||||
(L2, "L2")]:
|
||||
print("{} covariance:\n{}\n".format(s,
|
||||
marginals.marginalCovariance(key)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because one or more lines are too long
|
@ -1,102 +0,0 @@
|
|||
"""
|
||||
GTSAM Copyright 2010-2018, Georgia Tech Research Corporation,
|
||||
Atlanta, Georgia 30332-0415
|
||||
All Rights Reserved
|
||||
Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
See LICENSE for the license information
|
||||
|
||||
Simple robotics example using odometry measurements and bearing-range (laser) measurements
|
||||
Author: Alex Cunningham (C++), Kevin Deng & Frank Dellaert (Python)
|
||||
"""
|
||||
# pylint: disable=invalid-name, E1101
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import gtsam
|
||||
import gtsam.utils.plot as gtsam_plot
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def main():
|
||||
"""Main runner."""
|
||||
# Create noise models
|
||||
PRIOR_NOISE = gtsam.noiseModel.Diagonal.Sigmas(gtsam.Point3(0.3, 0.3, 0.1))
|
||||
ODOMETRY_NOISE = gtsam.noiseModel.Diagonal.Sigmas(
|
||||
gtsam.Point3(0.2, 0.2, 0.1))
|
||||
|
||||
# 1. Create a factor graph container and add factors to it
|
||||
graph = gtsam.NonlinearFactorGraph()
|
||||
|
||||
# 2a. Add a prior on the first pose, setting it to the origin
|
||||
# A prior factor consists of a mean and a noise ODOMETRY_NOISE (covariance matrix)
|
||||
graph.add(gtsam.PriorFactorPose2(1, gtsam.Pose2(0, 0, 0), PRIOR_NOISE))
|
||||
|
||||
# 2b. Add odometry factors
|
||||
# Create odometry (Between) factors between consecutive poses
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(1, 2, gtsam.Pose2(2, 0, 0), ODOMETRY_NOISE))
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(2, 3, gtsam.Pose2(2, 0, math.pi / 2),
|
||||
ODOMETRY_NOISE))
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(3, 4, gtsam.Pose2(2, 0, math.pi / 2),
|
||||
ODOMETRY_NOISE))
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(4, 5, gtsam.Pose2(2, 0, math.pi / 2),
|
||||
ODOMETRY_NOISE))
|
||||
|
||||
# 2c. Add the loop closure constraint
|
||||
# This factor encodes the fact that we have returned to the same pose. In real
|
||||
# systems, these constraints may be identified in many ways, such as appearance-based
|
||||
# techniques with camera images. We will use another Between Factor to enforce this constraint:
|
||||
graph.add(
|
||||
gtsam.BetweenFactorPose2(5, 2, gtsam.Pose2(2, 0, math.pi / 2),
|
||||
ODOMETRY_NOISE))
|
||||
print("\nFactor Graph:\n{}".format(graph)) # print
|
||||
|
||||
# 3. Create the data structure to hold the initial_estimate estimate to the
|
||||
# solution. For illustrative purposes, these have been deliberately set to incorrect values
|
||||
initial_estimate = gtsam.Values()
|
||||
initial_estimate.insert(1, gtsam.Pose2(0.5, 0.0, 0.2))
|
||||
initial_estimate.insert(2, gtsam.Pose2(2.3, 0.1, -0.2))
|
||||
initial_estimate.insert(3, gtsam.Pose2(4.1, 0.1, math.pi / 2))
|
||||
initial_estimate.insert(4, gtsam.Pose2(4.0, 2.0, math.pi))
|
||||
initial_estimate.insert(5, gtsam.Pose2(2.1, 2.1, -math.pi / 2))
|
||||
print("\nInitial Estimate:\n{}".format(initial_estimate)) # print
|
||||
|
||||
# 4. Optimize the initial values using a Gauss-Newton nonlinear optimizer
|
||||
# The optimizer accepts an optional set of configuration parameters,
|
||||
# controlling things like convergence criteria, the type of linear
|
||||
# system solver to use, and the amount of information displayed during
|
||||
# optimization. We will set a few parameters as a demonstration.
|
||||
parameters = gtsam.GaussNewtonParams()
|
||||
|
||||
# Stop iterating once the change in error between steps is less than this value
|
||||
parameters.setRelativeErrorTol(1e-5)
|
||||
# Do not perform more than N iteration steps
|
||||
parameters.setMaxIterations(100)
|
||||
# Create the optimizer ...
|
||||
optimizer = gtsam.GaussNewtonOptimizer(graph, initial_estimate, parameters)
|
||||
# ... and optimize
|
||||
result = optimizer.optimize()
|
||||
print("Final Result:\n{}".format(result))
|
||||
|
||||
# 5. Calculate and print marginal covariances for all variables
|
||||
marginals = gtsam.Marginals(graph, result)
|
||||
for i in range(1, 6):
|
||||
print("X{} covariance:\n{}\n".format(i,
|
||||
marginals.marginalCovariance(i)))
|
||||
|
||||
for i in range(1, 6):
|
||||
gtsam_plot.plot_pose2(0, result.atPose2(i), 0.5,
|
||||
marginals.marginalCovariance(i))
|
||||
|
||||
plt.axis('equal')
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,519 @@
|
|||
# gtsam_plotly.py
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import graphviz
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
from tqdm.notebook import tqdm # Optional progress bar
|
||||
|
||||
import gtsam
|
||||
|
||||
|
||||
# --- Dataclass for History ---
|
||||
@dataclass
|
||||
class SlamFrameData:
|
||||
"""Holds all data needed for a single frame of the SLAM animation."""
|
||||
|
||||
step_index: int
|
||||
results: gtsam.Values # Estimates for variables active at this step
|
||||
marginals: Optional[gtsam.Marginals] # Marginals for variables active at this step
|
||||
graph_dot_string: Optional[str] = None # Graphviz DOT string for visualization
|
||||
|
||||
|
||||
# --- Core Ellipse Calculation & Path Generation ---
|
||||
|
||||
|
||||
def create_ellipse_path_from_cov(
|
||||
cx: float, cy: float, cov_matrix: np.ndarray, scale: float = 2.0, N: int = 60
|
||||
) -> str:
|
||||
"""Generates SVG path string for an ellipse from 2x2 covariance."""
|
||||
cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite 2x2
|
||||
try:
|
||||
eigvals, eigvecs = np.linalg.eigh(cov)
|
||||
eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues
|
||||
minor_std, major_std = np.sqrt(eigvals) # eigh sorts ascending
|
||||
v_minor, v_major = eigvecs[:, 0], eigvecs[:, 1]
|
||||
except np.linalg.LinAlgError:
|
||||
# Fallback to a small circle if decomposition fails
|
||||
radius = 0.1 * scale
|
||||
t = np.linspace(0, 2 * np.pi, N)
|
||||
x_p = cx + radius * np.cos(t)
|
||||
y_p = cy + radius * np.sin(t)
|
||||
else:
|
||||
# Parametric equation using eigenvectors and eigenvalues
|
||||
t = np.linspace(0, 2 * np.pi, N)
|
||||
cos_t, sin_t = np.cos(t), np.sin(t)
|
||||
x_p = cx + scale * (
|
||||
major_std * cos_t * v_major[0] + minor_std * sin_t * v_minor[0]
|
||||
)
|
||||
y_p = cy + scale * (
|
||||
major_std * cos_t * v_major[1] + minor_std * sin_t * v_minor[1]
|
||||
)
|
||||
|
||||
# Build SVG path string
|
||||
path = (
|
||||
f"M {x_p[0]},{y_p[0]} "
|
||||
+ " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_p[1:], y_p[1:]))
|
||||
+ " Z"
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
# --- Plotly Element Generators ---
|
||||
|
||||
|
||||
def create_gt_landmarks_trace(
|
||||
landmarks_gt: Optional[np.ndarray],
|
||||
) -> Optional[go.Scatter]:
|
||||
"""Creates scatter trace for ground truth landmarks."""
|
||||
if landmarks_gt is None or landmarks_gt.size == 0:
|
||||
return None
|
||||
return go.Scatter(
|
||||
x=landmarks_gt[0, :],
|
||||
y=landmarks_gt[1, :],
|
||||
mode="markers",
|
||||
marker=dict(color="black", size=8, symbol="star"),
|
||||
name="Landmarks GT",
|
||||
)
|
||||
|
||||
|
||||
def create_gt_path_trace(poses_gt: Optional[List[gtsam.Pose2]]) -> Optional[go.Scatter]:
|
||||
"""Creates line trace for ground truth path."""
|
||||
if not poses_gt:
|
||||
return None
|
||||
return go.Scatter(
|
||||
x=[p.x() for p in poses_gt],
|
||||
y=[p.y() for p in poses_gt],
|
||||
mode="lines",
|
||||
line=dict(color="gray", width=1, dash="dash"),
|
||||
name="Path GT",
|
||||
)
|
||||
|
||||
|
||||
def create_est_path_trace(
|
||||
est_path_x: List[float], est_path_y: List[float]
|
||||
) -> go.Scatter:
|
||||
"""Creates trace for the estimated path (all poses up to current)."""
|
||||
return go.Scatter(
|
||||
x=est_path_x,
|
||||
y=est_path_y,
|
||||
mode="lines+markers",
|
||||
line=dict(color="rgba(255, 0, 0, 0.3)", width=1), # Fainter line for history
|
||||
marker=dict(size=4, color="red"), # Keep markers prominent
|
||||
name="Path Est",
|
||||
)
|
||||
|
||||
|
||||
def create_current_pose_trace(
|
||||
current_pose: Optional[gtsam.Pose2],
|
||||
) -> Optional[go.Scatter]:
|
||||
"""Creates a single marker trace for the current estimated pose."""
|
||||
if current_pose is None:
|
||||
return None
|
||||
return go.Scatter(
|
||||
x=[current_pose.x()],
|
||||
y=[current_pose.y()],
|
||||
mode="markers",
|
||||
marker=dict(color="magenta", size=10, symbol="cross"),
|
||||
name="Current Pose",
|
||||
)
|
||||
|
||||
|
||||
def create_est_landmarks_trace(
|
||||
est_landmarks_x: List[float], est_landmarks_y: List[float]
|
||||
) -> Optional[go.Scatter]:
|
||||
"""Creates trace for currently estimated landmarks."""
|
||||
if not est_landmarks_x:
|
||||
return None
|
||||
return go.Scatter(
|
||||
x=est_landmarks_x,
|
||||
y=est_landmarks_y,
|
||||
mode="markers",
|
||||
marker=dict(color="blue", size=6, symbol="x"),
|
||||
name="Landmarks Est",
|
||||
)
|
||||
|
||||
|
||||
def _create_ellipse_shape_dict(
|
||||
cx: float, cy: float, cov: np.ndarray, scale: float, fillcolor: str, line_color: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Helper: Creates dict for a Plotly ellipse shape from covariance."""
|
||||
path = create_ellipse_path_from_cov(cx, cy, cov, scale)
|
||||
return dict(
|
||||
type="path",
|
||||
path=path,
|
||||
xref="x",
|
||||
yref="y",
|
||||
fillcolor=fillcolor,
|
||||
line_color=line_color,
|
||||
opacity=0.7, # Make ellipses slightly transparent
|
||||
)
|
||||
|
||||
|
||||
def create_pose_ellipse_shape(
|
||||
pose_mean_xy: np.ndarray, pose_cov: np.ndarray, scale: float
|
||||
) -> Dict[str, Any]:
|
||||
"""Creates shape dict for a pose covariance ellipse."""
|
||||
return _create_ellipse_shape_dict(
|
||||
cx=pose_mean_xy[0],
|
||||
cy=pose_mean_xy[1],
|
||||
cov=pose_cov,
|
||||
scale=scale,
|
||||
fillcolor="rgba(255,0,255,0.2)", # Magenta fill
|
||||
line_color="rgba(255,0,255,0.5)", # Magenta line
|
||||
)
|
||||
|
||||
|
||||
def create_landmark_ellipse_shape(
|
||||
lm_mean_xy: np.ndarray, lm_cov: np.ndarray, scale: float
|
||||
) -> Dict[str, Any]:
|
||||
"""Creates shape dict for a landmark covariance ellipse."""
|
||||
return _create_ellipse_shape_dict(
|
||||
cx=lm_mean_xy[0],
|
||||
cy=lm_mean_xy[1],
|
||||
cov=lm_cov,
|
||||
scale=scale,
|
||||
fillcolor="rgba(0,0,255,0.1)", # Blue fill
|
||||
line_color="rgba(0,0,255,0.3)", # Blue line
|
||||
)
|
||||
|
||||
|
||||
def dot_string_to_base64_svg(
|
||||
dot_string: Optional[str], engine: str = "neato"
|
||||
) -> Optional[str]:
|
||||
"""Renders a DOT string to a base64 encoded SVG using graphviz."""
|
||||
if not dot_string:
|
||||
return None
|
||||
try:
|
||||
source = graphviz.Source(dot_string, engine=engine)
|
||||
svg_bytes = source.pipe(format="svg")
|
||||
encoded_string = base64.b64encode(svg_bytes).decode("utf-8")
|
||||
return f"data:image/svg+xml;base64,{encoded_string}"
|
||||
except graphviz.backend.execute.CalledProcessError as e:
|
||||
print(f"Graphviz rendering error ({engine}): {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error during Graphviz SVG generation: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# --- Frame Content Generation ---
|
||||
def generate_frame_content(
|
||||
frame_data: SlamFrameData,
|
||||
X: Callable[[int], int],
|
||||
L: Callable[[int], int],
|
||||
max_landmark_index: int,
|
||||
ellipse_scale: float = 2.0,
|
||||
graphviz_engine: str = "neato",
|
||||
verbose: bool = False,
|
||||
) -> Tuple[List[go.Scatter], List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
||||
"""Generates dynamic traces, shapes, and layout image for a single frame."""
|
||||
k = frame_data.step_index
|
||||
# Use the results specific to this frame for current elements
|
||||
step_results = frame_data.results
|
||||
step_marginals = frame_data.marginals
|
||||
|
||||
frame_dynamic_traces: List[go.Scatter] = []
|
||||
frame_shapes: List[Dict[str, Any]] = []
|
||||
layout_image: Optional[Dict[str, Any]] = None
|
||||
|
||||
# 1. Estimated Path (Full History or Partial)
|
||||
est_path_x = []
|
||||
est_path_y = []
|
||||
current_pose_est = None
|
||||
|
||||
# Plot poses currently existing in the step_results (e.g., within lag)
|
||||
for i in range(k + 1): # Check poses up to current step index
|
||||
pose_key = X(i)
|
||||
if step_results.exists(pose_key):
|
||||
pose = step_results.atPose2(pose_key)
|
||||
est_path_x.append(pose.x())
|
||||
est_path_y.append(pose.y())
|
||||
if i == k:
|
||||
current_pose_est = pose
|
||||
|
||||
path_trace = create_est_path_trace(est_path_x, est_path_y)
|
||||
if path_trace:
|
||||
frame_dynamic_traces.append(path_trace)
|
||||
|
||||
# Add a distinct marker for the current pose estimate
|
||||
current_pose_trace = create_current_pose_trace(current_pose_est)
|
||||
if current_pose_trace:
|
||||
frame_dynamic_traces.append(current_pose_trace)
|
||||
|
||||
# 2. Estimated Landmarks (Only those present in step_results)
|
||||
est_landmarks_x, est_landmarks_y, landmark_keys = [], [], []
|
||||
for j in range(max_landmark_index + 1):
|
||||
lm_key = L(j)
|
||||
# Check existence in the results for the *current frame*
|
||||
if step_results.exists(lm_key):
|
||||
lm_point = step_results.atPoint2(lm_key)
|
||||
est_landmarks_x.append(lm_point[0])
|
||||
est_landmarks_y.append(lm_point[1])
|
||||
landmark_keys.append(lm_key) # Store keys for covariance lookup
|
||||
|
||||
lm_trace = create_est_landmarks_trace(est_landmarks_x, est_landmarks_y)
|
||||
if lm_trace:
|
||||
frame_dynamic_traces.append(lm_trace)
|
||||
|
||||
# 3. Covariance Ellipses (Only for items in step_results AND step_marginals)
|
||||
if step_marginals:
|
||||
# Current Pose Ellipse
|
||||
pose_key = X(k)
|
||||
# Retrieve cov from marginals specific to this frame
|
||||
cov = step_marginals.marginalCovariance(pose_key)
|
||||
# Ensure mean comes from the pose in current results
|
||||
mean = step_results.atPose2(pose_key).translation()
|
||||
frame_shapes.append(create_pose_ellipse_shape(mean, cov, ellipse_scale))
|
||||
|
||||
# Landmark Ellipses (Iterate over keys found in step_results)
|
||||
for lm_key in landmark_keys:
|
||||
try:
|
||||
# Retrieve cov from marginals specific to this frame
|
||||
cov = step_marginals.marginalCovariance(lm_key)
|
||||
# Ensure mean comes from the landmark in current results
|
||||
mean = step_results.atPoint2(lm_key)
|
||||
frame_shapes.append(
|
||||
create_landmark_ellipse_shape(mean, cov, ellipse_scale)
|
||||
)
|
||||
except RuntimeError: # Covariance might not be available
|
||||
if verbose:
|
||||
print(
|
||||
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov not in marginals @ step {k}"
|
||||
)
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
print(
|
||||
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov OTHER err @ step {k}: {e}"
|
||||
)
|
||||
|
||||
# 4. Graph Image for Layout
|
||||
img_src = dot_string_to_base64_svg(
|
||||
frame_data.graph_dot_string, engine=graphviz_engine
|
||||
)
|
||||
if img_src:
|
||||
layout_image = dict(
|
||||
source=img_src,
|
||||
xref="paper",
|
||||
yref="paper",
|
||||
x=0,
|
||||
y=1,
|
||||
sizex=0.48,
|
||||
sizey=1,
|
||||
xanchor="left",
|
||||
yanchor="top",
|
||||
layer="below",
|
||||
sizing="contain",
|
||||
)
|
||||
|
||||
# Return dynamic elements for this frame
|
||||
return frame_dynamic_traces, frame_shapes, layout_image
|
||||
|
||||
|
||||
# --- Figure Configuration ---
|
||||
|
||||
|
||||
def configure_figure_layout(
|
||||
fig: go.Figure,
|
||||
num_steps: int,
|
||||
world_size: float,
|
||||
initial_shapes: List[Dict[str, Any]],
|
||||
initial_image: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Configures Plotly figure layout for side-by-side display."""
|
||||
steps = list(range(num_steps + 1))
|
||||
plot_domain = [0.52, 1.0] # Right pane for the SLAM plot
|
||||
|
||||
sliders = [
|
||||
dict(
|
||||
active=0,
|
||||
currentvalue={"prefix": "Step: "},
|
||||
pad={"t": 50},
|
||||
steps=[
|
||||
dict(
|
||||
label=str(k),
|
||||
method="animate",
|
||||
args=[
|
||||
[str(k)],
|
||||
dict(
|
||||
mode="immediate",
|
||||
frame=dict(duration=100, redraw=True),
|
||||
transition=dict(duration=0),
|
||||
),
|
||||
],
|
||||
)
|
||||
for k in steps
|
||||
],
|
||||
)
|
||||
]
|
||||
updatemenus = [
|
||||
dict(
|
||||
type="buttons",
|
||||
showactive=False,
|
||||
direction="left",
|
||||
pad={"r": 10, "t": 87},
|
||||
x=plot_domain[0],
|
||||
xanchor="left",
|
||||
y=0,
|
||||
yanchor="top",
|
||||
buttons=[
|
||||
dict(
|
||||
label="Play",
|
||||
method="animate",
|
||||
args=[
|
||||
None,
|
||||
dict(
|
||||
mode="immediate",
|
||||
frame=dict(duration=100, redraw=True),
|
||||
transition=dict(duration=0),
|
||||
fromcurrent=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
dict(
|
||||
label="Pause",
|
||||
method="animate",
|
||||
args=[
|
||||
[None],
|
||||
dict(
|
||||
mode="immediate",
|
||||
frame=dict(duration=0, redraw=False),
|
||||
transition=dict(duration=0),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
fig.update_layout(
|
||||
title="Factor Graph SLAM Animation (Graph Left, Results Right)",
|
||||
xaxis=dict(
|
||||
range=[-world_size / 2 - 2, world_size / 2 + 2],
|
||||
domain=plot_domain,
|
||||
constrain="domain",
|
||||
),
|
||||
yaxis=dict(
|
||||
range=[-world_size / 2 - 2, world_size / 2 + 2],
|
||||
scaleanchor="x",
|
||||
scaleratio=1,
|
||||
domain=[0, 1],
|
||||
),
|
||||
width=1000,
|
||||
height=600,
|
||||
hovermode="closest",
|
||||
updatemenus=updatemenus,
|
||||
sliders=sliders,
|
||||
shapes=initial_shapes, # Initial shapes (frame 0)
|
||||
images=([initial_image] if initial_image else []), # Initial image (frame 0)
|
||||
showlegend=True, # Keep legend for clarity
|
||||
legend=dict(
|
||||
x=plot_domain[0],
|
||||
y=1,
|
||||
traceorder="normal", # Position legend
|
||||
bgcolor="rgba(255,255,255,0.5)",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# --- Main Animation Orchestrator ---
|
||||
|
||||
|
||||
def create_slam_animation(
|
||||
history: List[SlamFrameData],
|
||||
X: Callable[[int], int],
|
||||
L: Callable[[int], int],
|
||||
max_landmark_index: int,
|
||||
landmarks_gt_array: Optional[np.ndarray] = None,
|
||||
poses_gt: Optional[List[gtsam.Pose2]] = None,
|
||||
world_size: float = 20.0,
|
||||
ellipse_scale: float = 2.0,
|
||||
graphviz_engine: str = "neato",
|
||||
verbose_cov_errors: bool = False,
|
||||
) -> go.Figure:
|
||||
"""Creates a side-by-side Plotly SLAM animation using a history of dataclasses."""
|
||||
if not history:
|
||||
raise ValueError("History cannot be empty.")
|
||||
print("Generating Plotly animation...")
|
||||
num_steps = history[-1].step_index
|
||||
fig = go.Figure()
|
||||
|
||||
# 1. Create static GT traces ONCE
|
||||
gt_traces = []
|
||||
gt_lm_trace = create_gt_landmarks_trace(landmarks_gt_array)
|
||||
if gt_lm_trace:
|
||||
gt_traces.append(gt_lm_trace)
|
||||
gt_path_trace = create_gt_path_trace(poses_gt)
|
||||
if gt_path_trace:
|
||||
gt_traces.append(gt_path_trace)
|
||||
|
||||
# 2. Generate content for the initial frame (k=0) to set up the figure
|
||||
initial_frame_data = next((item for item in history if item.step_index == 0), None)
|
||||
if initial_frame_data is None:
|
||||
raise ValueError("History must contain data for step 0.")
|
||||
|
||||
(
|
||||
initial_dynamic_traces,
|
||||
initial_shapes,
|
||||
initial_image,
|
||||
) = generate_frame_content(
|
||||
initial_frame_data,
|
||||
X,
|
||||
L,
|
||||
max_landmark_index,
|
||||
ellipse_scale,
|
||||
graphviz_engine,
|
||||
verbose_cov_errors,
|
||||
)
|
||||
|
||||
# 3. Add initial traces (GT + dynamic frame 0)
|
||||
for trace in gt_traces:
|
||||
fig.add_trace(trace)
|
||||
for trace in initial_dynamic_traces:
|
||||
fig.add_trace(trace)
|
||||
|
||||
# 4. Generate frames for the animation (k=0 to num_steps)
|
||||
frames = []
|
||||
steps_iterable = range(num_steps + 1)
|
||||
steps_iterable = tqdm(steps_iterable, desc="Creating Frames")
|
||||
|
||||
for k in steps_iterable:
|
||||
frame_data = next((item for item in history if item.step_index == k), None)
|
||||
|
||||
# Generate dynamic content specific to this frame
|
||||
frame_dynamic_traces, frame_shapes, layout_image = generate_frame_content(
|
||||
frame_data,
|
||||
X,
|
||||
L,
|
||||
max_landmark_index,
|
||||
ellipse_scale,
|
||||
graphviz_engine,
|
||||
verbose_cov_errors,
|
||||
)
|
||||
|
||||
# Frame definition: includes static GT + dynamic traces for this step
|
||||
# Layout updates only include shapes and images for this step
|
||||
frames.append(
|
||||
go.Frame(
|
||||
data=gt_traces
|
||||
+ frame_dynamic_traces, # GT must be in each frame's data
|
||||
name=str(k),
|
||||
layout=go.Layout(
|
||||
shapes=frame_shapes, # Replaces shapes list for this frame
|
||||
images=(
|
||||
[layout_image] if layout_image else []
|
||||
), # Replaces image list
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Assign frames to the figure
|
||||
fig.update(frames=frames)
|
||||
|
||||
# 6. Configure overall layout (sliders, buttons, axes, etc.)
|
||||
configure_figure_layout(fig, num_steps, world_size, initial_shapes, initial_image)
|
||||
|
||||
print("Plotly animation generated.")
|
||||
return fig
|
|
@ -0,0 +1,137 @@
|
|||
# simulation.py
|
||||
import numpy as np
|
||||
|
||||
import gtsam
|
||||
|
||||
|
||||
def generate_simulation_data(
|
||||
num_landmarks,
|
||||
world_size,
|
||||
robot_radius,
|
||||
robot_angular_vel,
|
||||
num_steps,
|
||||
dt,
|
||||
odometry_noise_model,
|
||||
measurement_noise_model,
|
||||
max_sensor_range,
|
||||
X, # Symbol generator function
|
||||
L, # Symbol generator function
|
||||
odom_seed=42,
|
||||
meas_seed=42,
|
||||
landmark_seed=42,
|
||||
):
|
||||
"""Generates ground truth and simulated measurements for SLAM.
|
||||
|
||||
Args:
|
||||
num_landmarks: Number of landmarks to generate.
|
||||
world_size: Size of the square world environment.
|
||||
robot_radius: Radius of the robot's circular path.
|
||||
robot_angular_vel: Angular velocity of the robot (rad/step).
|
||||
num_steps: Number of simulation steps.
|
||||
dt: Time step duration.
|
||||
odometry_noise_model: GTSAM noise model for odometry.
|
||||
measurement_noise_model: GTSAM noise model for bearing-range.
|
||||
max_sensor_range: Maximum range of the bearing-range sensor.
|
||||
X: GTSAM symbol shorthand function for poses.
|
||||
L: GTSAM symbol shorthand function for landmarks.
|
||||
odom_seed: Random seed for odometry noise.
|
||||
meas_seed: Random seed for measurement noise.
|
||||
landmark_seed: Random seed for landmark placement.
|
||||
|
||||
Returns:
|
||||
tuple: Contains:
|
||||
- landmarks_gt_dict (dict): L(i) -> gtsam.Point2 ground truth.
|
||||
- poses_gt (list): List of gtsam.Pose2 ground truth poses.
|
||||
- odometry_measurements (list): List of noisy gtsam.Pose2 odometry.
|
||||
- measurements_sim (list): List of lists, measurements_sim[k] contains
|
||||
tuples (L(lm_id), bearing, range) for step k.
|
||||
- landmarks_gt_array (np.array): 2xN numpy array of landmark positions.
|
||||
"""
|
||||
np.random.seed(landmark_seed)
|
||||
odometry_noise_sampler = gtsam.Sampler(odometry_noise_model, odom_seed)
|
||||
measurement_noise_sampler = gtsam.Sampler(measurement_noise_model, meas_seed)
|
||||
|
||||
# 1. Ground Truth Landmarks
|
||||
landmarks_gt_array = (np.random.rand(2, num_landmarks) - 0.5) * world_size
|
||||
landmarks_gt_dict = {
|
||||
L(i): gtsam.Point2(landmarks_gt_array[:, i]) for i in range(num_landmarks)
|
||||
}
|
||||
|
||||
# 2. Ground Truth Robot Path
|
||||
poses_gt = []
|
||||
current_pose_gt = gtsam.Pose2(robot_radius, 0, np.pi / 2) # Start on circle edge
|
||||
poses_gt.append(current_pose_gt)
|
||||
|
||||
for _ in range(num_steps):
|
||||
delta_theta = robot_angular_vel * dt
|
||||
arc_length = robot_angular_vel * robot_radius * dt
|
||||
motion_command = gtsam.Pose2(arc_length, 0, delta_theta)
|
||||
current_pose_gt = current_pose_gt.compose(motion_command)
|
||||
poses_gt.append(current_pose_gt)
|
||||
|
||||
# 3. Simulate Noisy Odometry Measurements
|
||||
odometry_measurements = []
|
||||
for k in range(num_steps):
|
||||
pose_k = poses_gt[k]
|
||||
pose_k1 = poses_gt[k + 1]
|
||||
true_odom = pose_k.between(pose_k1)
|
||||
|
||||
# Sample noise directly for Pose2 composition (approximate)
|
||||
odom_noise_vec = odometry_noise_sampler.sample()
|
||||
noisy_odom = true_odom.compose(
|
||||
gtsam.Pose2(odom_noise_vec[0], odom_noise_vec[1], odom_noise_vec[2])
|
||||
)
|
||||
odometry_measurements.append(noisy_odom)
|
||||
|
||||
# 4. Simulate Noisy Bearing-Range Measurements
|
||||
measurements_sim = [[] for _ in range(num_steps + 1)]
|
||||
for k in range(num_steps + 1):
|
||||
robot_pose = poses_gt[k]
|
||||
for lm_id in range(num_landmarks):
|
||||
lm_gt_pt = landmarks_gt_dict[L(lm_id)]
|
||||
try:
|
||||
measurement_factor = gtsam.BearingRangeFactor2D(
|
||||
X(k),
|
||||
L(lm_id),
|
||||
robot_pose.bearing(lm_gt_pt),
|
||||
robot_pose.range(lm_gt_pt),
|
||||
measurement_noise_model,
|
||||
)
|
||||
true_range = measurement_factor.measured().range()
|
||||
true_bearing = measurement_factor.measured().bearing()
|
||||
|
||||
# Check sensor limits (range and Field of View - e.g. +/- 45 degrees)
|
||||
if (
|
||||
true_range <= max_sensor_range
|
||||
and abs(true_bearing.theta()) < np.pi / 2
|
||||
):
|
||||
# Sample noise
|
||||
noise_vec = measurement_noise_sampler.sample()
|
||||
noisy_bearing = true_bearing.retract(
|
||||
np.array([noise_vec[0]])
|
||||
) # Retract on SO(2)
|
||||
noisy_range = true_range + noise_vec[1]
|
||||
|
||||
if noisy_range > 0: # Ensure range is positive
|
||||
measurements_sim[k].append(
|
||||
(L(lm_id), noisy_bearing, noisy_range)
|
||||
)
|
||||
except Exception as e:
|
||||
# Catch potential errors like point being too close to the pose
|
||||
# print(f"Sim Warning at step {k}, landmark {lm_id}: {e}") # Can be verbose
|
||||
pass
|
||||
|
||||
print(f"Simulation Generated: {num_landmarks} landmarks.")
|
||||
print(
|
||||
f"Simulation Generated: {num_steps + 1} ground truth poses and {num_steps} odometry measurements."
|
||||
)
|
||||
num_meas_total = sum(len(m_list) for m_list in measurements_sim)
|
||||
print(f"Simulation Generated: {num_meas_total} bearing-range measurements.")
|
||||
|
||||
return (
|
||||
landmarks_gt_dict,
|
||||
poses_gt,
|
||||
odometry_measurements,
|
||||
measurements_sim,
|
||||
landmarks_gt_array,
|
||||
)
|
Loading…
Reference in New Issue