Convert Chain to a Chain Factory (#4605)

## Change Chain argument in client to accept a chain factory

The `run_over_dataset` functionality seeks to treat each iteration of an
example as an independent trial.
Chains have memory, so it's easier to permit this type of behavior if we
accept a factory method rather than the chain object directly.

There's still corner cases / UX pains people will likely run into, like:
- Caching may cause issues
- if memory is persisted to a shared object (e.g., same redis queue) ,
this could impact what is retrieved
- If we're running the async methods with concurrency using local
models, if someone naively instantiates the chain and loads each time,
it could lead to tons of disk I/O or OOM
textloader_autodetect_encodings
Zander Chase 1 year ago committed by GitHub
parent ed0d557ede
commit 0c6ed657ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,6 +40,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
def _get_link_stem(url: str) -> str:
scheme = urlsplit(url).scheme
@ -99,6 +101,21 @@ class LangChainPlusClient(BaseSettings):
raise ValueError("No seeded tenant found")
return results[0]["id"]
@staticmethod
def _get_session_name(
session_name: Optional[str],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
dataset_name: str,
) -> str:
if session_name is not None:
return session_name
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
if isinstance(llm_or_chain_factory, BaseLanguageModel):
model_name = llm_or_chain_factory.__class__.__name__
else:
model_name = llm_or_chain_factory().__class__.__name__
return f"{dataset_name}-{model_name}-{current_time}"
def _repr_html_(self) -> str:
"""Return an HTML representation of the instance with a link to the URL."""
link = _get_link_stem(self.api_url)
@ -312,7 +329,7 @@ class LangChainPlusClient(BaseSettings):
async def _arun_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
llm_or_chain: Union[Chain, BaseLanguageModel],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain asynchronously."""
@ -321,12 +338,13 @@ class LangChainPlusClient(BaseSettings):
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain, BaseLanguageModel):
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = await LangChainPlusClient._arun_llm(
llm_or_chain, example.inputs, langchain_tracer
llm_or_chain_factory, example.inputs, langchain_tracer
)
else:
output = await llm_or_chain.arun(
chain = llm_or_chain_factory()
output = await chain.arun(
example.inputs, callbacks=[langchain_tracer]
)
outputs.append(output)
@ -388,7 +406,8 @@ class LangChainPlusClient(BaseSettings):
async def arun_on_dataset(
self,
dataset_name: str,
llm_or_chain: Union[Chain, BaseLanguageModel],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
concurrency_level: int = 5,
num_repetitions: int = 1,
session_name: Optional[str] = None,
@ -399,7 +418,9 @@ class LangChainPlusClient(BaseSettings):
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain: Chain or language model to run over the dataset.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
concurrency_level: The number of async tasks to run concurrently.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
@ -411,11 +432,9 @@ class LangChainPlusClient(BaseSettings):
Returns:
A dictionary mapping example ids to the model outputs.
"""
if session_name is None:
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
session_name = (
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
)
session_name = LangChainPlusClient._get_session_name(
session_name, llm_or_chain_factory, dataset_name
)
dataset = self.read_dataset(dataset_name=dataset_name)
examples = self.list_examples(dataset_id=str(dataset.id))
results: Dict[str, List[Any]] = {}
@ -427,7 +446,7 @@ class LangChainPlusClient(BaseSettings):
result = await LangChainPlusClient._arun_llm_or_chain(
example,
tracer,
llm_or_chain,
llm_or_chain_factory,
num_repetitions,
)
results[str(example.id)] = result
@ -474,7 +493,7 @@ class LangChainPlusClient(BaseSettings):
def run_llm_or_chain(
example: Example,
langchain_tracer: LangChainTracerV2,
llm_or_chain: Union[Chain, BaseLanguageModel],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
n_repetitions: int,
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
"""Run the chain synchronously."""
@ -483,14 +502,13 @@ class LangChainPlusClient(BaseSettings):
outputs = []
for _ in range(n_repetitions):
try:
if isinstance(llm_or_chain, BaseLanguageModel):
if isinstance(llm_or_chain_factory, BaseLanguageModel):
output: Any = LangChainPlusClient.run_llm(
llm_or_chain, example.inputs, langchain_tracer
llm_or_chain_factory, example.inputs, langchain_tracer
)
else:
output = llm_or_chain.run(
example.inputs, callbacks=[langchain_tracer]
)
chain = llm_or_chain_factory()
output = chain.run(example.inputs, callbacks=[langchain_tracer])
outputs.append(output)
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
@ -502,7 +520,8 @@ class LangChainPlusClient(BaseSettings):
def run_on_dataset(
self,
dataset_name: str,
llm_or_chain: Union[Chain, BaseLanguageModel],
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
*,
num_repetitions: int = 1,
session_name: Optional[str] = None,
verbose: bool = False,
@ -511,7 +530,9 @@ class LangChainPlusClient(BaseSettings):
Args:
dataset_name: Name of the dataset to run the chain on.
llm_or_chain: Chain or language model to run over the dataset.
llm_or_chain_factory: Language model or Chain constructor to run
over the dataset. The Chain constructor is used to permit
independent calls on each example without carrying over state.
concurrency_level: Number of async workers to run in parallel.
num_repetitions: Number of times to run the model on each example.
This is useful when testing success rates or generating confidence
@ -523,11 +544,9 @@ class LangChainPlusClient(BaseSettings):
Returns:
A dictionary mapping example ids to the model outputs.
"""
if session_name is None:
current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
session_name = (
f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}"
)
session_name = LangChainPlusClient._get_session_name(
session_name, llm_or_chain_factory, dataset_name
)
dataset = self.read_dataset(dataset_name=dataset_name)
examples = list(self.list_examples(dataset_id=str(dataset.id)))
results: Dict[str, Any] = {}
@ -539,7 +558,7 @@ class LangChainPlusClient(BaseSettings):
result = self.run_llm_or_chain(
example,
tracer,
llm_or_chain,
llm_or_chain_factory,
num_repetitions,
)
if verbose:

@ -133,21 +133,19 @@
"output_type": "stream",
"text": [
"The current population of Canada as of 2023 is 39,566,248.\n",
"Anwar Hadid's age raised to the 0.43 power is approximately 3.87.\n",
"Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n",
"LLMMathChain._evaluate(\"\n",
"(age)**0.43\n",
"\") raised error: 'age'. Please try again with a valid numerical expression\n",
"The distance between Paris and Boston is 3448 miles.\n",
"unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n",
"The distance between Paris and Boston is approximately 3448 miles.\n",
"unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n",
"LLMMathChain._evaluate(\"\n",
"(total number of points scored in the 2023 super bowl)**0.23\n",
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n",
"3 points were scored more in the 2023 Super Bowl than in the 2022 Super Bowl.\n",
"Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n",
"1.9347796717823205\n",
"81\n",
"LLMMathChain._evaluate(\"\n",
"round(0.2791714614499425, 2)\n",
"\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n"
"77\n",
"0.2791714614499425\n"
]
}
],
@ -254,12 +252,109 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"id": "60d14593-c61f-449f-a38f-772ca43707c2",
"metadata": {
"tags": []
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c34edde8de5340888b3278d1ac427417",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>input</th>\n",
" <th>output</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>How many people live in canada as of 2023?</td>\n",
" <td>approximately 38,625,801</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>who is dua lipa's boyfriend? what is his age r...</td>\n",
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>what is dua lipa's boyfriend age raised to the...</td>\n",
" <td>her boyfriend is Romain Gravas. his age raised...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>how far is it from paris to boston in miles</td>\n",
" <td>approximately 3,435 mi</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>what was the total number of points scored in ...</td>\n",
" <td>approximately 2.682651500990882</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" input \\\n",
"0 How many people live in canada as of 2023? \n",
"1 who is dua lipa's boyfriend? what is his age r... \n",
"2 what is dua lipa's boyfriend age raised to the... \n",
"3 how far is it from paris to boston in miles \n",
"4 what was the total number of points scored in ... \n",
"\n",
" output \n",
"0 approximately 38,625,801 \n",
"1 her boyfriend is Romain Gravas. his age raised... \n",
"2 her boyfriend is Romain Gravas. his age raised... \n",
"3 approximately 3,435 mi \n",
"4 approximately 2.682651500990882 "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# import pandas as pd\n",
"# from langchain.evaluation.loading import load_dataset\n",
@ -272,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "52a7ea76-79ca-4765-abf7-231e884040d6",
"metadata": {
"tags": []
@ -308,7 +403,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"id": "c2b59104-b90e-466a-b7ea-c5bd0194263b",
"metadata": {
"tags": []
@ -336,7 +431,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee",
"metadata": {
"tags": []
@ -348,7 +443,8 @@
"\u001b[0;31mSignature:\u001b[0m\n",
"\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
@ -359,7 +455,9 @@
"\n",
"Args:\n",
" dataset_name: Name of the dataset to run the chain on.\n",
" llm_or_chain: Chain or language model to run over the dataset.\n",
" llm_or_chain_factory: Language model or Chain constructor to run\n",
" over the dataset. The Chain constructor is used to permit\n",
" independent calls on each example without carrying over state.\n",
" concurrency_level: The number of async tasks to run concurrently.\n",
" num_repetitions: Number of times to run the model on each example.\n",
" This is useful when testing success rates or generating confidence\n",
@ -384,7 +482,26 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"id": "6e10f823",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Since chains can be stateful (e.g. they can have memory), we need provide\n",
"# a way to initialize a new chain for each row in the dataset. This is done\n",
"# by passing in a factory function that returns a new chain for each row.\n",
"chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n",
"\n",
"# If your chain is NOT stateful, your lambda can return the object directly\n",
"# to improve runtime performance. For example:\n",
"# chain_factory = lambda: agent"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33",
"metadata": {
"tags": []
@ -396,7 +513,9 @@
"text": [
"/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n",
" warnings.warn(\n",
"Chain failed for example 92c75ce4-f807-4d44-8f7e-027610f7fcbd. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n"
"Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n",
"(total number of points scored in the 2023 super bowl)**0.23\n",
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
]
},
{
@ -410,25 +529,23 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example 9f5d1426-3e21-4628-b5f9-d2ad354bfa8d. Error: LLMMathChain._evaluate(\"\n",
"(age ** 0.43)\n",
"\") raised error: 'age'. Please try again with a valid numerical expression\n"
"Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Processed examples: 4\r"
"Processed examples: 6\r"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Chain failed for example e480f086-6d3f-4659-8669-26316db7e772. Error: LLMMathChain._evaluate(\"\n",
"(total number of points scored in the 2023 super bowl)**0.23\n",
"\") raised error: invalid syntax. Perhaps you forgot a comma? (<expr>, line 1). Please try again with a valid numerical expression\n"
"Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n",
"(age ** 0.43)\n",
"\") raised error: 'age'. Please try again with a valid numerical expression\n"
]
},
{
@ -442,7 +559,7 @@
"source": [
"chain_results = await client.arun_on_dataset(\n",
" dataset_name=dataset_name,\n",
" llm_or_chain=agent,\n",
" llm_or_chain_factory=chain_factory,\n",
" verbose=True\n",
")\n",
"\n",
@ -463,7 +580,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"id": "136db492-d6ca-4215-96f9-439c23538241",
"metadata": {
"tags": []
@ -478,7 +595,7 @@
"LangChainPlusClient (API URL: http://localhost:8000)"
]
},
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@ -508,7 +625,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"id": "64490d7c-9a18-49ed-a3ac-36049c522cb4",
"metadata": {
"tags": []
@ -524,7 +641,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "31a576ae98634602b349046ec0821c0d",
"model_id": "047a8094367f43938f74e863b3e01711",
"version_major": 2,
"version_minor": 0
},
@ -606,7 +723,7 @@
"4 [{'data': {'content': 'Here is the topic for a... "
]
},
"execution_count": 8,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -622,7 +739,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"id": "348acd86-a927-4d60-8d52-02e64585e4fc",
"metadata": {
"tags": []
@ -652,7 +769,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"id": "a69dd183-ad5e-473d-b631-db90706e837f",
"metadata": {
"tags": []
@ -691,7 +808,7 @@
"source": [
"chat_model_results = await client.arun_on_dataset(\n",
" dataset_name=chat_dataset_name,\n",
" llm_or_chain=chat_model,\n",
" llm_or_chain_factory=chat_model,\n",
" concurrency_level=5, # Optional, sets the number of examples to run at a time\n",
" num_repetitions=3,\n",
" verbose=True\n",
@ -936,7 +1053,7 @@
"# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n",
"completions_model_results = client.run_on_dataset(\n",
" dataset_name=completions_dataset_name,\n",
" llm_or_chain=llm,\n",
" llm_or_chain_factory=llm,\n",
" num_repetitions=1,\n",
" verbose=True\n",
")"

@ -218,7 +218,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
num_repetitions = 3
results = await client.arun_on_dataset(
dataset_name="test",
llm_or_chain=chain,
llm_or_chain_factory=lambda: chain,
concurrency_level=2,
session_name="test_session",
num_repetitions=num_repetitions,

Loading…
Cancel
Save