mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
2667ddc686
**Description: a description of the change** Fixed `make docs_build` and related scripts which caused errors. There are several changes. First, I made the build of the documentation and the API Reference into two separate commands. This is because it takes less time to build. The commands for documents are `make docs_build`, `make docs_clean`, and `make docs_linkcheck`. The commands for API Reference are `make api_docs_build`, `api_docs_clean`, and `api_docs_linkcheck`. It looked like `docs/.local_build.sh` could be used to build the documentation, so I used that. Since `.local_build.sh` was also building API Rerefence internally, I removed that process. `.local_build.sh` also added some Bash options to stop in error or so. Futher more added `cd "${SCRIPT_DIR}"` at the beginning so that the script will work no matter which directory it is executed in. `docs/api_reference/api_reference.rst` is removed, because which is generated by `docs/api_reference/create_api_rst.py`, and added it to .gitignore. Finally, the description of CONTRIBUTING.md was modified. **Issue: the issue # it fixes (if applicable)** https://github.com/hwchase17/langchain/issues/6413 **Dependencies: any dependencies required for this change** `nbdoc` was missing in group docs so it was added. I installed it with the `poetry add --group docs nbdoc` command. I am concerned if any modifications are needed to poetry.lock. I would greatly appreciate it if you could pay close attention to this file during the review. **Tag maintainer** - General / Misc / if you don't know who to tag: @baskaryan If this PR needs any additional changes, I'll be happy to make them! --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
922 lines
33 KiB
Plaintext
922 lines
33 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"id": "82f3f65d-fbcb-4e8e-b04b-959856283643",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Causal program-aided language (CPAL) chain\n",
|
|
"\n",
|
|
"The CPAL chain builds on the recent PAL to stop LLM hallucination. The problem with the PAL approach is that it hallucinates on a math problem with a nested chain of dependence. The innovation here is that this new CPAL approach includes causal structure to fix hallucination.\n",
|
|
"\n",
|
|
"The original [PR's description](https://github.com/hwchase17/langchain/pull/6255) contains a full overview.\n",
|
|
"\n",
|
|
"Using the CPAL chain, the LLM translated this\n",
|
|
"\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
"\n",
|
|
"\n",
|
|
"into this\n",
|
|
"\n",
|
|
"![complex-graph.png](/img/cpal_diagram.png).\n",
|
|
"\n",
|
|
"Outline of code examples demoed in this notebook.\n",
|
|
"\n",
|
|
"1. CPAL's value against hallucination: CPAL vs PAL \n",
|
|
" 1.1 Complex narrative \n",
|
|
" 1.2 Unanswerable math word problem \n",
|
|
"2. CPAL's three types of causal diagrams ([The Book of Why](https://en.wikipedia.org/wiki/The_Book_of_Why)). \n",
|
|
" 2.1 Mediator \n",
|
|
" 2.2 Collider \n",
|
|
" 2.3 Confounder "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1370e40f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import SVG\n",
|
|
"\n",
|
|
"from langchain.experimental.cpal.base import CPALChain\n",
|
|
"from langchain.chains import PALChain\n",
|
|
"from langchain import OpenAI\n",
|
|
"\n",
|
|
"llm = OpenAI(temperature=0, max_tokens=512)\n",
|
|
"cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)\n",
|
|
"pal_chain = PALChain.from_math_prompt(llm=llm, verbose=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "858a87d9-a9bd-4850-9687-9af4b0856b62",
|
|
"metadata": {},
|
|
"source": [
|
|
"## CPAL's value against hallucination: CPAL vs PAL\n",
|
|
"\n",
|
|
"Like PAL, CPAL intends to reduce large language model (LLM) hallucination.\n",
|
|
"\n",
|
|
"The CPAL chain is different from the PAL chain for a couple of reasons.\n",
|
|
"\n",
|
|
"CPAL adds a causal structure (or DAG) to link entity actions (or math expressions).\n",
|
|
"The CPAL math expressions are modeling a chain of cause and effect relations, which can be intervened upon, whereas for the PAL chain math expressions are projected math identities.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "496403c5-d268-43ae-8852-2bd9903ce444",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1.1 Complex narrative\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, CPAL does not hallucinate."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "d5dad768-2892-4825-8093-9b840f643a8a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Tim buys the same number of pets as Cindy and Boris.\"\n",
|
|
" \"Cindy buys the same number of pets as Bill plus Bob.\"\n",
|
|
" \"Boris buys the same number of pets as Ben plus Beth.\"\n",
|
|
" \"Bill buys the same number of pets as Obama.\"\n",
|
|
" \"Bob buys the same number of pets as Obama.\"\n",
|
|
" \"Ben buys the same number of pets as Obama.\"\n",
|
|
" \"Beth buys the same number of pets as Obama.\"\n",
|
|
" \"If Obama buys one pet, how many pets total does everyone buy?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "bbffa7a0-3c22-4a1d-ab2d-f230973073b0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Tim buys the same number of pets as Cindy and Boris.Cindy buys the same number of pets as Bill plus Bob.Boris buys the same number of pets as Ben plus Beth.Bill buys the same number of pets as Obama.Bob buys the same number of pets as Obama.Ben buys the same number of pets as Obama.Beth buys the same number of pets as Obama.If Obama buys one pet, how many pets total does everyone buy?\"\"\"\n",
|
|
" obama_pets = 1\n",
|
|
" tim_pets = obama_pets\n",
|
|
" cindy_pets = obama_pets + obama_pets\n",
|
|
" boris_pets = obama_pets + obama_pets\n",
|
|
" total_pets = tim_pets + cindy_pets + boris_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'5'"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "35a70d1d-86f8-4abc-b818-fbd083f072e9",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 obama pass 1.0 []\n",
|
|
"1 bill bill.value = obama.value 1.0 [obama]\n",
|
|
"2 bob bob.value = obama.value 1.0 [obama]\n",
|
|
"3 ben ben.value = obama.value 1.0 [obama]\n",
|
|
"4 beth beth.value = obama.value 1.0 [obama]\n",
|
|
"5 cindy cindy.value = bill.value + bob.value 2.0 [bill, bob]\n",
|
|
"6 boris boris.value = ben.value + beth.value 2.0 [ben, beth]\n",
|
|
"7 tim tim.value = cindy.value + boris.value 4.0 [cindy, boris]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets total does everyone buy?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"13.0"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "ccb6b2b0-9de6-4f66-a8fb-fc59229ee316",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"292pt\" height=\"260pt\" viewBox=\"0.00 0.00 292.00 260.00\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-256 288,-256 288,4 -4,4\"/>\n",
|
|
"<!-- obama -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>obama</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-234\" rx=\"41.69\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"137\" y=\"-230.3\" font-family=\"Times,serif\" font-size=\"14.00\">obama</text>\n",
|
|
"</g>\n",
|
|
"<!-- bill -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>bill</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"27\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bill</text>\n",
|
|
"</g>\n",
|
|
"<!-- obama->bill -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>obama->bill</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M114.47,-218.67C97.08,-207.6 72.94,-192.23 54.42,-180.45\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"56.15,-177.4 45.84,-174.99 52.4,-183.31 56.15,-177.4\"/>\n",
|
|
"</g>\n",
|
|
"<!-- bob -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>bob</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"100\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"100\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">bob</text>\n",
|
|
"</g>\n",
|
|
"<!-- obama->bob -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>obama->bob</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M128.04,-216.05C123.66,-207.77 118.3,-197.62 113.44,-188.42\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"116.39,-186.51 108.62,-179.31 110.2,-189.79 116.39,-186.51\"/>\n",
|
|
"</g>\n",
|
|
"<!-- ben -->\n",
|
|
"<g id=\"node4\" class=\"node\">\n",
|
|
"<title>ben</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"174\" cy=\"-162\" rx=\"28\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"174\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">ben</text>\n",
|
|
"</g>\n",
|
|
"<!-- obama->ben -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>obama->ben</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M145.96,-216.05C150.34,-207.77 155.7,-197.62 160.56,-188.42\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"163.8,-189.79 165.38,-179.31 157.61,-186.51 163.8,-189.79\"/>\n",
|
|
"</g>\n",
|
|
"<!-- beth -->\n",
|
|
"<g id=\"node5\" class=\"node\">\n",
|
|
"<title>beth</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"252\" cy=\"-162\" rx=\"32\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"252\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">beth</text>\n",
|
|
"</g>\n",
|
|
"<!-- obama->beth -->\n",
|
|
"<g id=\"edge4\" class=\"edge\">\n",
|
|
"<title>obama->beth</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M160.27,-218.83C178.18,-207.94 203.04,-192.8 222.37,-181.04\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"224.36,-183.92 231.08,-175.73 220.72,-177.95 224.36,-183.92\"/>\n",
|
|
"</g>\n",
|
|
"<!-- cindy -->\n",
|
|
"<g id=\"node6\" class=\"node\">\n",
|
|
"<title>cindy</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"93\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"93\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n",
|
|
"</g>\n",
|
|
"<!-- bill->cindy -->\n",
|
|
"<g id=\"edge5\" class=\"edge\">\n",
|
|
"<title>bill->cindy</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M41,-146.15C49.77,-136.85 61.25,-124.67 71.2,-114.12\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"73.79,-116.47 78.11,-106.8 68.7,-111.67 73.79,-116.47\"/>\n",
|
|
"</g>\n",
|
|
"<!-- bob->cindy -->\n",
|
|
"<g id=\"edge6\" class=\"edge\">\n",
|
|
"<title>bob->cindy</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M98.27,-143.7C97.5,-135.98 96.57,-126.71 95.71,-118.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"99.19,-117.7 94.71,-108.1 92.22,-118.4 99.19,-117.7\"/>\n",
|
|
"</g>\n",
|
|
"<!-- boris -->\n",
|
|
"<g id=\"node7\" class=\"node\">\n",
|
|
"<title>boris</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"181\" cy=\"-90\" rx=\"34.5\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"181\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">boris</text>\n",
|
|
"</g>\n",
|
|
"<!-- ben->boris -->\n",
|
|
"<g id=\"edge7\" class=\"edge\">\n",
|
|
"<title>ben->boris</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M175.73,-143.7C176.5,-135.98 177.43,-126.71 178.29,-118.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"181.78,-118.4 179.29,-108.1 174.81,-117.7 181.78,-118.4\"/>\n",
|
|
"</g>\n",
|
|
"<!-- beth->boris -->\n",
|
|
"<g id=\"edge8\" class=\"edge\">\n",
|
|
"<title>beth->boris</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M236.59,-145.81C227.01,-136.36 214.51,-124.04 203.8,-113.48\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"205.96,-110.69 196.38,-106.16 201.04,-115.67 205.96,-110.69\"/>\n",
|
|
"</g>\n",
|
|
"<!-- tim -->\n",
|
|
"<g id=\"node8\" class=\"node\">\n",
|
|
"<title>tim</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"137\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"137\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">tim</text>\n",
|
|
"</g>\n",
|
|
"<!-- cindy->tim -->\n",
|
|
"<g id=\"edge9\" class=\"edge\">\n",
|
|
"<title>cindy->tim</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M103.43,-72.41C108.82,-63.83 115.51,-53.19 121.49,-43.67\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"124.59,-45.32 126.95,-34.99 118.66,-41.59 124.59,-45.32\"/>\n",
|
|
"</g>\n",
|
|
"<!-- boris->tim -->\n",
|
|
"<g id=\"edge10\" class=\"edge\">\n",
|
|
"<title>boris->tim</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M170.79,-72.77C165.41,-64.19 158.68,-53.49 152.65,-43.9\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"155.43,-41.75 147.15,-35.15 149.51,-45.48 155.43,-41.75\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1f6f345a-bb16-4e64-83c4-cbbc789a8325",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Unanswerable math\n",
|
|
"\n",
|
|
"Takeaway: PAL hallucinates, where CPAL, rather than hallucinate, answers with _\"unanswerable, narrative question and plot are incoherent\"_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "068afd79-fd41-4ec2-b4d0-c64140dc413f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia.\"\n",
|
|
" \"Marcia has two more pets than Cindy.\"\n",
|
|
" \"If Cindy has ten pets, how many pets does Barak have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "02f77db2-72e8-46c2-90b3-5e37ca42f80d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia.Marcia has two more pets than Cindy.If Cindy has ten pets, how many pets does Barak have?\"\"\"\n",
|
|
" cindy_pets = 10\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" result = jan_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'36'"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "925958de-e998-4ffa-8b2e-5a00ddae5026",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many pets does barak have?\",\n",
|
|
" \"expression\": \"SELECT name, value FROM df WHERE name = 'barak'\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"unanswerable, query and outcome are incoherent\n",
|
|
"\n",
|
|
"outcome:\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 10.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 12.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 36.0 [marcia]\n",
|
|
"query:\n",
|
|
"{'question': 'how many pets does barak have?', 'expression': \"SELECT name, value FROM df WHERE name = 'barak'\", 'llm_error_msg': ''}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"try:\n",
|
|
" cpal_chain.run(question)\n",
|
|
"except Exception as e_msg:\n",
|
|
" print(e_msg)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "095adc76",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Basic math\n",
|
|
"\n",
|
|
"#### Causal mediator"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "3ecf03fa-8350-4c4e-8080-84a307ba6ad4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has three times the number of pets as Marcia. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "74e49c47-3eed-4abe-98b7-8e97bcd15944",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"PAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "2e88395f-d014-4362-abb0-88f6800860bb",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mdef solution():\n",
|
|
" \"\"\"Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?\"\"\"\n",
|
|
" cindy_pets = 4\n",
|
|
" marcia_pets = cindy_pets + 2\n",
|
|
" jan_pets = marcia_pets * 3\n",
|
|
" total_pets = cindy_pets + marcia_pets + jan_pets\n",
|
|
" result = total_pets\n",
|
|
" return result\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'28'"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20ba6640-3d17-4b59-8101-aaba89d68cf4",
|
|
"metadata": {},
|
|
"source": [
|
|
"---\n",
|
|
"CPAL"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "312a0943-a482-4ed0-a064-1e7a72e9479b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = marcia.value * 3 18.0 [marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"28.0"
|
|
]
|
|
},
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "4466b975-ae2b-4252-972b-b3182a089ade",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"92pt\" height=\"188pt\" viewBox=\"0.00 0.00 92.49 188.00\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 88.49,-184 88.49,4 -4,4\"/>\n",
|
|
"<!-- cindy -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>cindy</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"42.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n",
|
|
"</g>\n",
|
|
"<!-- marcia -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>marcia</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n",
|
|
"</g>\n",
|
|
"<!-- cindy->marcia -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>cindy->marcia</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M42.25,-143.7C42.25,-135.98 42.25,-126.71 42.25,-118.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-118.1 42.25,-108.1 38.75,-118.1 45.75,-118.1\"/>\n",
|
|
"</g>\n",
|
|
"<!-- jan -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>jan</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"42.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n",
|
|
"</g>\n",
|
|
"<!-- marcia->jan -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>marcia->jan</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M42.25,-71.7C42.25,-63.98 42.25,-54.71 42.25,-46.11\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"45.75,-46.1 42.25,-36.1 38.75,-46.1 45.75,-46.1\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "29fa7b8a-75a3-4270-82a2-2c31939cd7e0",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal collider"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "618eddac-f0ef-4ab5-90ed-72e880fdeba3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has no pets. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "a01563f3-7974-4de4-8bd9-0b7d710aa0d3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 marcia pass 0.0 []\n",
|
|
"1 cindy pass 4.0 []\n",
|
|
"2 jan jan.value = marcia.value + cindy.value 4.0 [marcia, cindy]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"8.0"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "0fbe7243-0522-4946-b9a2-6e21e7c49a42",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"182pt\" height=\"116pt\" viewBox=\"0.00 0.00 182.00 116.00\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 112)\">\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-112 178,-112 178,4 -4,4\"/>\n",
|
|
"<!-- marcia -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>marcia</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n",
|
|
"</g>\n",
|
|
"<!-- jan -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>jan</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"90.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"90.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n",
|
|
"</g>\n",
|
|
"<!-- marcia->jan -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>marcia->jan</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M53.62,-72.41C59.57,-63.74 66.95,-52.97 73.53,-43.38\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"76.51,-45.21 79.28,-34.99 70.74,-41.26 76.51,-45.21\"/>\n",
|
|
"</g>\n",
|
|
"<!-- cindy -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>cindy</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"138.25\" cy=\"-90\" rx=\"36\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"138.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n",
|
|
"</g>\n",
|
|
"<!-- cindy->jan -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>cindy->jan</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M127.11,-72.77C121.09,-63.98 113.54,-52.96 106.83,-43.19\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"109.53,-40.94 100.99,-34.67 103.75,-44.89 109.53,-40.94\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4082538-ec03-44f0-aac3-07e03aad7555",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal confounder"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "83932c30-950b-435a-b328-7993ce8cc6bd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"question = (\n",
|
|
" \"Jan has the number of pets as Marcia plus the number of pets as Cindy. \"\n",
|
|
" \"Marcia has two more pets than Cindy. \"\n",
|
|
" \"If Cindy has four pets, how many total pets do the three have?\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "570de307-7c6b-4fdc-80c3-4361daa8a629",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Entering new chain...\u001b[0m\n",
|
|
"\u001b[32;1m\u001b[1;3mstory outcome data\n",
|
|
" name code value depends_on\n",
|
|
"0 cindy pass 4.0 []\n",
|
|
"1 marcia marcia.value = cindy.value + 2 6.0 [cindy]\n",
|
|
"2 jan jan.value = cindy.value + marcia.value 10.0 [cindy, marcia]\u001b[0m\n",
|
|
"\n",
|
|
"\u001b[36;1m\u001b[1;3mquery data\n",
|
|
"{\n",
|
|
" \"question\": \"how many total pets do the three have?\",\n",
|
|
" \"expression\": \"SELECT SUM(value) FROM df\",\n",
|
|
" \"llm_error_msg\": \"\"\n",
|
|
"}\u001b[0m\n",
|
|
"\n",
|
|
"\n",
|
|
"\u001b[1m> Finished chain.\u001b[0m\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"20.0"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"cpal_chain.run(question)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "00375615-6b6d-4357-bdb8-f64f682f7605",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/svg+xml": [
|
|
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"121pt\" height=\"188pt\" viewBox=\"0.00 0.00 120.99 188.00\">\n",
|
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
|
|
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-184 116.99,-184 116.99,4 -4,4\"/>\n",
|
|
"<!-- cindy -->\n",
|
|
"<g id=\"node1\" class=\"node\">\n",
|
|
"<title>cindy</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-162\" rx=\"36\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"77.25\" y=\"-158.3\" font-family=\"Times,serif\" font-size=\"14.00\">cindy</text>\n",
|
|
"</g>\n",
|
|
"<!-- marcia -->\n",
|
|
"<g id=\"node2\" class=\"node\">\n",
|
|
"<title>marcia</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"42.25\" cy=\"-90\" rx=\"42.49\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"42.25\" y=\"-86.3\" font-family=\"Times,serif\" font-size=\"14.00\">marcia</text>\n",
|
|
"</g>\n",
|
|
"<!-- cindy->marcia -->\n",
|
|
"<g id=\"edge1\" class=\"edge\">\n",
|
|
"<title>cindy->marcia</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M68.95,-144.41C64.87,-136.25 59.86,-126.22 55.28,-117.07\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"58.33,-115.34 50.72,-107.96 52.07,-118.47 58.33,-115.34\"/>\n",
|
|
"</g>\n",
|
|
"<!-- jan -->\n",
|
|
"<g id=\"node3\" class=\"node\">\n",
|
|
"<title>jan</title>\n",
|
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"77.25\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
|
|
"<text text-anchor=\"middle\" x=\"77.25\" y=\"-14.3\" font-family=\"Times,serif\" font-size=\"14.00\">jan</text>\n",
|
|
"</g>\n",
|
|
"<!-- cindy->jan -->\n",
|
|
"<g id=\"edge2\" class=\"edge\">\n",
|
|
"<title>cindy->jan</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M83.73,-144.1C87.32,-133.84 91.42,-120.36 93.25,-108 95.58,-92.17 95.58,-87.83 93.25,-72 91.95,-63.21 89.5,-53.86 86.91,-45.5\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"90.19,-44.29 83.73,-35.9 83.55,-46.49 90.19,-44.29\"/>\n",
|
|
"</g>\n",
|
|
"<!-- marcia->jan -->\n",
|
|
"<g id=\"edge3\" class=\"edge\">\n",
|
|
"<title>marcia->jan</title>\n",
|
|
"<path fill=\"none\" stroke=\"black\" d=\"M50.72,-72.06C54.86,-63.77 59.94,-53.62 64.53,-44.42\"/>\n",
|
|
"<polygon fill=\"black\" stroke=\"black\" points=\"67.75,-45.82 69.09,-35.31 61.49,-42.69 67.75,-45.82\"/>\n",
|
|
"</g>\n",
|
|
"</g>\n",
|
|
"</svg>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.SVG object>"
|
|
]
|
|
},
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# wait 20 secs to see display\n",
|
|
"cpal_chain.draw(path=\"web.svg\")\n",
|
|
"SVG(\"web.svg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "255683de-0c1c-4131-b277-99d09f5ac1fc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|