Asia example
parent
60d00621ff
commit
6ea62e1f04
|
@ -15,7 +15,7 @@
|
|||
"id": "bayesnet_desc_md"
|
||||
},
|
||||
"source": [
|
||||
"A `BayesNet` in GTSAM represents a directed graphical model, specifically the result of running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph`.\n",
|
||||
"A `BayesNet` in GTSAM represents a directed graphical model, created by running sequential variable elimination (like Cholesky or QR factorization) on a `FactorGraph` or constructing from scratch.\n",
|
||||
"\n",
|
||||
"It is essentially a collection of `Conditional` objects, ordered according to the elimination order. Each conditional represents $P(\\text{variable} | \\text{parents})$, where the parents are variables that appear later in the elimination ordering.\n",
|
||||
"\n",
|
||||
|
@ -57,7 +57,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"id": "bayesnet_import_code"
|
||||
},
|
||||
|
@ -69,6 +69,8 @@
|
|||
"\n",
|
||||
"# We need concrete graph types and elimination to get a BayesNet\n",
|
||||
"from gtsam import GaussianFactorGraph, Ordering, GaussianBayesNet\n",
|
||||
"# For the Asia example\n",
|
||||
"from gtsam import DiscreteBayesNet, DiscreteConditional, DiscreteKeys, DiscreteValues, symbol\n",
|
||||
"from gtsam import symbol_shorthand\n",
|
||||
"\n",
|
||||
"X = symbol_shorthand.X\n",
|
||||
|
@ -88,7 +90,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -144,7 +146,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -199,7 +201,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 47,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -250,7 +252,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 48,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -297,7 +299,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 49,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
|
@ -353,10 +355,10 @@
|
|||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.sources.Source at 0x18b7818a990>"
|
||||
"<graphviz.sources.Source at 0x18b7757b020>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -364,6 +366,364 @@
|
|||
"source": [
|
||||
"graphviz.Source(bayes_net.dot())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bayesnet_discrete_md"
|
||||
},
|
||||
"source": [
|
||||
"## Example: DiscreteBayesNet (Asia Network)\n",
|
||||
"\n",
|
||||
"While the previous examples focused on `GaussianBayesNet`, GTSAM also supports `DiscreteBayesNet` for representing probability distributions over discrete variables. Here we construct the classic 'Asia' network example directly by adding `DiscreteConditional` objects."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"metadata": {
|
||||
"id": "bayesnet_discrete_imports_code"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define keys for the Asia network variables\n",
|
||||
"A = symbol('A', 8) # Visit to Asia?\n",
|
||||
"S = symbol('S', 7) # Smoker?\n",
|
||||
"T = symbol('T', 6) # Tuberculosis?\n",
|
||||
"L = symbol('L', 5) # Lung Cancer?\n",
|
||||
"B = symbol('B', 4) # Bronchitis?\n",
|
||||
"E = symbol('E', 3) # Tuberculosis or Lung Cancer?\n",
|
||||
"X = symbol('X', 2) # Positive X-Ray?\n",
|
||||
"D = symbol('D', 1) # Dyspnea (Shortness of breath)?\n",
|
||||
"\n",
|
||||
"# Define cardinalities (all are binary in this case)\n",
|
||||
"cardinalities = { A: 2, S: 2, T: 2, L: 2, B: 2, E: 2, X: 2, D: 2 }\n",
|
||||
"\n",
|
||||
"# Helper to create DiscreteKeys object\n",
|
||||
"def make_keys(keys_list):\n",
|
||||
" dk = DiscreteKeys()\n",
|
||||
" for k in keys_list:\n",
|
||||
" dk.push_back((k, cardinalities[k]))\n",
|
||||
" return dk"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "bayesnet_discrete_build_code",
|
||||
"outputId": "456789ab-cdef-0123-4567-89abcdef0123"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Asia Bayes Net:\n",
|
||||
"DiscreteBayesNet\n",
|
||||
" \n",
|
||||
"size: 8\n",
|
||||
"conditional 0: P( D1 | E3 B4 ):\n",
|
||||
" Choice(E3) \n",
|
||||
" 0 Choice(D1) \n",
|
||||
" 0 0 Choice(B4) \n",
|
||||
" 0 0 0 Leaf 0.9\n",
|
||||
" 0 0 1 Leaf 0.2\n",
|
||||
" 0 1 Choice(B4) \n",
|
||||
" 0 1 0 Leaf 0.1\n",
|
||||
" 0 1 1 Leaf 0.8\n",
|
||||
" 1 Choice(D1) \n",
|
||||
" 1 0 Choice(B4) \n",
|
||||
" 1 0 0 Leaf 0.3\n",
|
||||
" 1 0 1 Leaf 0.1\n",
|
||||
" 1 1 Choice(B4) \n",
|
||||
" 1 1 0 Leaf 0.7\n",
|
||||
" 1 1 1 Leaf 0.9\n",
|
||||
"\n",
|
||||
"conditional 1: P( X2 | E3 ):\n",
|
||||
" Choice(X2) \n",
|
||||
" 0 Choice(E3) \n",
|
||||
" 0 0 Leaf 0.95\n",
|
||||
" 0 1 Leaf 0.02\n",
|
||||
" 1 Choice(E3) \n",
|
||||
" 1 0 Leaf 0.05\n",
|
||||
" 1 1 Leaf 0.98\n",
|
||||
"\n",
|
||||
"conditional 2: P( E3 | T6 L5 ):\n",
|
||||
" Choice(T6) \n",
|
||||
" 0 Choice(L5) \n",
|
||||
" 0 0 Choice(E3) \n",
|
||||
" 0 0 0 Leaf 1\n",
|
||||
" 0 0 1 Leaf 0\n",
|
||||
" 0 1 Choice(E3) \n",
|
||||
" 0 1 0 Leaf 0\n",
|
||||
" 0 1 1 Leaf 1\n",
|
||||
" 1 Choice(L5) \n",
|
||||
" 1 0 Choice(E3) \n",
|
||||
" 1 0 0 Leaf 0\n",
|
||||
" 1 0 1 Leaf 1\n",
|
||||
" 1 1 Choice(E3) \n",
|
||||
" 1 1 0 Leaf 0\n",
|
||||
" 1 1 1 Leaf 1\n",
|
||||
"\n",
|
||||
"conditional 3: P( B4 | S7 ):\n",
|
||||
" Choice(S7) \n",
|
||||
" 0 Choice(B4) \n",
|
||||
" 0 0 Leaf 0.7\n",
|
||||
" 0 1 Leaf 0.3\n",
|
||||
" 1 Choice(B4) \n",
|
||||
" 1 0 Leaf 0.4\n",
|
||||
" 1 1 Leaf 0.6\n",
|
||||
"\n",
|
||||
"conditional 4: P( L5 | S7 ):\n",
|
||||
" Choice(S7) \n",
|
||||
" 0 Choice(L5) \n",
|
||||
" 0 0 Leaf 0.99\n",
|
||||
" 0 1 Leaf 0.01\n",
|
||||
" 1 Choice(L5) \n",
|
||||
" 1 0 Leaf 0.9\n",
|
||||
" 1 1 Leaf 0.1\n",
|
||||
"\n",
|
||||
"conditional 5: P( T6 | A8 ):\n",
|
||||
" Choice(T6) \n",
|
||||
" 0 Choice(A8) \n",
|
||||
" 0 0 Leaf 0.99\n",
|
||||
" 0 1 Leaf 0.95\n",
|
||||
" 1 Choice(A8) \n",
|
||||
" 1 0 Leaf 0.01\n",
|
||||
" 1 1 Leaf 0.05\n",
|
||||
"\n",
|
||||
"conditional 6: P( S7 ):\n",
|
||||
" Leaf 0.5\n",
|
||||
"\n",
|
||||
"conditional 7: P( A8 ):\n",
|
||||
" Choice(A8) \n",
|
||||
" 0 Leaf 0.99\n",
|
||||
" 1 Leaf 0.01\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create the DiscreteBayesNet\n",
|
||||
"asia_net = DiscreteBayesNet()\n",
|
||||
"\n",
|
||||
"# Helper function to create parent list in correct format\n",
|
||||
"def make_parent_tuples(parent_keys):\n",
|
||||
" return [(pk, cardinalities[pk]) for pk in parent_keys]\n",
|
||||
"\n",
|
||||
"# P(D | E, B) - Dyspnea given Either and Bronchitis\n",
|
||||
"asia_net.add(DiscreteConditional((D, cardinalities[D]), make_parent_tuples([E, B]), \"9/1 2/8 3/7 1/9\"))\n",
|
||||
"\n",
|
||||
"# P(X | E) - X-Ray result given Either\n",
|
||||
"asia_net.add(DiscreteConditional((X, cardinalities[X]), make_parent_tuples([E]), \"95/5 2/98\"))\n",
|
||||
"\n",
|
||||
"# P(E | T, L) - Either Tub. or Lung Cancer (OR gate)\n",
|
||||
"# \"F T T T\" means P(E=1|T=0,L=0)=0, P(E=1|T=0,L=1)=1, P(E=1|T=1,L=0)=1, P(E=1|T=1,L=1)=1\n",
|
||||
"asia_net.add(DiscreteConditional((E, cardinalities[E]), make_parent_tuples([T, L]), \"F T T T\"))\n",
|
||||
"\n",
|
||||
"# P(B | S) - Bronchitis given Smoker\n",
|
||||
"asia_net.add(DiscreteConditional((B, cardinalities[B]), make_parent_tuples([S]), \"70/30 40/60\"))\n",
|
||||
"\n",
|
||||
"# P(L | S) - Lung Cancer given Smoker\n",
|
||||
"asia_net.add(DiscreteConditional((L, cardinalities[L]), make_parent_tuples([S]), \"99/1 90/10\"))\n",
|
||||
"\n",
|
||||
"# P(T | A) - Tuberculosis given Asia\n",
|
||||
"asia_net.add(DiscreteConditional((T, cardinalities[T]), make_parent_tuples([A]), \"99/1 95/5\"))\n",
|
||||
"\n",
|
||||
"# P(S) - Prior on Smoking\n",
|
||||
"asia_net.add(DiscreteConditional((S, cardinalities[S]), [], \"1/1\")) # or \"50/50\"\n",
|
||||
"\n",
|
||||
"# Add conditional probability tables (CPTs) using C++ sugar syntax\n",
|
||||
"# P(A) - Prior on Asia\n",
|
||||
"asia_net.add(DiscreteConditional((A, cardinalities[A]), [], \"99/1\"))\n",
|
||||
"\n",
|
||||
"print(\"Asia Bayes Net:\")\n",
|
||||
"asia_net.print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "bayesnet_discrete_viz_eval_code",
|
||||
"outputId": "56789abc-def0-1234-5678-9abcdef01234"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"image/svg+xml": [
|
||||
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||
"<!-- Generated by graphviz version 2.50.0 (0)\n",
|
||||
" -->\n",
|
||||
"<!-- Pages: 1 -->\n",
|
||||
"<svg width=\"189pt\" height=\"260pt\"\n",
|
||||
" viewBox=\"0.00 0.00 189.00 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
|
||||
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 185,-256 185,4 -4,4\"/>\n",
|
||||
"<!-- var4683743612465315848 -->\n",
|
||||
"<g id=\"node1\" class=\"node\">\n",
|
||||
"<title>var4683743612465315848</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"27\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">A8</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var6052837899185946630 -->\n",
|
||||
"<g id=\"node7\" class=\"node\">\n",
|
||||
"<title>var6052837899185946630</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">T6</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4683743612465315848->var6052837899185946630 -->\n",
|
||||
"<g id=\"edge1\" class=\"edge\">\n",
|
||||
"<title>var4683743612465315848->var6052837899185946630</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M27,-215.7C27,-207.98 27,-198.71 27,-190.11\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"30.5,-190.1 27,-180.1 23.5,-190.1 30.5,-190.1\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4755801206503243780 -->\n",
|
||||
"<g id=\"node2\" class=\"node\">\n",
|
||||
"<title>var4755801206503243780</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"154\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"154\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">B4</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4899916394579099649 -->\n",
|
||||
"<g id=\"node3\" class=\"node\">\n",
|
||||
"<title>var4899916394579099649</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"154\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"154\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">D1</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4755801206503243780->var4899916394579099649 -->\n",
|
||||
"<g id=\"edge8\" class=\"edge\">\n",
|
||||
"<title>var4755801206503243780->var4899916394579099649</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M154,-71.7C154,-63.98 154,-54.71 154,-46.11\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"157.5,-46.1 154,-36.1 150.5,-46.1 157.5,-46.1\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4971973988617027587 -->\n",
|
||||
"<g id=\"node4\" class=\"node\">\n",
|
||||
"<title>var4971973988617027587</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"82\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"82\" y=\"-86.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">E3</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4971973988617027587->var4899916394579099649 -->\n",
|
||||
"<g id=\"edge7\" class=\"edge\">\n",
|
||||
"<title>var4971973988617027587->var4899916394579099649</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M96.57,-74.83C106.75,-64.94 120.52,-51.55 132.03,-40.36\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"134.47,-42.87 139.2,-33.38 129.59,-37.85 134.47,-42.87\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var6341068275337658370 -->\n",
|
||||
"<g id=\"node8\" class=\"node\">\n",
|
||||
"<title>var6341068275337658370</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"82\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"82\" y=\"-14.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">X2</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var4971973988617027587->var6341068275337658370 -->\n",
|
||||
"<g id=\"edge6\" class=\"edge\">\n",
|
||||
"<title>var4971973988617027587->var6341068275337658370</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M82,-71.7C82,-63.98 82,-54.71 82,-46.11\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"85.5,-46.1 82,-36.1 78.5,-46.1 85.5,-46.1\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var5476377146882523141 -->\n",
|
||||
"<g id=\"node5\" class=\"node\">\n",
|
||||
"<title>var5476377146882523141</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"99\" y=\"-158.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">L5</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var5476377146882523141->var4971973988617027587 -->\n",
|
||||
"<g id=\"edge5\" class=\"edge\">\n",
|
||||
"<title>var5476377146882523141->var4971973988617027587</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M94.88,-144.05C92.99,-136.26 90.7,-126.82 88.58,-118.08\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"91.96,-117.17 86.2,-108.28 85.15,-118.82 91.96,-117.17\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var5980780305148018695 -->\n",
|
||||
"<g id=\"node6\" class=\"node\">\n",
|
||||
"<title>var5980780305148018695</title>\n",
|
||||
"<ellipse fill=\"none\" stroke=\"black\" cx=\"126\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
|
||||
"<text text-anchor=\"middle\" x=\"126\" y=\"-230.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">S7</text>\n",
|
||||
"</g>\n",
|
||||
"<!-- var5980780305148018695->var4755801206503243780 -->\n",
|
||||
"<g id=\"edge3\" class=\"edge\">\n",
|
||||
"<title>var5980780305148018695->var4755801206503243780</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M129.38,-215.87C134.17,-191.56 142.99,-146.82 148.67,-118.01\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"152.11,-118.68 150.61,-108.19 145.24,-117.32 152.11,-118.68\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var5980780305148018695->var5476377146882523141 -->\n",
|
||||
"<g id=\"edge2\" class=\"edge\">\n",
|
||||
"<title>var5980780305148018695->var5476377146882523141</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M119.6,-216.41C116.49,-208.34 112.67,-198.43 109.17,-189.35\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"112.4,-188.03 105.54,-179.96 105.87,-190.55 112.4,-188.03\"/>\n",
|
||||
"</g>\n",
|
||||
"<!-- var6052837899185946630->var4971973988617027587 -->\n",
|
||||
"<g id=\"edge4\" class=\"edge\">\n",
|
||||
"<title>var6052837899185946630->var4971973988617027587</title>\n",
|
||||
"<path fill=\"none\" stroke=\"black\" d=\"M38.93,-145.81C46.21,-136.55 55.66,-124.52 63.85,-114.09\"/>\n",
|
||||
"<polygon fill=\"black\" stroke=\"black\" points=\"66.66,-116.18 70.09,-106.16 61.16,-111.86 66.66,-116.18\"/>\n",
|
||||
"</g>\n",
|
||||
"</g>\n",
|
||||
"</svg>\n"
|
||||
],
|
||||
"text/plain": [
|
||||
"<graphviz.sources.Source at 0x18b782e0520>"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Log Probability of all zeros: -1.2366269421045588\n",
|
||||
"Sampled Values (basic print):\n",
|
||||
"DiscreteValues{4683743612465315848: 0, 4755801206503243780: 1, 4899916394579099649: 1, 4971973988617027587: 0, 5476377146882523141: 0, 5980780305148018695: 1, 6052837899185946630: 0, 6341068275337658370: 0}\n",
|
||||
"Sampled Values (pretty print):\n",
|
||||
" A8: 0\n",
|
||||
" B4: 1\n",
|
||||
" D1: 1\n",
|
||||
" E3: 0\n",
|
||||
" L5: 0\n",
|
||||
" S7: 1\n",
|
||||
" T6: 0\n",
|
||||
" X2: 0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Visualize the network structure\n",
|
||||
"dot_string = asia_net.dot()\n",
|
||||
"display(graphviz.Source(dot_string))\n",
|
||||
"\n",
|
||||
"# Evaluate the log probability of a specific assignment\n",
|
||||
"# Example: Calculate P(A=0, S=0, T=0, L=0, B=0, E=0, X=0, D=0)\n",
|
||||
"values = DiscreteValues()\n",
|
||||
"for key, card in cardinalities.items():\n",
|
||||
" values[key] = 0 # Assign 0 to all variables to start\n",
|
||||
"\n",
|
||||
"log_prob_zeros = asia_net.logProbability(values)\n",
|
||||
"print(f\"Log Probability of all zeros: {log_prob_zeros}\")\n",
|
||||
"\n",
|
||||
"# Sample from the Bayes Net\n",
|
||||
"sample = asia_net.sample()\n",
|
||||
"print(\"Sampled Values (basic print):\")\n",
|
||||
"print(sample)\n",
|
||||
"\n",
|
||||
"# --- Pretty Print ---\n",
|
||||
"print(\"Sampled Values (pretty print):\")\n",
|
||||
"# Create a reverse mapping from integer key to string like 'A8'\n",
|
||||
"# We defined A=symbol('A',8), S=symbol('S',7), etc. above\n",
|
||||
"symbol_map = { A: 'A8', S: 'S7', T: 'T6', L: 'L5', B: 'B4', E: 'E3', X: 'X2', D: 'D1' }\n",
|
||||
"# Iterate through the sampled values and print nicely\n",
|
||||
"# Sort items by the symbol string for consistent order (optional)\n",
|
||||
"for key, value in sorted(sample.items(), key=lambda item: symbol_map.get(item[0], str(item[0]))):\n",
|
||||
" symbol_str = symbol_map.get(key, f\"UnknownKey({key})\") # Get 'A8' from key A\n",
|
||||
" print(f\" {symbol_str}: {value}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Loading…
Reference in New Issue