401 lines
14 KiB
Plaintext
401 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_intro_md"
|
|
},
|
|
"source": [
|
|
"# BayesNet"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_desc_md"
|
|
},
|
|
"source": [
|
|
"A `BayesNet` in GTSAM represents a directed graphical model, specifically the result of running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph`.\n",
|
|
"\n",
|
|
"It is essentially a collection of `Conditional` objects, ordered according to the elimination order. Each conditional represents $P(\\text{variable} | \\text{parents})$, where the parents are variables that appear later in the elimination ordering.\n",
|
|
"\n",
|
|
"A Bayes net represents the joint probability distribution as a product of conditional probabilities stored in the net:\n",
|
|
"\n",
|
|
"$$\n",
|
|
"P(X_1, X_2, \\dots, X_N) = \\prod_{i=1}^N P(X_i | \\text{Parents}(X_i))\n",
|
|
"$$\n",
|
|
"The total log-probability of an assignment is the sum of the log-probabilities of its conditionals:\n",
|
|
"$$\n",
|
|
"\\log P(X_1, \\dots, X_N) = \\sum_{i=1}^N \\log P(X_i | \\text{Parents}(X_i))\n",
|
|
"$$\n",
|
|
"\n",
|
|
"Like `FactorGraph`, `BayesNet` is templated on the type of conditional it stores (e.g., `GaussianBayesNet`, `DiscreteBayesNet`)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_colab_md"
|
|
},
|
|
"source": [
|
|
"<a href=\"https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/inference/doc/BayesNet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "bayesnet_pip_code",
|
|
"tags": [
|
|
"remove-cell"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"%pip install gtsam"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"id": "bayesnet_import_code"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gtsam\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# We need concrete graph types and elimination to get a BayesNet\n",
|
|
"from gtsam import GaussianFactorGraph, Ordering, GaussianBayesNet\n",
|
|
"from gtsam import symbol_shorthand\n",
|
|
"\n",
|
|
"X = symbol_shorthand.X\n",
|
|
"L = symbol_shorthand.L"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_create_md"
|
|
},
|
|
"source": [
|
|
"## Creating a BayesNet (via Elimination)\n",
|
|
"\n",
|
|
"BayesNets are typically obtained by eliminating a `FactorGraph`."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "bayesnet_eliminate_code",
|
|
"outputId": "01234567-89ab-cdef-0123-456789abcdef"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Original Factor Graph:\n",
|
|
"\n",
|
|
"size: 3\n",
|
|
"factor 0: \n",
|
|
" A[x0] = [\n",
|
|
"\t-1\n",
|
|
"]\n",
|
|
" b = [ 0 ]\n",
|
|
" Noise model: unit (1) \n",
|
|
"factor 1: \n",
|
|
" A[x0] = [\n",
|
|
"\t-1\n",
|
|
"]\n",
|
|
" A[x1] = [\n",
|
|
"\t1\n",
|
|
"]\n",
|
|
" b = [ 0 ]\n",
|
|
" Noise model: unit (1) \n",
|
|
"factor 2: \n",
|
|
" A[x1] = [\n",
|
|
"\t-1\n",
|
|
"]\n",
|
|
" A[x2] = [\n",
|
|
"\t1\n",
|
|
"]\n",
|
|
" b = [ 0 ]\n",
|
|
" Noise model: unit (1) \n",
|
|
"\n",
|
|
"Resulting BayesNet:\n",
|
|
"\n",
|
|
"size: 3\n",
|
|
"conditional 0: p(x0 | x1)\n",
|
|
" R = [ 1.41421 ]\n",
|
|
" S[x1] = [ -0.707107 ]\n",
|
|
" d = [ 0 ]\n",
|
|
" logNormalizationConstant: -0.572365\n",
|
|
" No noise model\n",
|
|
"conditional 1: p(x1 | x2)\n",
|
|
" R = [ 1.22474 ]\n",
|
|
" S[x2] = [ -0.816497 ]\n",
|
|
" d = [ 0 ]\n",
|
|
" logNormalizationConstant: -0.716206\n",
|
|
" No noise model\n",
|
|
"conditional 2: p(x2)\n",
|
|
" R = [ 0.57735 ]\n",
|
|
" d = [ 0 ]\n",
|
|
" mean: 1 elements\n",
|
|
" x2: 0\n",
|
|
" logNormalizationConstant: -1.46824\n",
|
|
" No noise model\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Create a simple Gaussian Factor Graph P(x0) P(x1|x0) P(x2|x1)\n",
|
|
"graph = GaussianFactorGraph()\n",
|
|
"model = gtsam.noiseModel.Isotropic.Sigma(1, 1.0)\n",
|
|
"graph.add(X(0), -np.eye(1), np.zeros(1), model)\n",
|
|
"graph.add(X(0), -np.eye(1), X(1), np.eye(1), np.zeros(1), model)\n",
|
|
"graph.add(X(1), -np.eye(1), X(2), np.eye(1), np.zeros(1), model)\n",
|
|
"print(\"Original Factor Graph:\")\n",
|
|
"graph.print()\n",
|
|
"\n",
|
|
"# Eliminate sequentially using a specific ordering\n",
|
|
"ordering = Ordering([X(0), X(1), X(2)])\n",
|
|
"bayes_net = graph.eliminateSequential(ordering)\n",
|
|
"\n",
|
|
"print(\"\\nResulting BayesNet:\")\n",
|
|
"bayes_net.print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_props_md"
|
|
},
|
|
"source": [
|
|
"## Properties and Access\n",
|
|
"\n",
|
|
"A `BayesNet` provides access to its constituent conditionals and basic properties."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "bayesnet_access_code",
|
|
"outputId": "12345678-9abc-def0-1234-56789abcdef0"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"BayesNet size: 3\n",
|
|
"Conditional at index 1: \n",
|
|
"GaussianConditional p(x1 | x2)\n",
|
|
" R = [ 1.22474 ]\n",
|
|
" S[x2] = [ -0.816497 ]\n",
|
|
" d = [ 0 ]\n",
|
|
" logNormalizationConstant: -0.716206\n",
|
|
" No noise model\n",
|
|
"Keys in BayesNet: x0x1x2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"BayesNet size: {bayes_net.size()}\")\n",
|
|
"\n",
|
|
"# Access conditional by index\n",
|
|
"conditional1 = bayes_net.at(1)\n",
|
|
"print(\"Conditional at index 1: \")\n",
|
|
"conditional1.print()\n",
|
|
"\n",
|
|
"# Get all keys involved\n",
|
|
"bn_keys = bayes_net.keys()\n",
|
|
"print(f\"Keys in BayesNet: {bn_keys}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_eval_md"
|
|
},
|
|
"source": [
|
|
"## Evaluation and Solution\n",
|
|
"\n",
|
|
"The `logProbability(Values)` method computes the log probability of a variable assignment given the conditional distributions in the Bayes net. For Gaussian Bayes nets, the `optimize()` method can be used to find the maximum likelihood estimate (MLE) solution via back-substitution."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "bayesnet_eval_code",
|
|
"outputId": "23456789-abcd-ef01-2345-6789abcdef01"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Log Probability at 0,0,0]: -2.7568155996140185\n",
|
|
"Optimized Solution (MLE):\n",
|
|
"VectorValues: 3 elements\n",
|
|
" x0: 0\n",
|
|
" x1: 0\n",
|
|
" x2: 0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# For GaussianBayesNet, we use VectorValues\n",
|
|
"mle_solution = bayes_net.optimize()\n",
|
|
"\n",
|
|
"# Calculate log probability (requires providing values for all variables)\n",
|
|
"log_prob = bayes_net.logProbability(mle_solution)\n",
|
|
"print(f\"Log Probability at {mle_solution.at(X(0))[0]:.0f},{mle_solution.at(X(1))[0]:.0f},{mle_solution.at(X(2))[0]:.0f}]: {log_prob}\")\n",
|
|
"\n",
|
|
"print(\"Optimized Solution (MLE):\")\n",
|
|
"mle_solution.print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "bayesnet_viz_md"
|
|
},
|
|
"source": [
|
|
"## Visualization\n",
|
|
"\n",
|
|
"Bayes nets can also be visualized using Graphviz."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "bayesnet_dot_code",
|
|
"outputId": "3456789a-bcde-f012-3456-789abcdef012"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"digraph {\n",
|
|
" size=\"5,5\";\n",
|
|
"\n",
|
|
" var8646911284551352320[label=\"x0\"];\n",
|
|
" var8646911284551352321[label=\"x1\"];\n",
|
|
" var8646911284551352322[label=\"x2\"];\n",
|
|
"\n",
|
|
" var8646911284551352322->var8646911284551352321\n",
|
|
" var8646911284551352321->var8646911284551352320\n",
|
|
"}\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
|
"<!-- Generated by graphviz version 2.50.0 (0)\n",
|
|
" -->\n",
|
|
"<!-- Pages: 1 -->\n",
|
|
"<svg width=\"62pt\" height=\"188pt\"\n",
|
|
" viewBox=\"0.00 0.00 62.00 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 58,-184 58,4 -4,4\"/>\n",
|
|
"<!-- var8646911284551352320 -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>var8646911284551352320</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"27\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x0</text>\n",
|
|
"</g>\n",
|
|
"<!-- var8646911284551352321 -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>var8646911284551352321</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"27\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x1</text>\n",
|
|
"</g>\n",
|
|
"<!-- var8646911284551352321->var8646911284551352320 -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>var8646911284551352321->var8646911284551352320</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M27,-71.7C27,-63.98 27,-54.71 27,-46.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-46.1 27,-36.1 23.5,-46.1 30.5,-46.1\"/>\n",
|
|
"</g>\n",
|
|
"<!-- var8646911284551352322 -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>var8646911284551352322</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">x2</text>\n",
|
|
"</g>\n",
|
|
"<!-- var8646911284551352322->var8646911284551352321 -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>var8646911284551352322->var8646911284551352321</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M27,-143.7C27,-135.98 27,-126.71 27,-118.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-118.1 27,-108.1 23.5,-118.1 30.5,-118.1\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>\n"
|
|
],
|
|
"text/plain": [
|
|
"<graphviz.sources.Source at 0x2c3022fcc20>"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dot_string = bayes_net.dot()\n",
|
|
"print(dot_string)\n",
|
|
"\n",
|
|
"# To render:\n",
|
|
"# dot -Tpng bayesnet.dot -o bayesnet.png\n",
|
|
"import graphviz\n",
|
|
"graphviz.Source(dot_string)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "gtsam",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.1"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
}
|