From 6ea62e1f043215b9d8fa2d415406ec2b1e8fbf42 Mon Sep 17 00:00:00 2001 From: p-zach Date: Wed, 16 Apr 2025 16:29:11 -0400 Subject: [PATCH] Asia example --- gtsam/inference/doc/BayesNet.ipynb | 378 ++++++++++++++++++++++++++++- 1 file changed, 369 insertions(+), 9 deletions(-) diff --git a/gtsam/inference/doc/BayesNet.ipynb b/gtsam/inference/doc/BayesNet.ipynb index db64eda3d..c379b32ce 100644 --- a/gtsam/inference/doc/BayesNet.ipynb +++ b/gtsam/inference/doc/BayesNet.ipynb @@ -15,7 +15,7 @@ "id": "bayesnet_desc_md" }, "source": [ - "A `BayesNet` in GTSAM represents a directed graphical model, specifically the result of running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph`.\n", + "A `BayesNet` in GTSAM represents a directed graphical model, created by running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph` or constructing from scratch.\n", "\n", "It is essentially a collection of `Conditional` objects, ordered according to the elimination order. Each conditional represents $P(\\text{variable} | \\text{parents})$, where the parents are variables that appear later in the elimination ordering.\n", "\n", @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 44, "metadata": { "id": "bayesnet_import_code" }, @@ -69,6 +69,8 @@ "\n", "# We need concrete graph types and elimination to get a BayesNet\n", "from gtsam import GaussianFactorGraph, Ordering, GaussianBayesNet\n", + "# For the Asia example\n", + "from gtsam import DiscreteBayesNet, DiscreteConditional, DiscreteKeys, DiscreteValues, symbol\n", "from gtsam import symbol_shorthand\n", "\n", "X = symbol_shorthand.X\n", @@ -88,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 45, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -144,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 46, "metadata": {}, "outputs": [ { @@ -199,7 +201,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 47, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -250,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 48, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -297,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 49, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -353,10 +355,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 12, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -364,6 +366,364 @@ "source": [ "graphviz.Source(bayes_net.dot())" ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bayesnet_discrete_md" + }, + "source": [ + "## Example: DiscreteBayesNet (Asia Network)\n", + "\n", + "While the previous examples focused on `GaussianBayesNet`, GTSAM also supports `DiscreteBayesNet` for representing probability distributions over discrete variables. Here we construct the classic 'Asia' network example directly by adding `DiscreteConditional` objects." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "id": "bayesnet_discrete_imports_code" + }, + "outputs": [], + "source": [ + "# Define keys for the Asia network variables\n", + "A = symbol('A', 8) # Visit to Asia?\n", + "S = symbol('S', 7) # Smoker?\n", + "T = symbol('T', 6) # Tuberculosis?\n", + "L = symbol('L', 5) # Lung Cancer?\n", + "B = symbol('B', 4) # Bronchitis?\n", + "E = symbol('E', 3) # Tuberculosis or Lung Cancer?\n", + "X = symbol('X', 2) # Positive X-Ray?\n", + "D = symbol('D', 1) # Dyspnea (Shortness of breath)?\n", + "\n", + "# Define cardinalities (all are binary in this case)\n", + "cardinalities = { A: 2, S: 2, T: 2, L: 2, B: 2, E: 2, X: 2, D: 2 }\n", + "\n", + "# Helper to create DiscreteKeys object\n", + "def make_keys(keys_list):\n", + " dk = DiscreteKeys()\n", + " for k in keys_list:\n", + " dk.push_back((k, cardinalities[k]))\n", + " return dk" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bayesnet_discrete_build_code", + "outputId": "456789ab-cdef-0123-4567-89abcdef0123" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Asia Bayes Net:\n", + "DiscreteBayesNet\n", + " \n", + "size: 8\n", + "conditional 0: P( D1 | E3 B4 ):\n", + " Choice(E3) \n", + " 0 Choice(D1) \n", + " 0 0 Choice(B4) \n", + " 0 0 0 Leaf 0.9\n", + " 0 0 1 Leaf 0.2\n", + " 0 1 Choice(B4) \n", + " 0 1 0 Leaf 0.1\n", + " 0 1 1 Leaf 0.8\n", + " 1 Choice(D1) \n", + " 1 0 Choice(B4) \n", + " 1 0 0 Leaf 0.3\n", + " 1 0 1 Leaf 0.1\n", + " 1 1 Choice(B4) \n", + " 1 1 0 Leaf 0.7\n", + " 1 1 1 Leaf 0.9\n", + "\n", + "conditional 1: P( X2 | E3 ):\n", + " Choice(X2) \n", + " 0 Choice(E3) \n", + " 0 0 Leaf 0.95\n", + " 0 1 Leaf 0.02\n", + " 1 Choice(E3) \n", + " 1 0 Leaf 0.05\n", + " 1 1 Leaf 0.98\n", + "\n", + "conditional 2: P( E3 | T6 L5 ):\n", + " Choice(T6) \n", + " 0 Choice(L5) \n", + " 0 0 Choice(E3) \n", + " 0 0 0 Leaf 1\n", + " 0 0 1 Leaf 0\n", + " 0 1 Choice(E3) \n", + " 0 1 0 Leaf 0\n", + " 0 1 1 Leaf 1\n", + " 1 Choice(L5) \n", + " 1 0 Choice(E3) \n", + " 1 0 0 Leaf 0\n", + " 1 0 1 Leaf 1\n", + " 1 1 Choice(E3) \n", + " 1 1 0 Leaf 0\n", + " 1 1 1 Leaf 1\n", + "\n", + "conditional 3: P( B4 | S7 ):\n", + " Choice(S7) \n", + " 0 Choice(B4) \n", + " 0 0 Leaf 0.7\n", + " 0 1 Leaf 0.3\n", + " 1 Choice(B4) \n", + " 1 0 Leaf 0.4\n", + " 1 1 Leaf 0.6\n", + "\n", + "conditional 4: P( L5 | S7 ):\n", + " Choice(S7) \n", + " 0 Choice(L5) \n", + " 0 0 Leaf 0.99\n", + " 0 1 Leaf 0.01\n", + " 1 Choice(L5) \n", + " 1 0 Leaf 0.9\n", + " 1 1 Leaf 0.1\n", + "\n", + "conditional 5: P( T6 | A8 ):\n", + " Choice(T6) \n", + " 0 Choice(A8) \n", + " 0 0 Leaf 0.99\n", + " 0 1 Leaf 0.95\n", + " 1 Choice(A8) \n", + " 1 0 Leaf 0.01\n", + " 1 1 Leaf 0.05\n", + "\n", + "conditional 6: P( S7 ):\n", + " Leaf 0.5\n", + "\n", + "conditional 7: P( A8 ):\n", + " Choice(A8) \n", + " 0 Leaf 0.99\n", + " 1 Leaf 0.01\n", + "\n" + ] + } + ], + "source": [ + "# Create the DiscreteBayesNet\n", + "asia_net = DiscreteBayesNet()\n", + "\n", + "# Helper function to create parent list in correct format\n", + "def make_parent_tuples(parent_keys):\n", + " return [(pk, cardinalities[pk]) for pk in parent_keys]\n", + "\n", + "# P(D | E, B) - Dyspnea given Either and Bronchitis\n", + "asia_net.add(DiscreteConditional((D, cardinalities[D]), make_parent_tuples([E, B]), \"9/1 2/8 3/7 1/9\"))\n", + "\n", + "# P(X | E) - X-Ray result given Either\n", + "asia_net.add(DiscreteConditional((X, cardinalities[X]), make_parent_tuples([E]), \"95/5 2/98\"))\n", + "\n", + "# P(E | T, L) - Either Tub. or Lung Cancer (OR gate)\n", + "# \"F T T T\" means P(E=1|T=0,L=0)=0, P(E=1|T=0,L=1)=1, P(E=1|T=1,L=0)=1, P(E=1|T=1,L=1)=1\n", + "asia_net.add(DiscreteConditional((E, cardinalities[E]), make_parent_tuples([T, L]), \"F T T T\"))\n", + "\n", + "# P(B | S) - Bronchitis given Smoker\n", + "asia_net.add(DiscreteConditional((B, cardinalities[B]), make_parent_tuples([S]), \"70/30 40/60\"))\n", + "\n", + "# P(L | S) - Lung Cancer given Smoker\n", + "asia_net.add(DiscreteConditional((L, cardinalities[L]), make_parent_tuples([S]), \"99/1 90/10\"))\n", + "\n", + "# P(T | A) - Tuberculosis given Asia\n", + "asia_net.add(DiscreteConditional((T, cardinalities[T]), make_parent_tuples([A]), \"99/1 95/5\"))\n", + "\n", + "# P(S) - Prior on Smoking\n", + "asia_net.add(DiscreteConditional((S, cardinalities[S]), [], \"1/1\")) # or \"50/50\"\n", + "\n", + "# Add conditional probability tables (CPTs) using C++ sugar syntax\n", + "# P(A) - Prior on Asia\n", + "asia_net.add(DiscreteConditional((A, cardinalities[A]), [], \"99/1\"))\n", + "\n", + "print(\"Asia Bayes Net:\")\n", + "asia_net.print()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bayesnet_discrete_viz_eval_code", + "outputId": "56789abc-def0-1234-5678-9abcdef01234" + }, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var4683743612465315848\n", + "\n", + "A8\n", + "\n", + "\n", + "\n", + "var6052837899185946630\n", + "\n", + "T6\n", + "\n", + "\n", + "\n", + "var4683743612465315848->var6052837899185946630\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var4755801206503243780\n", + "\n", + "B4\n", + "\n", + "\n", + "\n", + "var4899916394579099649\n", + "\n", + "D1\n", + "\n", + "\n", + "\n", + "var4755801206503243780->var4899916394579099649\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var4971973988617027587\n", + "\n", + "E3\n", + "\n", + "\n", + "\n", + "var4971973988617027587->var4899916394579099649\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var6341068275337658370\n", + "\n", + "X2\n", + "\n", + "\n", + "\n", + "var4971973988617027587->var6341068275337658370\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var5476377146882523141\n", + "\n", + "L5\n", + "\n", + "\n", + "\n", + "var5476377146882523141->var4971973988617027587\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var5980780305148018695\n", + "\n", + "S7\n", + "\n", + "\n", + "\n", + "var5980780305148018695->var4755801206503243780\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var5980780305148018695->var5476377146882523141\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "var6052837899185946630->var4971973988617027587\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Log Probability of all zeros: -1.2366269421045588\n", + "Sampled Values (basic print):\n", + "DiscreteValues{4683743612465315848: 0, 4755801206503243780: 1, 4899916394579099649: 1, 4971973988617027587: 0, 5476377146882523141: 0, 5980780305148018695: 1, 6052837899185946630: 0, 6341068275337658370: 0}\n", + "Sampled Values (pretty print):\n", + " A8: 0\n", + " B4: 1\n", + " D1: 1\n", + " E3: 0\n", + " L5: 0\n", + " S7: 1\n", + " T6: 0\n", + " X2: 0\n" + ] + } + ], + "source": [ + "# Visualize the network structure\n", + "dot_string = asia_net.dot()\n", + "display(graphviz.Source(dot_string))\n", + "\n", + "# Evaluate the log probability of a specific assignment\n", + "# Example: Calculate P(A=0, S=0, T=0, L=0, B=0, E=0, X=0, D=0)\n", + "values = DiscreteValues()\n", + "for key, card in cardinalities.items():\n", + " values[key] = 0 # Assign 0 to all variables to start\n", + "\n", + "log_prob_zeros = asia_net.logProbability(values)\n", + "print(f\"Log Probability of all zeros: {log_prob_zeros}\")\n", + "\n", + "# Sample from the Bayes Net\n", + "sample = asia_net.sample()\n", + "print(\"Sampled Values (basic print):\")\n", + "print(sample)\n", + "\n", + "# --- Pretty Print ---\n", + "print(\"Sampled Values (pretty print):\")\n", + "# Create a reverse mapping from integer key to string like 'A8'\n", + "# We defined A=symbol('A',8), S=symbol('S',7), etc. above\n", + "symbol_map = { A: 'A8', S: 'S7', T: 'T6', L: 'L5', B: 'B4', E: 'E3', X: 'X2', D: 'D1' }\n", + "# Iterate through the sampled values and print nicely\n", + "# Sort items by the symbol string for consistent order (optional)\n", + "for key, value in sorted(sample.items(), key=lambda item: symbol_map.get(item[0], str(item[0]))):\n", + " symbol_str = symbol_map.get(key, f\"UnknownKey({key})\") # Get 'A8' from key A\n", + " print(f\" {symbol_str}: {value}\")" + ] } ], "metadata": {