From 0c6ed657efc4488957cb1fcb913b2d0d1b1cab00 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Sat, 13 May 2023 02:13:21 +0000 Subject: [PATCH] Convert Chain to a Chain Factory (#4605) ## Change Chain argument in client to accept a chain factory The `run_over_dataset` functionality seeks to treat each iteration of an example as an independent trial. Chains have memory, so it's easier to permit this type of behavior if we accept a factory method rather than the chain object directly. There's still corner cases / UX pains people will likely run into, like: - Caching may cause issues - if memory is persisted to a shared object (e.g., same redis queue) , this could impact what is retrieved - If we're running the async methods with concurrency using local models, if someone naively instantiates the chain and loads each time, it could lead to tons of disk I/O or OOM --- langchain/client/langchain.py | 71 ++++--- .../client/tracing_datasets.ipynb | 185 ++++++++++++++---- tests/unit_tests/client/test_langchain.py | 2 +- 3 files changed, 197 insertions(+), 61 deletions(-) diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 9d65716c..8df0219f 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -40,6 +40,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] + def _get_link_stem(url: str) -> str: scheme = urlsplit(url).scheme @@ -99,6 +101,21 @@ class LangChainPlusClient(BaseSettings): raise ValueError("No seeded tenant found") return results[0]["id"] + @staticmethod + def _get_session_name( + session_name: Optional[str], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + dataset_name: str, + ) -> str: + if session_name is not None: + return session_name + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + if isinstance(llm_or_chain_factory, BaseLanguageModel): + model_name = llm_or_chain_factory.__class__.__name__ + else: + model_name = llm_or_chain_factory().__class__.__name__ + return f"{dataset_name}-{model_name}-{current_time}" + def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL.""" link = _get_link_stem(self.api_url) @@ -312,7 +329,7 @@ class LangChainPlusClient(BaseSettings): async def _arun_llm_or_chain( example: Example, langchain_tracer: LangChainTracerV2, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain asynchronously.""" @@ -321,12 +338,13 @@ class LangChainPlusClient(BaseSettings): outputs = [] for _ in range(n_repetitions): try: - if isinstance(llm_or_chain, BaseLanguageModel): + if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = await LangChainPlusClient._arun_llm( - llm_or_chain, example.inputs, langchain_tracer + llm_or_chain_factory, example.inputs, langchain_tracer ) else: - output = await llm_or_chain.arun( + chain = llm_or_chain_factory() + output = await chain.arun( example.inputs, callbacks=[langchain_tracer] ) outputs.append(output) @@ -388,7 +406,8 @@ class LangChainPlusClient(BaseSettings): async def arun_on_dataset( self, dataset_name: str, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, concurrency_level: int = 5, num_repetitions: int = 1, session_name: Optional[str] = None, @@ -399,7 +418,9 @@ class LangChainPlusClient(BaseSettings): Args: dataset_name: Name of the dataset to run the chain on. - llm_or_chain: Chain or language model to run over the dataset. + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. concurrency_level: The number of async tasks to run concurrently. num_repetitions: Number of times to run the model on each example. This is useful when testing success rates or generating confidence @@ -411,11 +432,9 @@ class LangChainPlusClient(BaseSettings): Returns: A dictionary mapping example ids to the model outputs. """ - if session_name is None: - current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - session_name = ( - f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" - ) + session_name = LangChainPlusClient._get_session_name( + session_name, llm_or_chain_factory, dataset_name + ) dataset = self.read_dataset(dataset_name=dataset_name) examples = self.list_examples(dataset_id=str(dataset.id)) results: Dict[str, List[Any]] = {} @@ -427,7 +446,7 @@ class LangChainPlusClient(BaseSettings): result = await LangChainPlusClient._arun_llm_or_chain( example, tracer, - llm_or_chain, + llm_or_chain_factory, num_repetitions, ) results[str(example.id)] = result @@ -474,7 +493,7 @@ class LangChainPlusClient(BaseSettings): def run_llm_or_chain( example: Example, langchain_tracer: LangChainTracerV2, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain synchronously.""" @@ -483,14 +502,13 @@ class LangChainPlusClient(BaseSettings): outputs = [] for _ in range(n_repetitions): try: - if isinstance(llm_or_chain, BaseLanguageModel): + if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = LangChainPlusClient.run_llm( - llm_or_chain, example.inputs, langchain_tracer + llm_or_chain_factory, example.inputs, langchain_tracer ) else: - output = llm_or_chain.run( - example.inputs, callbacks=[langchain_tracer] - ) + chain = llm_or_chain_factory() + output = chain.run(example.inputs, callbacks=[langchain_tracer]) outputs.append(output) except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") @@ -502,7 +520,8 @@ class LangChainPlusClient(BaseSettings): def run_on_dataset( self, dataset_name: str, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, num_repetitions: int = 1, session_name: Optional[str] = None, verbose: bool = False, @@ -511,7 +530,9 @@ class LangChainPlusClient(BaseSettings): Args: dataset_name: Name of the dataset to run the chain on. - llm_or_chain: Chain or language model to run over the dataset. + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. concurrency_level: Number of async workers to run in parallel. num_repetitions: Number of times to run the model on each example. This is useful when testing success rates or generating confidence @@ -523,11 +544,9 @@ class LangChainPlusClient(BaseSettings): Returns: A dictionary mapping example ids to the model outputs. """ - if session_name is None: - current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - session_name = ( - f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" - ) + session_name = LangChainPlusClient._get_session_name( + session_name, llm_or_chain_factory, dataset_name + ) dataset = self.read_dataset(dataset_name=dataset_name) examples = list(self.list_examples(dataset_id=str(dataset.id))) results: Dict[str, Any] = {} @@ -539,7 +558,7 @@ class LangChainPlusClient(BaseSettings): result = self.run_llm_or_chain( example, tracer, - llm_or_chain, + llm_or_chain_factory, num_repetitions, ) if verbose: diff --git a/langchain/experimental/client/tracing_datasets.ipynb b/langchain/experimental/client/tracing_datasets.ipynb index 7330d663..dddbc52d 100644 --- a/langchain/experimental/client/tracing_datasets.ipynb +++ b/langchain/experimental/client/tracing_datasets.ipynb @@ -133,21 +133,19 @@ "output_type": "stream", "text": [ "The current population of Canada as of 2023 is 39,566,248.\n", - "Anwar Hadid's age raised to the 0.43 power is approximately 3.87.\n", + "Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n", "LLMMathChain._evaluate(\"\n", "(age)**0.43\n", "\") raised error: 'age'. Please try again with a valid numerical expression\n", - "The distance between Paris and Boston is 3448 miles.\n", - "unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n", + "The distance between Paris and Boston is approximately 3448 miles.\n", + "unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n", "LLMMathChain._evaluate(\"\n", "(total number of points scored in the 2023 super bowl)**0.23\n", "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n", - "3 points were scored more in the 2023 Super Bowl than in the 2022 Super Bowl.\n", + "Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n", "1.9347796717823205\n", - "81\n", - "LLMMathChain._evaluate(\"\n", - "round(0.2791714614499425, 2)\n", - "\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n" + "77\n", + "0.2791714614499425\n" ] } ], @@ -254,12 +252,109 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "id": "60d14593-c61f-449f-a38f-772ca43707c2", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c34edde8de5340888b3278d1ac427417", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
inputoutput
0How many people live in canada as of 2023?approximately 38,625,801
1who is dua lipa's boyfriend? what is his age r...her boyfriend is Romain Gravas. his age raised...
2what is dua lipa's boyfriend age raised to the...her boyfriend is Romain Gravas. his age raised...
3how far is it from paris to boston in milesapproximately 3,435 mi
4what was the total number of points scored in ...approximately 2.682651500990882
\n", + "" + ], + "text/plain": [ + " input \\\n", + "0 How many people live in canada as of 2023? \n", + "1 who is dua lipa's boyfriend? what is his age r... \n", + "2 what is dua lipa's boyfriend age raised to the... \n", + "3 how far is it from paris to boston in miles \n", + "4 what was the total number of points scored in ... \n", + "\n", + " output \n", + "0 approximately 38,625,801 \n", + "1 her boyfriend is Romain Gravas. his age raised... \n", + "2 her boyfriend is Romain Gravas. his age raised... \n", + "3 approximately 3,435 mi \n", + "4 approximately 2.682651500990882 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# import pandas as pd\n", "# from langchain.evaluation.loading import load_dataset\n", @@ -272,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "52a7ea76-79ca-4765-abf7-231e884040d6", "metadata": { "tags": [] @@ -308,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "c2b59104-b90e-466a-b7ea-c5bd0194263b", "metadata": { "tags": [] @@ -336,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee", "metadata": { "tags": [] @@ -348,7 +443,8 @@ "\u001b[0;31mSignature:\u001b[0m\n", "\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", @@ -359,7 +455,9 @@ "\n", "Args:\n", " dataset_name: Name of the dataset to run the chain on.\n", - " llm_or_chain: Chain or language model to run over the dataset.\n", + " llm_or_chain_factory: Language model or Chain constructor to run\n", + " over the dataset. The Chain constructor is used to permit\n", + " independent calls on each example without carrying over state.\n", " concurrency_level: The number of async tasks to run concurrently.\n", " num_repetitions: Number of times to run the model on each example.\n", " This is useful when testing success rates or generating confidence\n", @@ -384,7 +482,26 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, + "id": "6e10f823", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Since chains can be stateful (e.g. they can have memory), we need provide\n", + "# a way to initialize a new chain for each row in the dataset. This is done\n", + "# by passing in a factory function that returns a new chain for each row.\n", + "chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n", + "\n", + "# If your chain is NOT stateful, your lambda can return the object directly\n", + "# to improve runtime performance. For example:\n", + "# chain_factory = lambda: agent" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", "metadata": { "tags": [] @@ -396,7 +513,9 @@ "text": [ "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", " warnings.warn(\n", - "Chain failed for example 92c75ce4-f807-4d44-8f7e-027610f7fcbd. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n" + "Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n", + "(total number of points scored in the 2023 super bowl)**0.23\n", + "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" ] }, { @@ -410,25 +529,23 @@ "name": "stderr", "output_type": "stream", "text": [ - "Chain failed for example 9f5d1426-3e21-4628-b5f9-d2ad354bfa8d. Error: LLMMathChain._evaluate(\"\n", - "(age ** 0.43)\n", - "\") raised error: 'age'. Please try again with a valid numerical expression\n" + "Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Processed examples: 4\r" + "Processed examples: 6\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Chain failed for example e480f086-6d3f-4659-8669-26316db7e772. Error: LLMMathChain._evaluate(\"\n", - "(total number of points scored in the 2023 super bowl)**0.23\n", - "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" + "Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n", + "(age ** 0.43)\n", + "\") raised error: 'age'. Please try again with a valid numerical expression\n" ] }, { @@ -442,7 +559,7 @@ "source": [ "chain_results = await client.arun_on_dataset(\n", " dataset_name=dataset_name,\n", - " llm_or_chain=agent,\n", + " llm_or_chain_factory=chain_factory,\n", " verbose=True\n", ")\n", "\n", @@ -463,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "136db492-d6ca-4215-96f9-439c23538241", "metadata": { "tags": [] @@ -478,7 +595,7 @@ "LangChainPlusClient (API URL: http://localhost:8000)" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -508,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "id": "64490d7c-9a18-49ed-a3ac-36049c522cb4", "metadata": { "tags": [] @@ -524,7 +641,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "31a576ae98634602b349046ec0821c0d", + "model_id": "047a8094367f43938f74e863b3e01711", "version_major": 2, "version_minor": 0 }, @@ -606,7 +723,7 @@ "4 [{'data': {'content': 'Here is the topic for a... " ] }, - "execution_count": 8, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -622,7 +739,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "348acd86-a927-4d60-8d52-02e64585e4fc", "metadata": { "tags": [] @@ -652,7 +769,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "a69dd183-ad5e-473d-b631-db90706e837f", "metadata": { "tags": [] @@ -691,7 +808,7 @@ "source": [ "chat_model_results = await client.arun_on_dataset(\n", " dataset_name=chat_dataset_name,\n", - " llm_or_chain=chat_model,\n", + " llm_or_chain_factory=chat_model,\n", " concurrency_level=5, # Optional, sets the number of examples to run at a time\n", " num_repetitions=3,\n", " verbose=True\n", @@ -936,7 +1053,7 @@ "# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n", "completions_model_results = client.run_on_dataset(\n", " dataset_name=completions_dataset_name,\n", - " llm_or_chain=llm,\n", + " llm_or_chain_factory=llm,\n", " num_repetitions=1,\n", " verbose=True\n", ")" diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index efa04971..731c4d69 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -218,7 +218,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: num_repetitions = 3 results = await client.arun_on_dataset( dataset_name="test", - llm_or_chain=chain, + llm_or_chain_factory=lambda: chain, concurrency_level=2, session_name="test_session", num_repetitions=num_repetitions,