gtsam/gtsam/inference/doc/BayesNet.ipynb

401 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_intro_md"
},
"source": [
"# BayesNet"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_desc_md"
},
"source": [
"A `BayesNet` in GTSAM represents a directed graphical model, specifically the result of running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph`.\n",
"\n",
"It is essentially a collection of `Conditional` objects, ordered according to the elimination order. Each conditional represents $P(\\text{variable} | \\text{parents})$, where the parents are variables that appear later in the elimination ordering.\n",
"\n",
"A Bayes net represents the joint probability distribution as a product of conditional probabilities stored in the net:\n",
"\n",
"$$\n",
"P(X_1, X_2, \\dots, X_N) = \\prod_{i=1}^N P(X_i | \\text{Parents}(X_i))\n",
"$$\n",
"The total log-probability of an assignment is the sum of the log-probabilities of its conditionals:\n",
"$$\n",
"\\log P(X_1, \\dots, X_N) = \\sum_{i=1}^N \\log P(X_i | \\text{Parents}(X_i))\n",
"$$\n",
"\n",
"Like `FactorGraph`, `BayesNet` is templated on the type of conditional it stores (e.g., `GaussianBayesNet`, `DiscreteBayesNet`)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_colab_md"
},
"source": [
"<a href=\"https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/inference/doc/BayesNet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bayesnet_pip_code",
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"%pip install gtsam"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "bayesnet_import_code"
},
"outputs": [],
"source": [
"import gtsam\n",
"import numpy as np\n",
"\n",
"# We need concrete graph types and elimination to get a BayesNet\n",
"from gtsam import GaussianFactorGraph, Ordering, GaussianBayesNet\n",
"from gtsam import symbol_shorthand\n",
"\n",
"X = symbol_shorthand.X\n",
"L = symbol_shorthand.L"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_create_md"
},
"source": [
"## Creating a BayesNet (via Elimination)\n",
"\n",
"BayesNets are typically obtained by eliminating a `FactorGraph`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bayesnet_eliminate_code",
"outputId": "01234567-89ab-cdef-0123-456789abcdef"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Factor Graph:\n",
"\n",
"size: 3\n",
"factor 0: \n",
" A[x0] = [\n",
"\t-1\n",
"]\n",
" b = [ 0 ]\n",
" Noise model: unit (1) \n",
"factor 1: \n",
" A[x0] = [\n",
"\t-1\n",
"]\n",
" A[x1] = [\n",
"\t1\n",
"]\n",
" b = [ 0 ]\n",
" Noise model: unit (1) \n",
"factor 2: \n",
" A[x1] = [\n",
"\t-1\n",
"]\n",
" A[x2] = [\n",
"\t1\n",
"]\n",
" b = [ 0 ]\n",
" Noise model: unit (1) \n",
"\n",
"Resulting BayesNet:\n",
"\n",
"size: 3\n",
"conditional 0: p(x0 | x1)\n",
" R = [ 1.41421 ]\n",
" S[x1] = [ -0.707107 ]\n",
" d = [ 0 ]\n",
" logNormalizationConstant: -0.572365\n",
" No noise model\n",
"conditional 1: p(x1 | x2)\n",
" R = [ 1.22474 ]\n",
" S[x2] = [ -0.816497 ]\n",
" d = [ 0 ]\n",
" logNormalizationConstant: -0.716206\n",
" No noise model\n",
"conditional 2: p(x2)\n",
" R = [ 0.57735 ]\n",
" d = [ 0 ]\n",
" mean: 1 elements\n",
" x2: 0\n",
" logNormalizationConstant: -1.46824\n",
" No noise model\n"
]
}
],
"source": [
"# Create a simple Gaussian Factor Graph P(x0) P(x1|x0) P(x2|x1)\n",
"graph = GaussianFactorGraph()\n",
"model = gtsam.noiseModel.Isotropic.Sigma(1, 1.0)\n",
"graph.add(X(0), -np.eye(1), np.zeros(1), model)\n",
"graph.add(X(0), -np.eye(1), X(1), np.eye(1), np.zeros(1), model)\n",
"graph.add(X(1), -np.eye(1), X(2), np.eye(1), np.zeros(1), model)\n",
"print(\"Original Factor Graph:\")\n",
"graph.print()\n",
"\n",
"# Eliminate sequentially using a specific ordering\n",
"ordering = Ordering([X(0), X(1), X(2)])\n",
"bayes_net = graph.eliminateSequential(ordering)\n",
"\n",
"print(\"\\nResulting BayesNet:\")\n",
"bayes_net.print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_props_md"
},
"source": [
"## Properties and Access\n",
"\n",
"A `BayesNet` provides access to its constituent conditionals and basic properties."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bayesnet_access_code",
"outputId": "12345678-9abc-def0-1234-56789abcdef0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BayesNet size: 3\n",
"Conditional at index 1: \n",
"GaussianConditional p(x1 | x2)\n",
" R = [ 1.22474 ]\n",
" S[x2] = [ -0.816497 ]\n",
" d = [ 0 ]\n",
" logNormalizationConstant: -0.716206\n",
" No noise model\n",
"Keys in BayesNet: x0x1x2\n"
]
}
],
"source": [
"print(f\"BayesNet size: {bayes_net.size()}\")\n",
"\n",
"# Access conditional by index\n",
"conditional1 = bayes_net.at(1)\n",
"print(\"Conditional at index 1: \")\n",
"conditional1.print()\n",
"\n",
"# Get all keys involved\n",
"bn_keys = bayes_net.keys()\n",
"print(f\"Keys in BayesNet: {bn_keys}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_eval_md"
},
"source": [
"## Evaluation and Solution\n",
"\n",
"The `logProbability(Values)` method computes the log probability of a variable assignment given the conditional distributions in the Bayes net. For Gaussian Bayes nets, the `optimize()` method can be used to find the maximum likelihood estimate (MLE) solution via back-substitution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bayesnet_eval_code",
"outputId": "23456789-abcd-ef01-2345-6789abcdef01"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Log Probability at 0,0,0]: -2.7568155996140185\n",
"Optimized Solution (MLE):\n",
"VectorValues: 3 elements\n",
" x0: 0\n",
" x1: 0\n",
" x2: 0\n"
]
}
],
"source": [
"# For GaussianBayesNet, we use VectorValues\n",
"mle_solution = bayes_net.optimize()\n",
"\n",
"# Calculate log probability (requires providing values for all variables)\n",
"log_prob = bayes_net.logProbability(mle_solution)\n",
"print(f\"Log Probability at {mle_solution.at(X(0))[0]:.0f},{mle_solution.at(X(1))[0]:.0f},{mle_solution.at(X(2))[0]:.0f}]: {log_prob}\")\n",
"\n",
"print(\"Optimized Solution (MLE):\")\n",
"mle_solution.print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bayesnet_viz_md"
},
"source": [
"## Visualization\n",
"\n",
"Bayes nets can also be visualized using Graphviz."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bayesnet_dot_code",
"outputId": "3456789a-bcde-f012-3456-789abcdef012"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"digraph {\n",
" size=\"5,5\";\n",
"\n",
" var8646911284551352320[label=\"x0\"];\n",
" var8646911284551352321[label=\"x1\"];\n",
" var8646911284551352322[label=\"x2\"];\n",
"\n",
" var8646911284551352322->var8646911284551352321\n",
" var8646911284551352321->var8646911284551352320\n",
"}\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.50.0 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"62pt\" height=\"188pt\"\n",
" viewBox=\"0.00 0.00 62.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 58,-184 58,4 -4,4\"/>\n",
"<!-- var8646911284551352320 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>var8646911284551352320</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x0</text>\n",
"</g>\n",
"<!-- var8646911284551352321 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>var8646911284551352321</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\n",
"</g>\n",
"<!-- var8646911284551352321&#45;&gt;var8646911284551352320 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>var8646911284551352321&#45;&gt;var8646911284551352320</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M27,-71.7C27,-63.98 27,-54.71 27,-46.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-46.1 27,-36.1 23.5,-46.1 30.5,-46.1\"/>\n",
"</g>\n",
"<!-- var8646911284551352322 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>var8646911284551352322</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\n",
"</g>\n",
"<!-- var8646911284551352322&#45;&gt;var8646911284551352321 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>var8646911284551352322&#45;&gt;var8646911284551352321</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M27,-143.7C27,-135.98 27,-126.71 27,-118.11\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-118.1 27,-108.1 23.5,-118.1 30.5,-118.1\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x2c3022fcc20>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_string = bayes_net.dot()\n",
"print(dot_string)\n",
"\n",
"# To render:\n",
"# dot -Tpng bayesnet.dot -o bayesnet.png\n",
"import graphviz\n",
"graphviz.Source(dot_string)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "gtsam",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}