diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 9d65716c79..8df0219fa8 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -40,6 +40,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] + def _get_link_stem(url: str) -> str: scheme = urlsplit(url).scheme @@ -99,6 +101,21 @@ class LangChainPlusClient(BaseSettings): raise ValueError("No seeded tenant found") return results[0]["id"] + @staticmethod + def _get_session_name( + session_name: Optional[str], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + dataset_name: str, + ) -> str: + if session_name is not None: + return session_name + current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + if isinstance(llm_or_chain_factory, BaseLanguageModel): + model_name = llm_or_chain_factory.__class__.__name__ + else: + model_name = llm_or_chain_factory().__class__.__name__ + return f"{dataset_name}-{model_name}-{current_time}" + def _repr_html_(self) -> str: """Return an HTML representation of the instance with a link to the URL.""" link = _get_link_stem(self.api_url) @@ -312,7 +329,7 @@ class LangChainPlusClient(BaseSettings): async def _arun_llm_or_chain( example: Example, langchain_tracer: LangChainTracerV2, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain asynchronously.""" @@ -321,12 +338,13 @@ class LangChainPlusClient(BaseSettings): outputs = [] for _ in range(n_repetitions): try: - if isinstance(llm_or_chain, BaseLanguageModel): + if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = await LangChainPlusClient._arun_llm( - llm_or_chain, example.inputs, langchain_tracer + llm_or_chain_factory, example.inputs, langchain_tracer ) else: - output = await llm_or_chain.arun( + chain = llm_or_chain_factory() + output = await chain.arun( example.inputs, callbacks=[langchain_tracer] ) outputs.append(output) @@ -388,7 +406,8 @@ class LangChainPlusClient(BaseSettings): async def arun_on_dataset( self, dataset_name: str, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, concurrency_level: int = 5, num_repetitions: int = 1, session_name: Optional[str] = None, @@ -399,7 +418,9 @@ class LangChainPlusClient(BaseSettings): Args: dataset_name: Name of the dataset to run the chain on. - llm_or_chain: Chain or language model to run over the dataset. + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. concurrency_level: The number of async tasks to run concurrently. num_repetitions: Number of times to run the model on each example. This is useful when testing success rates or generating confidence @@ -411,11 +432,9 @@ class LangChainPlusClient(BaseSettings): Returns: A dictionary mapping example ids to the model outputs. """ - if session_name is None: - current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - session_name = ( - f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" - ) + session_name = LangChainPlusClient._get_session_name( + session_name, llm_or_chain_factory, dataset_name + ) dataset = self.read_dataset(dataset_name=dataset_name) examples = self.list_examples(dataset_id=str(dataset.id)) results: Dict[str, List[Any]] = {} @@ -427,7 +446,7 @@ class LangChainPlusClient(BaseSettings): result = await LangChainPlusClient._arun_llm_or_chain( example, tracer, - llm_or_chain, + llm_or_chain_factory, num_repetitions, ) results[str(example.id)] = result @@ -474,7 +493,7 @@ class LangChainPlusClient(BaseSettings): def run_llm_or_chain( example: Example, langchain_tracer: LangChainTracerV2, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, n_repetitions: int, ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: """Run the chain synchronously.""" @@ -483,14 +502,13 @@ class LangChainPlusClient(BaseSettings): outputs = [] for _ in range(n_repetitions): try: - if isinstance(llm_or_chain, BaseLanguageModel): + if isinstance(llm_or_chain_factory, BaseLanguageModel): output: Any = LangChainPlusClient.run_llm( - llm_or_chain, example.inputs, langchain_tracer + llm_or_chain_factory, example.inputs, langchain_tracer ) else: - output = llm_or_chain.run( - example.inputs, callbacks=[langchain_tracer] - ) + chain = llm_or_chain_factory() + output = chain.run(example.inputs, callbacks=[langchain_tracer]) outputs.append(output) except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") @@ -502,7 +520,8 @@ class LangChainPlusClient(BaseSettings): def run_on_dataset( self, dataset_name: str, - llm_or_chain: Union[Chain, BaseLanguageModel], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, num_repetitions: int = 1, session_name: Optional[str] = None, verbose: bool = False, @@ -511,7 +530,9 @@ class LangChainPlusClient(BaseSettings): Args: dataset_name: Name of the dataset to run the chain on. - llm_or_chain: Chain or language model to run over the dataset. + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. concurrency_level: Number of async workers to run in parallel. num_repetitions: Number of times to run the model on each example. This is useful when testing success rates or generating confidence @@ -523,11 +544,9 @@ class LangChainPlusClient(BaseSettings): Returns: A dictionary mapping example ids to the model outputs. """ - if session_name is None: - current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - session_name = ( - f"{dataset_name}-{llm_or_chain.__class__.__name__}-{current_time}" - ) + session_name = LangChainPlusClient._get_session_name( + session_name, llm_or_chain_factory, dataset_name + ) dataset = self.read_dataset(dataset_name=dataset_name) examples = list(self.list_examples(dataset_id=str(dataset.id))) results: Dict[str, Any] = {} @@ -539,7 +558,7 @@ class LangChainPlusClient(BaseSettings): result = self.run_llm_or_chain( example, tracer, - llm_or_chain, + llm_or_chain_factory, num_repetitions, ) if verbose: diff --git a/langchain/experimental/client/tracing_datasets.ipynb b/langchain/experimental/client/tracing_datasets.ipynb index 7330d66321..dddbc52d63 100644 --- a/langchain/experimental/client/tracing_datasets.ipynb +++ b/langchain/experimental/client/tracing_datasets.ipynb @@ -133,21 +133,19 @@ "output_type": "stream", "text": [ "The current population of Canada as of 2023 is 39,566,248.\n", - "Anwar Hadid's age raised to the 0.43 power is approximately 3.87.\n", + "Anwar Hadid is Dua Lipa's boyfriend and his age raised to the 0.43 power is approximately 3.87.\n", "LLMMathChain._evaluate(\"\n", "(age)**0.43\n", "\") raised error: 'age'. Please try again with a valid numerical expression\n", - "The distance between Paris and Boston is 3448 miles.\n", - "unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n", + "The distance between Paris and Boston is approximately 3448 miles.\n", + "unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n", "LLMMathChain._evaluate(\"\n", "(total number of points scored in the 2023 super bowl)**0.23\n", "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n", - "3 points were scored more in the 2023 Super Bowl than in the 2022 Super Bowl.\n", + "Could not parse LLM output: `The final answer is that there were no more points scored in the 2023 Super Bowl than in the 2022 Super Bowl.`\n", "1.9347796717823205\n", - "81\n", - "LLMMathChain._evaluate(\"\n", - "round(0.2791714614499425, 2)\n", - "\") raised error: 'VariableNode' object is not callable. Please try again with a valid numerical expression\n" + "77\n", + "0.2791714614499425\n" ] } ], @@ -254,12 +252,109 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "id": "60d14593-c61f-449f-a38f-772ca43707c2", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset json (/Users/wfh/.cache/huggingface/datasets/LangChainDatasets___json/LangChainDatasets--agent-search-calculator-8a025c0ce5fb99d2/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c34edde8de5340888b3278d1ac427417", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
inputoutput
0How many people live in canada as of 2023?approximately 38,625,801
1who is dua lipa's boyfriend? what is his age r...her boyfriend is Romain Gravas. his age raised...
2what is dua lipa's boyfriend age raised to the...her boyfriend is Romain Gravas. his age raised...
3how far is it from paris to boston in milesapproximately 3,435 mi
4what was the total number of points scored in ...approximately 2.682651500990882
\n", + "" + ], + "text/plain": [ + " input \\\n", + "0 How many people live in canada as of 2023? \n", + "1 who is dua lipa's boyfriend? what is his age r... \n", + "2 what is dua lipa's boyfriend age raised to the... \n", + "3 how far is it from paris to boston in miles \n", + "4 what was the total number of points scored in ... \n", + "\n", + " output \n", + "0 approximately 38,625,801 \n", + "1 her boyfriend is Romain Gravas. his age raised... \n", + "2 her boyfriend is Romain Gravas. his age raised... \n", + "3 approximately 3,435 mi \n", + "4 approximately 2.682651500990882 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# import pandas as pd\n", "# from langchain.evaluation.loading import load_dataset\n", @@ -272,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "52a7ea76-79ca-4765-abf7-231e884040d6", "metadata": { "tags": [] @@ -308,7 +403,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "c2b59104-b90e-466a-b7ea-c5bd0194263b", "metadata": { "tags": [] @@ -336,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "id": "112d7bdf-7e50-4c1a-9285-5bac8473f2ee", "metadata": { "tags": [] @@ -348,7 +443,8 @@ "\u001b[0;31mSignature:\u001b[0m\n", "\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marun_on_dataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mdataset_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'str'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", - "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Union[Chain, BaseLanguageModel]'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mllm_or_chain_factory\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'MODEL_OR_CHAIN_FACTORY'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mconcurrency_level\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mnum_repetitions\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'int'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0msession_name\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Optional[str]'\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", @@ -359,7 +455,9 @@ "\n", "Args:\n", " dataset_name: Name of the dataset to run the chain on.\n", - " llm_or_chain: Chain or language model to run over the dataset.\n", + " llm_or_chain_factory: Language model or Chain constructor to run\n", + " over the dataset. The Chain constructor is used to permit\n", + " independent calls on each example without carrying over state.\n", " concurrency_level: The number of async tasks to run concurrently.\n", " num_repetitions: Number of times to run the model on each example.\n", " This is useful when testing success rates or generating confidence\n", @@ -384,7 +482,26 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, + "id": "6e10f823", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Since chains can be stateful (e.g. they can have memory), we need provide\n", + "# a way to initialize a new chain for each row in the dataset. This is done\n", + "# by passing in a factory function that returns a new chain for each row.\n", + "chain_factory = lambda: initialize_agent(tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=False)\n", + "\n", + "# If your chain is NOT stateful, your lambda can return the object directly\n", + "# to improve runtime performance. For example:\n", + "# chain_factory = lambda: agent" + ] + }, + { + "cell_type": "code", + "execution_count": 12, "id": "a8088b7d-3ab6-4279-94c8-5116fe7cee33", "metadata": { "tags": [] @@ -396,7 +513,9 @@ "text": [ "/Users/wfh/code/lc/lckg/langchain/callbacks/manager.py:78: UserWarning: The experimental tracing v2 is in development. This is not yet stable and may change in the future.\n", " warnings.warn(\n", - "Chain failed for example 92c75ce4-f807-4d44-8f7e-027610f7fcbd. Error: unknown format from LLM: Sorry, I cannot answer this question as it requires information from the future.\n" + "Chain failed for example 5523e460-6bb4-4a64-be37-bec0a98699a4. Error: LLMMathChain._evaluate(\"\n", + "(total number of points scored in the 2023 super bowl)**0.23\n", + "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" ] }, { @@ -410,25 +529,23 @@ "name": "stderr", "output_type": "stream", "text": [ - "Chain failed for example 9f5d1426-3e21-4628-b5f9-d2ad354bfa8d. Error: LLMMathChain._evaluate(\"\n", - "(age ** 0.43)\n", - "\") raised error: 'age'. Please try again with a valid numerical expression\n" + "Chain failed for example f193a3f6-1147-4ce6-a83e-fab1157dc88d. Error: unknown format from LLM: Assuming we don't have any information about the actual number of points scored in the 2023 super bowl, we cannot provide a mathematical expression to solve this problem.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Processed examples: 4\r" + "Processed examples: 6\r" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Chain failed for example e480f086-6d3f-4659-8669-26316db7e772. Error: LLMMathChain._evaluate(\"\n", - "(total number of points scored in the 2023 super bowl)**0.23\n", - "\") raised error: invalid syntax. Perhaps you forgot a comma? (, line 1). Please try again with a valid numerical expression\n" + "Chain failed for example 6d7bbb45-1dc0-4adc-be21-4f76a208a8d2. Error: LLMMathChain._evaluate(\"\n", + "(age ** 0.43)\n", + "\") raised error: 'age'. Please try again with a valid numerical expression\n" ] }, { @@ -442,7 +559,7 @@ "source": [ "chain_results = await client.arun_on_dataset(\n", " dataset_name=dataset_name,\n", - " llm_or_chain=agent,\n", + " llm_or_chain_factory=chain_factory,\n", " verbose=True\n", ")\n", "\n", @@ -463,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "136db492-d6ca-4215-96f9-439c23538241", "metadata": { "tags": [] @@ -478,7 +595,7 @@ "LangChainPlusClient (API URL: http://localhost:8000)" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -508,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "id": "64490d7c-9a18-49ed-a3ac-36049c522cb4", "metadata": { "tags": [] @@ -524,7 +641,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "31a576ae98634602b349046ec0821c0d", + "model_id": "047a8094367f43938f74e863b3e01711", "version_major": 2, "version_minor": 0 }, @@ -606,7 +723,7 @@ "4 [{'data': {'content': 'Here is the topic for a... " ] }, - "execution_count": 8, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -622,7 +739,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "348acd86-a927-4d60-8d52-02e64585e4fc", "metadata": { "tags": [] @@ -652,7 +769,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "a69dd183-ad5e-473d-b631-db90706e837f", "metadata": { "tags": [] @@ -691,7 +808,7 @@ "source": [ "chat_model_results = await client.arun_on_dataset(\n", " dataset_name=chat_dataset_name,\n", - " llm_or_chain=chat_model,\n", + " llm_or_chain_factory=chat_model,\n", " concurrency_level=5, # Optional, sets the number of examples to run at a time\n", " num_repetitions=3,\n", " verbose=True\n", @@ -936,7 +1053,7 @@ "# We also offer a synchronous method for running examples if a chain or llm's async methods aren't yet implemented\n", "completions_model_results = client.run_on_dataset(\n", " dataset_name=completions_dataset_name,\n", - " llm_or_chain=llm,\n", + " llm_or_chain_factory=llm,\n", " num_repetitions=1,\n", " verbose=True\n", ")" diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index efa0497179..731c4d694d 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -218,7 +218,7 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: num_repetitions = 3 results = await client.arun_on_dataset( dataset_name="test", - llm_or_chain=chain, + llm_or_chain_factory=lambda: chain, concurrency_level=2, session_name="test_session", num_repetitions=num_repetitions,