mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
693 lines
30 KiB
Plaintext
693 lines
30 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "82f3f65d-fbcb-4e8e-b04b-959856283643",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Causal program-aided language (CPAL) chain\n",
|
|
"\n",
|
|
"The CPAL chain builds on the recent PAL to stop LLM hallucination. The problem with the PAL approach is that it hallucinates on a math problem with a nested chain of dependence. The innovation here is that this new CPAL approach includes causal structure to fix hallucination.\n",
|
|
"\n",
|
|
"The original [PR's description](https://github.com/langchain-ai/langchain/pull/6255) contains a full overview.\n",
|
|
"\n",
|
|
"Using the CPAL chain, the LLM translated this\n",
|
|
"\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
"\n",
|
|
"\n",
|
|
"into this\n",
|
|
"\n",
|
|
"![complex-graph.png](/img/cpal_diagram.png).\n",
|
|
"\n",
|
|
"Outline of code examples demoed in this notebook.\n",
|
|
"\n",
|
|
"1. CPAL's value against hallucination: CPAL vs PAL \n",
|
|
" 1.1 Complex narrative \n",
|
|
" 1.2 Unanswerable math word problem \n",
|
|
"2. CPAL's three types of causal diagrams ([The Book of Why](https://en.wikipedia.org/wiki/The_Book_of_Why)). \n",
|
|
" 2.1 Mediator \n",
|
|
" 2.2 Collider \n",
|
|
" 2.3 Confounder "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1370e40f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import SVG\n",
|
|
"from langchain.llms import OpenAI\n",
|
|
"from langchain_experimental.cpal.base import CPALChain\n",
|
|
"from langchain_experimental.pal_chain import PALChain\n",
|
|
"\n",
|
|
"llm = OpenAI(temperature=0, max_tokens=512)\n",
|
|
"cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)\n",
|
|
"pal_chain = PALChain.from_math_prompt(llm=llm, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "858a87d9-a9bd-4850-9687-9af4b0856b62",
|
|
"metadata": {},
|
|
"source": [
|
|
"## CPAL's value against hallucination: CPAL vs PAL\n",
|
|
"\n",
|
|
"Like PAL, CPAL intends to reduce large language model (LLM) hallucination.\n",
|
|
"\n",
|
|
"The CPAL chain is different from the PAL chain for a couple of reasons.\n",
|
|
"\n",
|
|
"CPAL adds a causal structure (or DAG) to link entity actions (or math expressions).\n",
|
|
"The CPAL math expressions are modeling a chain of cause and effect relations, which can be intervened upon, whereas for the PAL chain math expressions are projected math identities.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "496403c5-d268-43ae-8852-2bd9903ce444",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1.1 Complex narrative\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, CPAL does not hallucinate."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d5dad768-2892-4825-8093-9b840f643a8a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "bbffa7a0-3c22-4a1d-ab2d-f230973073b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Tim buys the same number of pets as Cindy and Boris.Cindy buys the same number of pets as Bill plus Bob.Boris buys the same number of pets as Ben plus Beth.Bill buys the same number of pets as Obama.Bob buys the same number of pets as Obama.Ben buys the same number of pets as Obama.Beth buys the same number of pets as Obama.If Obama buys one pet, how many pets total does everyone buy?\"\"\"\n",
|
|
" obama_pets = 1\n",
|
|
" tim_pets = obama_pets\n",
|
|
" cindy_pets = obama_pets + obama_pets\n",
|
|
" boris_pets = obama_pets + obama_pets\n",
|
|
" total_pets = tim_pets + cindy_pets + boris_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'5'"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "35a70d1d-86f8-4abc-b818-fbd083f072e9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 obama pass 1.0 []\n",
|
|
"1 bill bill.value = obama.value 1.0 [obama]\n",
|
|
"2 bob bob.value = obama.value 1.0 [obama]\n",
|
|
"3 ben ben.value = obama.value 1.0 [obama]\n",
|
|
"4 beth beth.value = obama.value 1.0 [obama]\n",
|
|
"5 cindy cindy.value = bill.value + bob.value 2.0 [bill, bob]\n",
|
|
"6 boris boris.value = ben.value + beth.value 2.0 [ben, beth]\n",
|
|
"7 tim tim.value = cindy.value + boris.value 4.0 [cindy, boris]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets total does everyone buy?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"13.0"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ccb6b2b0-9de6-4f66-a8fb-fc59229ee316",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"292pt\" height=\"260pt\" viewBox=\"0.00 0.00 292.00 260.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 288,-256 288,4 -4,4\"/>\n<!-- obama -->\n<g id=\"node1\" class=\"node\">\n<title>obama</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-234\" rx=\"41.69\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"137\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">obama</text>\n</g>\n<!-- bill -->\n<g id=\"node2\" class=\"node\">\n<title>bill</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bill</text>\n</g>\n<!-- obama->bill -->\n<g id=\"edge1\" class=\"edge\">\n<title>obama->bill</title>\n<path fill=\"none\" stroke=\"black\" d=\"M114.47,-218.67C97.08,-207.6 72.94,-192.23 54.42,-180.45\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"56.15,-177.4 45.84,-174.99 52.4,-183.31 56.15,-177.4\"/>\n</g>\n<!-- bob -->\n<g id=\"node3\" class=\"node\">\n<title>bob</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"100\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"100\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bob</text>\n</g>\n<!-- obama->bob -->\n<g id=\"edge2\" class=\"edge\">\n<title>obama->bob</title>\n<path fill=\"none\" stroke=\"black\" d=\"M128.04,-216.05C123.66,-207.77 118.3,-197.62 113.44,-188.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"116.39,-186.51 108.62,-179.31 110.2,-189.79 116.39,-186.51\"/>\n</g>\n<!-- ben -->\n<g id=\"node4\" class=\"node\">\n<title>ben</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"174\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"174\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">ben</text>\n</g>\n<!-- obama->ben -->\n<g id=\"edge3\" class=\"edge\">\n<title>obama->ben</title>\n<path fill=\"none\" stroke=\"black\" d=\"M145.96,-216.05C150.34,-207.77 155.7,-197.62 160.56,-188.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"163.8,-189.79 165.38,-179.31 157.61,-186.51 163.8,-189.79\"/>\n</g>\n<!-- beth -->\n<g id=\"node5\" class=\"node\">\n<title>beth</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"252\" cy=\"-162\" rx=\"32\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"252\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">beth</text>\n</g>\n<!-- obama->beth -->\n<g id=\"edge4\" class=\"edge\">\n<title>obama->beth</title>\n<path fill=\"none\" stroke=\"black\" d=\"M160.27,-218.83C178.18,-207.94 203.04,-192.8 222.37,-181.04\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"224.36,-183.92 231.08,-175.73 220.72,-177.95 224.36,-183.92\"/>\n</g>\n<!-- cindy -->\n<g id=\"node6\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"93\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"93\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- bill->cindy -->\n<g id=\"edge5\" class=\"edge\">\n<title>bill->cindy</title>\n<path fill=\"none\" stroke=\"black\" d=\"M41,-146.15C49.77,-136.85 61.25,-124.67 71.2,-114.12\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"73.79,-116.47 78.11,-106.8 68.7,-111.67 73.79,-116.47\"/>\n</g>\n<!-- bob->cindy -->\n<g id=\"edge6\" class=\"edge\">\n<title>bob->cindy</title>\n<path fill=\"none\" stroke=\"black\" d=\"M98.27,-143.7C97.5,-135.98 96.57,-126.71 95.71,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"99.19,-117.7 94.71,-108.1 92.22,-118.4 99.19,-117.7\"/>\n</g>\n<!-- boris -->\n<g id=\"node7\" class=\"node\">\n<title>boris</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"181\" cy=\"-90\" rx=\"34.5\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"181\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">boris</text>\n</g>\n<!-- ben->boris -->\n<g id=\"edge7\" class=\"edge\">\n<title>ben->boris</title>\n<path fill=\"none\" stroke=\"black\" d=\"M175.73,-143.7C176.5,-135.98 177.43,-126.71 178.29,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"181.78,-118.4 179.29,-108.1 174.81,-117.7 181.78,-118.4\"/>\n</g>\n<!-- beth->boris -->\n<g id=\"edge8\" class=\"edge\">\n<title>beth->boris</title>\n<path fill=\"none\" stroke=\"black\" d=\"M236.59,-145.81C227.01,-136.36 214.51,-124.04 203.8,-113.48\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"205.96,-110.69 196.38,-106.16 201.04,-115.67 205.96,-110.69\"/>\n</g>\n<!-- tim -->\n<g id=\"node8\" class=\"node\">\n<title>tim</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"137\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">tim</text>\n</g>\n<!-- cindy->tim -->\n<g id=\"edge9\" class=\"edge\">\n<title>cindy->tim</title>\n<path fill=\"none\" stroke=\"black\" d=\"M103.43,-72.41C108.82,-63.83 115.51,-53.19 121.49,-43.67\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"124.59,-45.32 126.95,-34.99 118.66,-41.59 124.59,-45.32\"/>\n</g>\n<!-- boris->tim -->\n<g id=\"edge10\" class=\"edge\">\n<title>boris->tim</title>\n<path fill=\"none\" stroke=\"black\" d=\"M170.79,-72.77C165.41,-64.19 158.68,-53.49 152.65,-43.9\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"155.43,-41.75 147.15,-35.15 149.51,-45.48 155.43,-41.75\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1f6f345a-bb16-4e64-83c4-cbbc789a8325",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Unanswerable math\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, where CPAL, rather than hallucinate, answers with _\"unanswerable, narrative question and plot are incoherent\"_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "068afd79-fd41-4ec2-b4d0-c64140dc413f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia.\"\n",
|
|
" \"Marcia has two more pets than Cindy.\"\n",
|
|
" \"If Cindy has ten pets, how many pets does Barak have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "02f77db2-72e8-46c2-90b3-5e37ca42f80d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia.Marcia has two more pets than Cindy.If Cindy has ten pets, how many pets does Barak have?\"\"\"\n",
|
|
" cindy_pets = 10\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" result = jan_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'36'"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "925958de-e998-4ffa-8b2e-5a00ddae5026",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets does barak have?\",\n",
|
|
" \"expression\": \"SELECT name, value FROM df WHERE name = 'barak'\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"unanswerable, query and outcome are incoherent\n",
|
|
"\n",
|
|
"outcome:\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\n",
|
|
"query:\n",
|
|
"{'question': 'how many pets does barak have?', 'expression': \"SELECT name, value FROM df WHERE name = 'barak'\", 'llm_error_msg': ''}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"try:\n",
|
|
" cpal_chain.run(question)\n",
|
|
"except Exception as e_msg:\n",
|
|
" print(e_msg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "095adc76",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Basic math\n",
|
|
"\n",
|
|
"#### Causal mediator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "3ecf03fa-8350-4c4e-8080-84a307ba6ad4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "74e49c47-3eed-4abe-98b7-8e97bcd15944",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"PAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "2e88395f-d014-4362-abb0-88f6800860bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n",
|
|
" cindy_pets = 4\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" total_pets = cindy_pets + marcia_pets + jan_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'28'"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20ba6640-3d17-4b59-8101-aaba89d68cf4",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"CPAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "312a0943-a482-4ed0-a064-1e7a72e9479b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 18.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"28.0"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "4466b975-ae2b-4252-972b-b3182a089ade",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"92pt\" height=\"188pt\" viewBox=\"0.00 0.00 92.49 188.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 88.49,-184 88.49,4 -4,4\"/>\n<!-- cindy -->\n<g id=\"node1\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- marcia -->\n<g id=\"node2\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- cindy->marcia -->\n<g id=\"edge1\" class=\"edge\">\n<title>cindy->marcia</title>\n<path fill=\"none\" stroke=\"black\" d=\"M42.25,-143.7C42.25,-135.98 42.25,-126.71 42.25,-118.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-118.1 42.25,-108.1 38.75,-118.1 45.75,-118.1\"/>\n</g>\n<!-- jan -->\n<g id=\"node3\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M42.25,-71.7C42.25,-63.98 42.25,-54.71 42.25,-46.11\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-46.1 42.25,-36.1 38.75,-46.1 45.75,-46.1\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "29fa7b8a-75a3-4270-82a2-2c31939cd7e0",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal collider"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "618eddac-f0ef-4ab5-90ed-72e880fdeba3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has no pets. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "a01563f3-7974-4de4-8bd9-0b7d710aa0d3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 marcia pass 0.0 []\n",
|
|
"1 cindy pass 4.0 []\n",
|
|
"2 jan jan.value = marcia.value + cindy.value 4.0 [marcia, cindy]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"8.0"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "0fbe7243-0522-4946-b9a2-6e21e7c49a42",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"182pt\" height=\"116pt\" viewBox=\"0.00 0.00 182.00 116.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 112)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-112 178,-112 178,4 -4,4\"/>\n<!-- marcia -->\n<g id=\"node1\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- jan -->\n<g id=\"node2\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"90.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"90.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge1\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M53.62,-72.41C59.57,-63.74 66.95,-52.97 73.53,-43.38\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"76.51,-45.21 79.28,-34.99 70.74,-41.26 76.51,-45.21\"/>\n</g>\n<!-- cindy -->\n<g id=\"node3\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"138.25\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"138.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- cindy->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>cindy->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M127.11,-72.77C121.09,-63.98 113.54,-52.96 106.83,-43.19\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"109.53,-40.94 100.99,-34.67 103.75,-44.89 109.53,-40.94\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4082538-ec03-44f0-aac3-07e03aad7555",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal confounder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "83932c30-950b-435a-b328-7993ce8cc6bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "570de307-7c6b-4fdc-80c3-4361daa8a629",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = cindy.value + marcia.value 10.0 [cindy, marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"20.0"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "00375615-6b6d-4357-bdb8-f64f682f7605",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": "<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"121pt\" height=\"188pt\" viewBox=\"0.00 0.00 120.99 188.00\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 116.99,-184 116.99,4 -4,4\"/>\n<!-- cindy -->\n<g id=\"node1\" class=\"node\">\n<title>cindy</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"77.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n</g>\n<!-- marcia -->\n<g id=\"node2\" class=\"node\">\n<title>marcia</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n</g>\n<!-- cindy->marcia -->\n<g id=\"edge1\" class=\"edge\">\n<title>cindy->marcia</title>\n<path fill=\"none\" stroke=\"black\" d=\"M68.95,-144.41C64.87,-136.25 59.86,-126.22 55.28,-117.07\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"58.33,-115.34 50.72,-107.96 52.07,-118.47 58.33,-115.34\"/>\n</g>\n<!-- jan -->\n<g id=\"node3\" class=\"node\">\n<title>jan</title>\n<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n<text text-anchor=\"middle\" x=\"77.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n</g>\n<!-- cindy->jan -->\n<g id=\"edge2\" class=\"edge\">\n<title>cindy->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M83.73,-144.1C87.32,-133.84 91.42,-120.36 93.25,-108 95.58,-92.17 95.58,-87.83 93.25,-72 91.95,-63.21 89.5,-53.86 86.91,-45.5\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"90.19,-44.29 83.73,-35.9 83.55,-46.49 90.19,-44.29\"/>\n</g>\n<!-- marcia->jan -->\n<g id=\"edge3\" class=\"edge\">\n<title>marcia->jan</title>\n<path fill=\"none\" stroke=\"black\" d=\"M50.72,-72.06C54.86,-63.77 59.94,-53.62 64.53,-44.42\"/>\n<polygon fill=\"black\" stroke=\"black\" points=\"67.75,-45.82 69.09,-35.31 61.49,-42.69 67.75,-45.82\"/>\n</g>\n</g>\n</svg>",
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "255683de-0c1c-4131-b277-99d09f5ac1fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|