# Causal program-aided language (CPAL) chain

## Motivation

This builds on the recent [PAL](https://arxiv.org/abs/2211.10435) to
stop LLM hallucination. The problem with the
[PAL](https://arxiv.org/abs/2211.10435) approach is that it hallucinates
on a math problem with a nested chain of dependence. The innovation here
is that this new CPAL approach includes causal structure to fix
hallucination.

For example, using the below word problem, PAL answers with 5, and CPAL
answers with 13.

    "Tim buys the same number of pets as Cindy and Boris."
    "Cindy buys the same number of pets as Bill plus Bob."
    "Boris buys the same number of pets as Ben plus Beth."
    "Bill buys the same number of pets as Obama."
    "Bob buys the same number of pets as Obama."
    "Ben buys the same number of pets as Obama."
    "Beth buys the same number of pets as Obama."
    "If Obama buys one pet, how many pets total does everyone buy?"

The CPAL chain represents the causal structure of the above narrative as
a causal graph or DAG, which it can also plot, as shown below.


![complex-graph](https://github.com/hwchase17/langchain/assets/367522/d938db15-f941-493d-8605-536ad530f576)

.

The two major sections below are:

1. Technical overview
2. Future application

Also see [this jupyter
notebook](https://github.com/borisdev/langchain/blob/master/docs/extras/modules/chains/additional/cpal.ipynb)
doc.


## 1. Technical overview

### CPAL versus PAL

Like [PAL](https://arxiv.org/abs/2211.10435), CPAL intends to reduce
large language model (LLM) hallucination.

The CPAL chain is different from the PAL chain for a couple of reasons. 

* CPAL adds a causal structure (or DAG) to link entity actions (or math
expressions).
* The CPAL math expressions are modeling a chain of cause and effect
relations, which can be intervened upon, whereas for the PAL chain math
expressions are projected math identities.

PAL's generated python code is wrong. It hallucinates when complexity
increases.

```python
def solution():
    """Tim buys the same number of pets as Cindy and Boris.Cindy buys the same number of pets as Bill plus Bob.Boris buys the same number of pets as Ben plus Beth.Bill buys the same number of pets as Obama.Bob buys the same number of pets as Obama.Ben buys the same number of pets as Obama.Beth buys the same number of pets as Obama.If Obama buys one pet, how many pets total does everyone buy?"""
    obama_pets = 1
    tim_pets = obama_pets
    cindy_pets = obama_pets + obama_pets
    boris_pets = obama_pets + obama_pets
    total_pets = tim_pets + cindy_pets + boris_pets
    result = total_pets
    return result  # math result is 5
```

CPAL's generated python code is correct.

```python
story outcome data
    name                                   code  value      depends_on
0  obama                                   pass    1.0              []
1   bill               bill.value = obama.value    1.0         [obama]
2    bob                bob.value = obama.value    1.0         [obama]
3    ben                ben.value = obama.value    1.0         [obama]
4   beth               beth.value = obama.value    1.0         [obama]
5  cindy   cindy.value = bill.value + bob.value    2.0     [bill, bob]
6  boris   boris.value = ben.value + beth.value    2.0     [ben, beth]
7    tim  tim.value = cindy.value + boris.value    4.0  [cindy, boris]

query data
{
    "question": "how many pets total does everyone buy?",
    "expression": "SELECT SUM(value) FROM df",
    "llm_error_msg": ""
}
# query result is 13
```

Based on the comments below, CPAL's intended location in the library is
`experimental/chains/cpal` and PAL's location is`chains/pal`.

### CPAL vs Graph QA

Both the CPAL chain and the Graph QA chain extract entity-action-entity
relations into a DAG.

The CPAL chain is different from the Graph QA chain for a few reasons.

* Graph QA does not connect entities to math expressions
* Graph QA does not associate actions in a sequence of dependence.
* Graph QA does not decompose the narrative into these three parts:
  1. Story plot or causal model
  4. Hypothetical question
  5. Hypothetical condition 

### Evaluation

Preliminary evaluation on simple math word problems shows that this CPAL
chain generates less hallucination than the PAL chain on answering
questions about a causal narrative. Two examples are in [this jupyter
notebook](https://github.com/borisdev/langchain/blob/master/docs/extras/modules/chains/additional/cpal.ipynb)
doc.

## 2. Future application

### "Describe as Narrative, Test as Code"

The thesis here is that the Describe as Narrative, Test as Code approach
allows you to represent a causal mental model both as code and as a
narrative, giving you the best of both worlds.

#### Why describe a causal mental mode as a narrative?

The narrative form is quick. At a consensus building meeting, people use
narratives to persuade others of their causal mental model, aka. plan.
You can share, version control and index a narrative.

#### Why test a causal mental model as a code?

Code is testable, complex narratives are not. Though fast, narratives
are problematic as their complexity increases. The problem is LLMs and
humans are prone to hallucination when predicting the outcomes of a
narrative. The cost of building a consensus around the validity of a
narrative outcome grows as its narrative complexity increases. Code does
not require tribal knowledge or social power to validate.

Code is composable, complex narratives are not. The answer of one CPAL
chain can be the hypothetical conditions of another CPAL Chain. For
stochastic simulations, a composable plan can be integrated with the
[DoWhy library](https://github.com/py-why/dowhy). Lastly, for the
futuristic folk, a composable plan as code allows ordinary community
folk to design a plan that can be integrated with a blockchain for
funding.

An explanation of a dependency planning application is
[here.](https://github.com/borisdev/cpal-llm-chain-demo)

--- 
Twitter handle: @boris_dev

---------

Co-authored-by: Boris Dev <borisdev@Boriss-MacBook-Air.local>
pull/7544/head
Boris 1 year ago committed by GitHub
parent 2e4047e5e7
commit 9129318466
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because one or more lines are too long

@ -0,0 +1,4 @@
# Causal program-aided language (CPAL) chain
see https://github.com/hwchase17/langchain/pull/6255

@ -0,0 +1,271 @@
"""
CPAL Chain and its subchains
"""
from __future__ import annotations
import json
from typing import Any, ClassVar, Dict, List, Optional, Type
import pydantic
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.experimental.cpal.constants import Constant
from langchain.experimental.cpal.models import (
CausalModel,
InterventionModel,
NarrativeModel,
QueryModel,
StoryModel,
)
from langchain.experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain.experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain.experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain.experimental.cpal.templates.univariate.query import (
template as query_template,
)
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.prompt import PromptTemplate
class _BaseStoryElementChain(Chain):
chain: LLMChain
input_key: str = Constant.narrative_input.value #: :meta private:
output_key: str = Constant.chain_answer.value #: :meta private:
pydantic_model: ClassVar[
Optional[Type[pydantic.BaseModel]]
] = None #: :meta private:
template: ClassVar[Optional[str]] = None #: :meta private:
@classmethod
def parser(cls) -> PydanticOutputParser:
"""Parse LLM output into a pydantic object."""
if cls.pydantic_model is None:
raise NotImplementedError(
f"pydantic_model not implemented for {cls.__name__}"
)
return PydanticOutputParser(pydantic_object=cls.pydantic_model)
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> Any:
return cls(
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=[Constant.narrative_input.value],
template=kwargs.get("template", cls.template),
partial_variables={
"format_instructions": cls.parser().get_format_instructions()
},
),
),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
completion = self.chain.run(inputs[self.input_key])
pydantic_data = self.__class__.parser().parse(completion)
return {
Constant.chain_data.value: pydantic_data,
Constant.chain_answer.value: None,
}
class NarrativeChain(_BaseStoryElementChain):
"""Decompose the narrative into its story elements
- causal model
- query
- intervention
"""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = NarrativeModel
template: ClassVar[str] = narrative_template
class CausalChain(_BaseStoryElementChain):
"""Translate the causal narrative into a stack of operations."""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = CausalModel
template: ClassVar[str] = causal_template
class InterventionChain(_BaseStoryElementChain):
"""Set the hypothetical conditions for the causal model."""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = InterventionModel
template: ClassVar[str] = intervention_template
class QueryChain(_BaseStoryElementChain):
"""Query the outcome table using SQL."""
pydantic_model: ClassVar[Type[pydantic.BaseModel]] = QueryModel
template: ClassVar[str] = query_template # TODO: incl. table schema
class CPALChain(_BaseStoryElementChain):
llm: BaseLanguageModel
narrative_chain: Optional[NarrativeChain] = None
causal_chain: Optional[CausalChain] = None
intervention_chain: Optional[InterventionChain] = None
query_chain: Optional[QueryChain] = None
_story: StoryModel = pydantic.PrivateAttr(default=None) # TODO: change name ?
@classmethod
def from_univariate_prompt(
cls,
llm: BaseLanguageModel,
**kwargs: Any,
) -> CPALChain:
"""instantiation depends on component chains"""
return cls(
llm=llm,
chain=LLMChain(
llm=llm,
prompt=PromptTemplate(
input_variables=["question", "query_result"],
template=(
"Summarize this answer '{query_result}' to this "
"question '{question}'? "
),
),
),
narrative_chain=NarrativeChain.from_univariate_prompt(llm=llm),
causal_chain=CausalChain.from_univariate_prompt(llm=llm),
intervention_chain=InterventionChain.from_univariate_prompt(llm=llm),
query_chain=QueryChain.from_univariate_prompt(llm=llm),
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
**kwargs: Any,
) -> Dict[str, Any]:
# instantiate component chains
if self.narrative_chain is None:
self.narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.llm)
if self.causal_chain is None:
self.causal_chain = CausalChain.from_univariate_prompt(llm=self.llm)
if self.intervention_chain is None:
self.intervention_chain = InterventionChain.from_univariate_prompt(
llm=self.llm
)
if self.query_chain is None:
self.query_chain = QueryChain.from_univariate_prompt(llm=self.llm)
# decompose narrative into three causal story elements
narrative = self.narrative_chain(inputs[Constant.narrative_input.value])[
Constant.chain_data.value
]
story = StoryModel(
causal_operations=self.causal_chain(narrative.story_plot)[
Constant.chain_data.value
],
intervention=self.intervention_chain(narrative.story_hypothetical)[
Constant.chain_data.value
],
query=self.query_chain(narrative.story_outcome_question)[
Constant.chain_data.value
],
)
self._story = story
def pretty_print_str(title: str, d: str) -> str:
return title + "\n" + d
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(
pretty_print_str("story outcome data", story._outcome_table.to_string()),
color="green",
end="\n\n",
verbose=self.verbose,
)
def pretty_print_dict(title: str, d: dict) -> str:
return title + "\n" + json.dumps(d, indent=4)
_run_manager.on_text(
pretty_print_dict("query data", story.query.dict()),
color="blue",
end="\n\n",
verbose=self.verbose,
)
if story.query._result_table.empty:
# prevent piping bad data into subsequent chains
raise ValueError(
(
"unanswerable, query and outcome are incoherent\n"
"\n"
"outcome:\n"
f"{story._outcome_table}\n"
"query:\n"
f"{story.query.dict()}"
)
)
else:
query_result = float(story.query._result_table.values[0][-1])
if False:
"""TODO: add this back in when demanded by composable chains"""
reporting_chain = self.chain
human_report = reporting_chain.run(
question=story.query.question, query_result=query_result
)
query_result = {
"query_result": query_result,
"human_report": human_report,
}
output = {
Constant.chain_data.value: story,
self.output_key: query_result,
**kwargs,
}
return output
def draw(self, **kwargs: Any) -> None:
"""
CPAL chain can draw its resulting DAG.
Usage in a jupyter notebook:
>>> from IPython.display import SVG
>>> cpal_chain.draw(path="graph.svg")
>>> SVG('graph.svg')
"""
self._story._networkx_wrapper.draw_graphviz(**kwargs)

@ -0,0 +1,7 @@
from enum import Enum
class Constant(Enum):
narrative_input = "narrative_input"
chain_answer = "chain_answer" # natural language answer
chain_data = "chain_data" # pydantic instance

@ -0,0 +1,245 @@
from __future__ import annotations # allows pydantic model to reference itself
import re
from typing import Any, Optional, Union
import duckdb
import pandas as pd
from pydantic import BaseModel, Field, PrivateAttr, root_validator, validator
from langchain.experimental.cpal.constants import Constant
from langchain.graphs.networkx_graph import NetworkxEntityGraph
class NarrativeModel(BaseModel):
"""
Represent the narrative input as three story elements.
"""
story_outcome_question: str
story_hypothetical: str
story_plot: str # causal stack of operations
@validator("*", pre=True)
def empty_str_to_none(cls, v: str) -> Union[str, None]:
"""Empty strings are not allowed"""
if v == "":
return None
return v
class EntityModel(BaseModel):
name: str = Field(description="entity name")
code: str = Field(description="entity actions")
value: float = Field(description="entity initial value")
depends_on: list[str] = Field(default=[], description="ancestor entities")
# TODO: generalize to multivariate math
# TODO: acyclic graph
class Config:
validate_assignment = True
@validator("name")
def lower_case_name(cls, v: str) -> str:
v = v.lower()
return v
class CausalModel(BaseModel):
attribute: str = Field(description="name of the attribute to be calculated")
entities: list[EntityModel] = Field(description="entities in the story")
# TODO: root validate each `entity.depends_on` using system's entity names
class EntitySettingModel(BaseModel):
"""
Initial conditions for an entity
{"name": "bud", "attribute": "pet_count", "value": 12}
"""
name: str = Field(description="name of the entity")
attribute: str = Field(description="name of the attribute to be calculated")
value: float = Field(description="entity's attribute value (calculated)")
@validator("name")
def lower_case_transform(cls, v: str) -> str:
v = v.lower()
return v
class SystemSettingModel(BaseModel):
"""
Initial global conditions for the system.
{"parameter": "interest_rate", "value": .05}
"""
parameter: str
value: float
class InterventionModel(BaseModel):
"""
aka initial conditions
>>> intervention.dict()
{
entity_settings: [
{"name": "bud", "attribute": "pet_count", "value": 12},
{"name": "pat", "attribute": "pet_count", "value": 0},
],
system_settings: None,
}
"""
entity_settings: list[EntitySettingModel]
system_settings: Optional[list[SystemSettingModel]] = None
@validator("system_settings")
def lower_case_name(cls, v: str) -> Union[str, None]:
if v is not None:
raise NotImplementedError("system_setting is not implemented yet")
return v
class QueryModel(BaseModel):
"""translate a question about the story outcome into a programatic expression"""
question: str = Field(alias=Constant.narrative_input.value) # input
expression: str # output, part of llm completion
llm_error_msg: str # output, part of llm completion
_result_table: str = PrivateAttr() # result of the executed query
class ResultModel(BaseModel):
question: str = Field(alias=Constant.narrative_input.value) # input
_result_table: str = PrivateAttr() # result of the executed query
class StoryModel(BaseModel):
causal_operations: Any = Field(required=True)
intervention: Any = Field(required=True)
query: Any = Field(required=True)
_outcome_table: pd.DataFrame = PrivateAttr(default=None)
_networkx_wrapper: Any = PrivateAttr(default=None)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._compute()
# TODO: when langchain adopts pydantic.v2 replace w/ `__post_init__`
# misses hints github.com/pydantic/pydantic/issues/1729#issuecomment-1300576214
@root_validator
def check_intervention_is_valid(cls, values: dict) -> dict:
valid_names = [e.name for e in values["causal_operations"].entities]
for setting in values["intervention"].entity_settings:
if setting.name not in valid_names:
error_msg = f"""
Hypothetical question has an invalid entity name.
`{setting.name}` not in `{valid_names}`
"""
raise ValueError(error_msg)
return values
def _block_back_door_paths(self) -> None:
# stop intervention entities from depending on others
intervention_entities = [
entity_setting.name for entity_setting in self.intervention.entity_settings
]
for entity in self.causal_operations.entities:
if entity.name in intervention_entities:
entity.depends_on = []
entity.code = "pass"
def _set_initial_conditions(self) -> None:
for entity_setting in self.intervention.entity_settings:
for entity in self.causal_operations.entities:
if entity.name == entity_setting.name:
entity.value = entity_setting.value
def _make_graph(self) -> None:
self._networkx_wrapper = NetworkxEntityGraph()
for entity in self.causal_operations.entities:
for parent_name in entity.depends_on:
self._networkx_wrapper._graph.add_edge(
parent_name, entity.name, relation=entity.code
)
# TODO: is it correct to drop entities with no impact on the outcome (?)
self.causal_operations.entities = [
entity
for entity in self.causal_operations.entities
if entity.name in self._networkx_wrapper.get_topological_sort()
]
def _sort_entities(self) -> None:
# order the sequence of causal actions
sorted_nodes = self._networkx_wrapper.get_topological_sort()
self.causal_operations.entities.sort(key=lambda x: sorted_nodes.index(x.name))
def _forward_propagate(self) -> None:
entity_scope = {
entity.name: entity for entity in self.causal_operations.entities
}
for entity in self.causal_operations.entities:
if entity.code == "pass":
continue
else:
# gist.github.com/dean0x7d/df5ce97e4a1a05be4d56d1378726ff92
exec(entity.code, globals(), entity_scope)
row_values = [entity.dict() for entity in entity_scope.values()]
self._outcome_table = pd.DataFrame(row_values)
def _run_query(self) -> None:
def humanize_sql_error_msg(error: str) -> str:
pattern = r"column\s+(.*?)\s+not found"
col_match = re.search(pattern, error)
if col_match:
return (
"SQL error: "
+ col_match.group(1)
+ " is not an attribute in your story!"
)
else:
return str(error)
if self.query.llm_error_msg == "":
try:
df = self._outcome_table # noqa
query_result = duckdb.sql(self.query.expression).df()
self.query._result_table = query_result
except duckdb.BinderException as e:
self.query._result_table = humanize_sql_error_msg(str(e))
except Exception as e:
self.query._result_table = str(e)
else:
msg = "LLM maybe failed to translate question to SQL query."
raise ValueError(
{
"question": self.query.question,
"llm_error_msg": self.query.llm_error_msg,
"msg": msg,
}
)
def _compute(self) -> Any:
self._block_back_door_paths()
self._set_initial_conditions()
self._make_graph()
self._sort_entities()
self._forward_propagate()
self._run_query()
def print_debug_report(self) -> None:
report = {
"outcome": self._outcome_table,
"query": self.query.dict(),
"result": self.query._result_table,
}
from pprint import pprint
pprint(report)

@ -0,0 +1,113 @@
# flake8: noqa E501
# fmt: off
template = (
"""
Transform the math story plot into a JSON object. Don't guess at any of the parts.
{format_instructions}
Story: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy.
# JSON:
{{
"attribute": "pet_count",
"entities": [
{{
"name": "cindy",
"value": 0,
"depends_on": [],
"code": "pass"
}},
{{
"name": "marcia",
"value": 0,
"depends_on": ["cindy"],
"code": "marcia.value = cindy.value + 2"
}},
{{
"name": "boris",
"value": 0,
"depends_on": ["marcia"],
"code": "boris.value = marcia.value * 7"
}},
{{
"name": "jan",
"value": 0,
"depends_on": ["marcia"],
"code": "jan.value = marcia.value * 3"
}}
]
}}
Story: Boris gives 20 percent of his money to Marcia. Marcia gives 10
percent of her money to Cindy. Cindy gives 5 percent of her money to Jan.
# JSON:
{{
"attribute": "money",
"entities": [
{{
"name": "boris",
"value": 0,
"depends_on": [],
"code": "pass"
}},
{{
"name": "marcia",
"value": 0,
"depends_on": ["boris"],
"code": "
marcia.value = boris.value * 0.2
boris.value = boris.value * 0.8
"
}},
{{
"name": "cindy",
"value": 0,
"depends_on": ["marcia"],
"code": "
cindy.value = marcia.value * 0.1
marcia.value = marcia.value * 0.9
"
}},
{{
"name": "jan",
"value": 0,
"depends_on": ["cindy"],
"code": "
jan.value = cindy.value * 0.05
cindy.value = cindy.value * 0.9
"
}}
]
}}
Story: {narrative_input}
# JSON:
""".strip()
+ "\n"
)
# fmt: on

@ -0,0 +1,59 @@
# flake8: noqa E501
# fmt: off
template = (
"""
Transform the hypothetical whatif statement into JSON. Don't guess at any of the parts. Write NONE if you are unsure.
{format_instructions}
statement: if cindy's pet count was 4
# JSON:
{{
"entity_settings" : [
{{ "name": "cindy", "attribute": "pet_count", "value": "4" }}
]
}}
statement: Let's say boris has ten dollars and Bill has 20 dollars.
# JSON:
{{
"entity_settings" : [
{{ "name": "boris", "attribute": "dollars", "value": "10" }},
{{ "name": "bill", "attribute": "dollars", "value": "20" }}
]
}}
Statement: {narrative_input}
# JSON:
""".strip()
+ "\n\n\n"
)
# fmt: on

@ -0,0 +1,79 @@
# flake8: noqa E501
# fmt: off
template = (
"""
Split the given text into three parts: the question, the story_hypothetical, and the logic. Don't guess at any of the parts. Write NONE if you are unsure.
{format_instructions}
Q: Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy. If Cindy has four pets, how many total pets do the three have?
# JSON
{{
"story_outcome_question": "how many total pets do the three have?",
"story_hypothetical": "If Cindy has four pets",
"story_plot": "Boris has seven times the number of pets as Marcia. Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy."
}}
Q: boris gives ten percent of his money to marcia. marcia gives ten
percent of her money to andy. If boris has 100 dollars, how much money
will andy have?
# JSON
{{
"story_outcome_question": "how much money will andy have?",
"story_hypothetical": "If boris has 100 dollars"
"story_plot": "boris gives ten percent of his money to marcia. marcia gives ten percent of her money to andy."
}}
Q: boris gives ten percent of his candy to marcia. marcia gives ten
percent of her candy to andy. If boris has 100 pounds of candy and marcia has
200 pounds of candy, then how many pounds of candy will andy have?
# JSON
{{
"story_outcome_question": "how many pounds of candy will andy have?",
"story_hypothetical": "If boris has 100 pounds of candy and marcia has 200 pounds of candy"
"story_plot": "boris gives ten percent of his candy to marcia. marcia gives ten percent of her candy to andy."
}}
Q: {narrative_input}
# JSON
""".strip()
+ "\n\n\n"
)
# fmt: on

@ -0,0 +1,270 @@
# flake8: noqa E501
# fmt: off
template = (
"""
Transform the narrative_input into an SQL expression. If you are
unsure, then do not guess, instead add a llm_error_msg that explains why you are unsure.
{format_instructions}
narrative_input: how much money will boris have?
# JSON:
{{
"narrative_input": "how much money will boris have?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
}}
narrative_input: How much money does ted have?
# JSON:
{{
"narrative_input": "How much money does ted have?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'ted'"
}}
narrative_input: what is the sum of pet count for all the people?
# JSON:
{{
"narrative_input": "what is the sum of pet count for all the people?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df"
}}
narrative_input: what's the average of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the average of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT AVG(value) FROM df"
}}
narrative_input: what's the maximum of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the maximum of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT MAX(value) FROM df"
}}
narrative_input: what's the minimum of the pet counts for all the people?
# JSON:
{{
"narrative_input": "what's the minimum of the pet counts for all the people?",
"llm_error_msg": "",
"expression": "SELECT MIN(value) FROM df"
}}
narrative_input: what's the number of people with pet counts greater than 10?
# JSON:
{{
"narrative_input": "what's the number of people with pet counts greater than 10?",
"llm_error_msg": "",
"expression": "SELECT COUNT(*) FROM df WHERE value > 10"
}}
narrative_input: what's the pet count for boris?
# JSON:
{{
"narrative_input": "what's the pet count for boris?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name = 'boris'"
}}
narrative_input: what's the pet count for cindy and marcia?
# JSON:
{{
"narrative_input": "what's the pet count for cindy and marcia?",
"llm_error_msg": "",
"expression": "SELECT name, value FROM df WHERE name IN ('cindy', 'marcia')"
}}
narrative_input: what's the total pet count for cindy and marcia?
# JSON:
{{
"narrative_input": "what's the total pet count for cindy and marcia?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('cindy', 'marcia')"
}}
narrative_input: what's the total pet count for TED?
# JSON:
{{
"narrative_input": "what's the total pet count for TED?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name = 'TED'"
}}
narrative_input: what's the total dollar count for TED and cindy?
# JSON:
{{
"narrative_input": "what's the total dollar count for TED and cindy?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
}}
narrative_input: what's the total pet count for TED and cindy?
# JSON:
{{
"narrative_input": "what's the total pet count for TED and cindy?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df WHERE name IN ('TED', 'cindy')"
}}
narrative_input: what's the best for TED and cindy?
# JSON:
{{
"narrative_input": "what's the best for TED and cindy?",
"llm_error_msg": "ambiguous narrative_input, not sure what 'best' means",
"expression": ""
}}
narrative_input: what's the value?
# JSON:
{{
"narrative_input": "what's the value?",
"llm_error_msg": "ambiguous narrative_input, not sure what entity is being asked about",
"expression": ""
}}
narrative_input: how many total pets do the three have?
# JSON:
{{
"narrative_input": "how many total pets do the three have?",
"llm_error_msg": "",
"expression": "SELECT SUM(value) FROM df"
}}
narrative_input: {narrative_input}
# JSON:
""".strip()
+ "\n\n\n"
)
# fmt: on

@ -122,3 +122,48 @@ class NetworkxEntityGraph:
def clear(self) -> None:
"""Clear the graph."""
self._graph.clear()
def get_topological_sort(self) -> List[str]:
"""Get a list of entity names in the graph sorted by causal dependence."""
import networkx as nx
return list(nx.topological_sort(self._graph))
def draw_graphviz(self, **kwargs: Any) -> None:
"""
Provides better drawing
Usage in a jupyter notebook:
>>> from IPython.display import SVG
>>> self.draw_graphviz_svg(layout="dot", filename="web.svg")
>>> SVG('web.svg')
"""
from networkx.drawing.nx_agraph import to_agraph
try:
import pygraphviz # noqa: F401
except ImportError as e:
if e.name == "_graphviz":
"""
>>> e.msg # pygraphviz throws this error
ImportError: libcgraph.so.6: cannot open shared object file
"""
raise ImportError(
"Could not import graphviz debian package. "
"Please install it with:"
"`sudo apt-get update`"
"`sudo apt-get install graphviz graphviz-dev`"
)
else:
raise ImportError(
"Could not import pygraphviz python package. "
"Please install it with:"
"`pip install pygraphviz`."
)
graph = to_agraph(self._graph) # --> pygraphviz.agraph.AGraph
# pygraphviz.github.io/documentation/stable/tutorial.html#layout-and-drawing
graph.layout(prog=kwargs.get("prog", "dot"))
graph.draw(kwargs.get("path", "graph.svg"))

@ -0,0 +1,554 @@
"""Test CPAL chain."""
import json
import unittest
from typing import Type
from unittest import mock
import pydantic
import pytest
from langchain import OpenAI
from langchain.experimental.cpal.base import (
CausalChain,
CPALChain,
InterventionChain,
NarrativeChain,
QueryChain,
)
from langchain.experimental.cpal.constants import Constant
from langchain.experimental.cpal.models import (
CausalModel,
EntityModel,
EntitySettingModel,
InterventionModel,
NarrativeModel,
QueryModel,
)
from langchain.experimental.cpal.templates.univariate.causal import (
template as causal_template,
)
from langchain.experimental.cpal.templates.univariate.intervention import (
template as intervention_template,
)
from langchain.experimental.cpal.templates.univariate.narrative import (
template as narrative_template,
)
from langchain.experimental.cpal.templates.univariate.query import (
template as query_template,
)
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts.prompt import PromptTemplate
from tests.unit_tests.llms.fake_llm import FakeLLM
class TestUnitCPALChain_MathWordProblems(unittest.TestCase):
"""Unit Test the CPAL chain and its component chains on math word problems.
These tests can't run in the standard unit test directory because of
this issue, https://github.com/hwchase17/langchain/issues/7451
"""
def setUp(self) -> None:
self.fake_llm = self.make_fake_llm()
def make_fake_llm(self) -> FakeLLM:
"""
Fake LLM service for testing CPAL chain and its components chains
on univariate math examples.
"""
class LLMMockData(pydantic.BaseModel):
question: str
completion: str
template: str
data_model: Type[pydantic.BaseModel]
@property
def prompt(self) -> str:
"""Create LLM prompt with the question."""
prompt_template = PromptTemplate(
input_variables=[Constant.narrative_input.value],
template=self.template,
partial_variables={
"format_instructions": PydanticOutputParser(
pydantic_object=self.data_model
).get_format_instructions()
},
)
prompt = prompt_template.format(narrative_input=self.question)
return prompt
narrative = LLMMockData(
**{
"question": (
"jan has three times the number of pets as marcia. "
"marcia has two more pets than cindy."
"if cindy has ten pets, how many pets does jan have? "
),
"completion": json.dumps(
{
"story_outcome_question": "how many pets does jan have? ",
"story_hypothetical": "if cindy has ten pets",
"story_plot": "jan has three times the number of pets as marcia. marcia has two more pets than cindy.", # noqa: E501
}
),
"template": narrative_template,
"data_model": NarrativeModel,
}
)
causal_model = LLMMockData(
**{
"question": (
"jan has three times the number of pets as marcia. "
"marcia has two more pets than cindy."
),
"completion": (
"\n"
"{\n"
' "attribute": "pet_count",\n'
' "entities": [\n'
" {\n"
' "name": "cindy",\n'
' "value": 0,\n'
' "depends_on": [],\n'
' "code": "pass"\n'
" },\n"
" {\n"
' "name": "marcia",\n'
' "value": 0,\n'
' "depends_on": ["cindy"],\n'
' "code": "marcia.value = cindy.value + 2"\n'
" },\n"
" {\n"
' "name": "jan",\n'
' "value": 0,\n'
' "depends_on": ["marcia"],\n'
' "code": "jan.value = marcia.value * 3"\n'
" }\n"
" ]\n"
"}"
),
"template": causal_template,
"data_model": CausalModel,
}
)
intervention = LLMMockData(
**{
"question": ("if cindy has ten pets"),
"completion": (
"{\n"
' "entity_settings" : [\n'
' { "name": "cindy", "attribute": "pet_count", "value": "10" }\n' # noqa: E501
" ]\n"
"}"
),
"template": intervention_template,
"data_model": InterventionModel,
}
)
query = LLMMockData(
**{
"question": ("how many pets does jan have? "),
"completion": (
"{\n"
' "narrative_input": "how many pets does jan have? ",\n'
' "llm_error_msg": "",\n'
' "expression": "SELECT name, value FROM df WHERE name = \'jan\'"\n' # noqa: E501
"}"
),
"template": query_template,
"data_model": QueryModel,
}
)
fake_llm = FakeLLM()
fake_llm.queries = {}
for mock_data in [narrative, causal_model, intervention, query]:
fake_llm.queries.update({mock_data.prompt: mock_data.completion})
return fake_llm
def test_narrative_chain(self) -> None:
"""Test narrative chain returns the three main elements of the causal
narrative as a pydantic object.
"""
narrative_chain = NarrativeChain.from_univariate_prompt(llm=self.fake_llm)
output = narrative_chain(
(
"jan has three times the number of pets as marcia. "
"marcia has two more pets than cindy."
"if cindy has ten pets, how many pets does jan have? "
)
)
expected_output = {
"chain_answer": None,
"chain_data": NarrativeModel(
story_outcome_question="how many pets does jan have? ",
story_hypothetical="if cindy has ten pets",
story_plot="jan has three times the number of pets as marcia. marcia has two more pets than cindy.", # noqa: E501
),
"narrative_input": "jan has three times the number of pets as marcia. marcia " # noqa: E501
"has two more pets than cindy.if cindy has ten pets, how "
"many pets does jan have? ",
}
assert output == expected_output
def test_causal_chain(self) -> None:
"""
Test causal chain returns a DAG as a pydantic object.
"""
causal_chain = CausalChain.from_univariate_prompt(llm=self.fake_llm)
output = causal_chain(
(
"jan has three times the number of pets as "
"marcia. marcia has two more pets than cindy."
)
)
expected_output = {
"chain_answer": None,
"chain_data": CausalModel(
attribute="pet_count",
entities=[
EntityModel(name="cindy", code="pass", value=0.0, depends_on=[]),
EntityModel(
name="marcia",
code="marcia.value = cindy.value + 2",
value=0.0,
depends_on=["cindy"],
),
EntityModel(
name="jan",
code="jan.value = marcia.value * 3",
value=0.0,
depends_on=["marcia"],
),
],
),
"narrative_input": "jan has three times the number of pets as marcia. marcia " # noqa: E501
"has two more pets than cindy.",
}
assert output == expected_output
def test_intervention_chain(self) -> None:
"""
Test intervention chain correctly transforms
the LLM's text completion into a setting-like object.
"""
intervention_chain = InterventionChain.from_univariate_prompt(llm=self.fake_llm)
output = intervention_chain("if cindy has ten pets")
expected_output = {
"chain_answer": None,
"chain_data": InterventionModel(
entity_settings=[
EntitySettingModel(name="cindy", attribute="pet_count", value=10),
]
),
"narrative_input": "if cindy has ten pets",
}
assert output == expected_output
def test_query_chain(self) -> None:
"""
Test query chain correctly transforms
the LLM's text completion into a query-like object.
"""
query_chain = QueryChain.from_univariate_prompt(llm=self.fake_llm)
output = query_chain("how many pets does jan have? ")
expected_output = {
"chain_answer": None,
"chain_data": QueryModel(
narrative_input="how many pets does jan have? ",
llm_error_msg="",
expression="SELECT name, value FROM df WHERE name = 'jan'",
),
"narrative_input": "how many pets does jan have? ",
}
assert output == expected_output
def test_cpal_chain(self) -> None:
"""
patch required since `networkx` package is not part of unit test environment
"""
with mock.patch(
"langchain.experimental.cpal.models.NetworkxEntityGraph"
) as mock_networkx:
graph_instance = mock_networkx.return_value
graph_instance.get_topological_sort.return_value = [
"cindy",
"marcia",
"jan",
]
cpal_chain = CPALChain.from_univariate_prompt(
llm=self.fake_llm, verbose=True
)
cpal_chain.run(
(
"jan has three times the number of pets as "
"marcia. marcia has two more pets than cindy."
"if cindy has ten pets, how many pets does jan have? "
)
)
class TestCPALChain_MathWordProblems(unittest.TestCase):
"""Test the CPAL chain and its component chains on math word problems."""
def test_causal_chain(self) -> None:
"""Test CausalChain can translate a narrative's plot into a causal model
containing operations linked by a DAG."""
llm = OpenAI(temperature=0, max_tokens=512)
casual_chain = CausalChain.from_univariate_prompt(llm)
narrative_plot = (
"Jan has three times the number of pets as Marcia. "
"Marcia has two more pets than Cindy. "
)
output = casual_chain(narrative_plot)
expected_output = {
"chain_answer": None,
"chain_data": CausalModel(
attribute="pet_count",
entities=[
EntityModel(name="cindy", code="pass", value=0.0, depends_on=[]),
EntityModel(
name="marcia",
code="marcia.value = cindy.value + 2",
value=0.0,
depends_on=["cindy"],
),
EntityModel(
name="jan",
code="jan.value = marcia.value * 3",
value=0.0,
depends_on=["marcia"],
),
],
),
"narrative_input": "Jan has three times the number of pets as Marcia. Marcia " # noqa: E501
"has two more pets than Cindy. ",
}
self.assertDictEqual(output, expected_output)
self.assertEqual(
isinstance(output[Constant.chain_data.value], CausalModel), True
)
def test_intervention_chain(self) -> None:
"""Test InterventionChain translates a hypothetical into a new value setting."""
llm = OpenAI(temperature=0, max_tokens=512)
story_conditions_chain = InterventionChain.from_univariate_prompt(llm)
question = "if cindy has ten pets"
data = story_conditions_chain(question)[Constant.chain_data.value]
self.assertEqual(type(data), InterventionModel)
def test_intervention_chain_2(self) -> None:
"""Test InterventionChain translates a hypothetical into a new value setting."""
llm = OpenAI(temperature=0, max_tokens=512)
story_conditions_chain = InterventionChain.from_univariate_prompt(llm)
narrative_condition = "What if Cindy has ten pets and Boris has 5 pets? "
data = story_conditions_chain(narrative_condition)[Constant.chain_data.value]
self.assertEqual(type(data), InterventionModel)
def test_query_chain(self) -> None:
"""Test QueryChain translates a question into a query expression."""
llm = OpenAI(temperature=0, max_tokens=512)
query_chain = QueryChain.from_univariate_prompt(llm)
narrative_question = "How many pets will Marcia end up with? "
data = query_chain(narrative_question)[Constant.chain_data.value]
self.assertEqual(type(data), QueryModel)
def test_narrative_chain(self) -> None:
"""Test NarrativeChain decomposes a human's narrative into three story elements:
- causal model
- intervention model
- query model
"""
narrative = (
"Jan has three times the number of pets as Marcia. "
"Marcia has two more pets than Cindy. "
"If Cindy has ten pets, how many pets does Jan have? "
)
llm = OpenAI(temperature=0, max_tokens=512)
narrative_chain = NarrativeChain.from_univariate_prompt(llm)
data = narrative_chain(narrative)[Constant.chain_data.value]
self.assertEqual(type(data), NarrativeModel)
out = narrative_chain(narrative)
expected_narrative_out = {
"chain_answer": None,
"chain_data": NarrativeModel(
story_outcome_question="how many pets does Jan have?",
story_hypothetical="If Cindy has ten pets",
story_plot="Jan has three times the number of pets as Marcia. Marcia has two more pets than Cindy.", # noqa: E501
),
"narrative_input": "Jan has three times the number of pets as Marcia. Marcia " # noqa: E501
"has two more pets than Cindy. If Cindy has ten pets, how "
"many pets does Jan have? ",
}
self.assertDictEqual(out, expected_narrative_out)
def test_against_pal_chain_doc(self) -> None:
"""
Test CPAL chain against the first example in the PAL chain notebook doc:
https://github.com/hwchase17/langchain/blob/master/docs/extras/modules/chains/additional/pal.ipynb
"""
narrative_input = (
"Jan has three times the number of pets as Marcia."
" Marcia has two more pets than Cindy."
" If Cindy has four pets, how many total pets do the three have?"
)
llm = OpenAI(temperature=0, max_tokens=512)
cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)
answer = cpal_chain.run(narrative_input)
"""
>>> story._outcome_table
name code value depends_on
0 cindy pass 4.0 []
1 marcia marcia.value = cindy.value + 2 6.0 [cindy]
2 jan jan.value = marcia.value * 3 18.0 [marcia]
"""
self.assertEqual(answer, 28.0)
def test_simple(self) -> None:
"""
Given a simple math word problem here we are test and illustrate the
the data structures that are produced by the CPAL chain.
"""
narrative_input = (
"jan has three times the number of pets as marcia."
"marcia has two more pets than cindy."
"If cindy has ten pets, how many pets does jan have?"
)
llm = OpenAI(temperature=0, max_tokens=512)
cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)
output = cpal_chain(narrative_input)
data = output[Constant.chain_data.value]
expected_output = {
"causal_operations": {
"attribute": "pet_count",
"entities": [
{"code": "pass", "depends_on": [], "name": "cindy", "value": 10.0},
{
"code": "marcia.value = cindy.value + 2",
"depends_on": ["cindy"],
"name": "marcia",
"value": 12.0,
},
{
"code": "jan.value = marcia.value * 3",
"depends_on": ["marcia"],
"name": "jan",
"value": 36.0,
},
],
},
"intervention": {
"entity_settings": [
{"attribute": "pet_count", "name": "cindy", "value": 10.0}
],
"system_settings": None,
},
"query": {
"expression": "SELECT name, value FROM df WHERE name = 'jan'",
"llm_error_msg": "",
"question": "how many pets does jan have?",
},
}
self.assertDictEqual(data.dict(), expected_output)
"""
Illustrate the query model's result table as a printed pandas dataframe
>>> data._outcome_table
name code value depends_on
0 cindy pass 10.0 []
1 marcia marcia.value = cindy.value + 2 12.0 [cindy]
2 jan jan.value = marcia.value * 3 36.0 [marcia]
"""
expected_output = {
"code": {
0: "pass",
1: "marcia.value = cindy.value + 2",
2: "jan.value = marcia.value * 3",
},
"depends_on": {0: [], 1: ["cindy"], 2: ["marcia"]},
"name": {0: "cindy", 1: "marcia", 2: "jan"},
"value": {0: 10.0, 1: 12.0, 2: 36.0},
}
self.assertDictEqual(data._outcome_table.to_dict(), expected_output)
expected_output = {"name": {0: "jan"}, "value": {0: 36.0}}
self.assertDictEqual(data.query._result_table.to_dict(), expected_output)
# TODO: use an LLM chain to translate numbers to words
df = data.query._result_table
expr = "name == 'jan'"
answer = df.query(expr).iloc[0]["value"]
self.assertEqual(float(answer), 36.0)
def test_hallucinating(self) -> None:
"""
Test CPAL approach does not hallucinate when given
an invalid entity in the question.
The PAL chain would hallucinates here!
"""
narrative_input = (
"Jan has three times the number of pets as Marcia."
"Marcia has two more pets than Cindy."
"If Cindy has ten pets, how many pets does Barak have?"
)
llm = OpenAI(temperature=0, max_tokens=512)
cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)
with pytest.raises(Exception) as e_info:
print(e_info)
cpal_chain.run(narrative_input)
def test_causal_mediator(self) -> None:
"""
Test CPAL approach on causal mediator.
"""
narrative_input = (
"jan has three times the number of pets as marcia."
"marcia has two more pets than cindy."
"If marcia has ten pets, how many pets does jan have?"
)
llm = OpenAI(temperature=0, max_tokens=512)
cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)
answer = cpal_chain.run(narrative_input)
self.assertEqual(answer, 30.0)
@pytest.mark.skip(reason="requires manual install of debian and py packages")
def test_draw(self) -> None:
"""
Test CPAL chain can draw its resulting DAG.
"""
import os
narrative_input = (
"Jan has three times the number of pets as Marcia."
"Marcia has two more pets than Cindy."
"If Marcia has ten pets, how many pets does Jan have?"
)
llm = OpenAI(temperature=0, max_tokens=512)
cpal_chain = CPALChain.from_univariate_prompt(llm=llm, verbose=True)
cpal_chain.run(narrative_input)
path = "graph.svg"
cpal_chain.draw(path=path)
self.assertTrue(os.path.exists(path))
Loading…
Cancel
Save