diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 83d83305..d2f13aec 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio -import functools import logging import socket from datetime import datetime @@ -10,9 +8,8 @@ from typing import ( TYPE_CHECKING, Any, Callable, - Coroutine, Dict, - Iterable, + Iterator, List, Optional, Tuple, @@ -27,10 +24,8 @@ from requests import Response from tenacity import retry, stop_after_attempt, wait_fixed from langchain.base_language import BaseLanguageModel -from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.schemas import Run, TracerSession from langchain.chains.base import Chain -from langchain.chat_models.base import BaseChatModel from langchain.client.models import ( Dataset, DatasetCreate, @@ -38,15 +33,7 @@ from langchain.client.models import ( ExampleCreate, ListRunsQueryParams, ) -from langchain.llms.base import BaseLLM -from langchain.schema import ( - BaseMessage, - ChatResult, - HumanMessage, - LLMResult, - get_buffer_string, - messages_from_dict, -) +from langchain.client.runner_utils import arun_on_examples, run_on_examples from langchain.utils import raise_for_status_with_text, xor_args if TYPE_CHECKING: @@ -57,10 +44,6 @@ logger = logging.getLogger(__name__) MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] -class InputFormatError(Exception): - """Raised when input format is invalid.""" - - def _get_link_stem(url: str) -> str: scheme = urlsplit(url).scheme netloc_prefix = urlsplit(url).netloc.split(":")[0] @@ -231,7 +214,7 @@ class LangChainPlusClient(BaseSettings): session_name: Optional[str] = None, run_type: Optional[str] = None, **kwargs: Any, - ) -> List[Run]: + ) -> Iterator[Run]: """List runs from the LangChain+ API.""" if session_name is not None: if session_id is not None: @@ -245,7 +228,7 @@ class LangChainPlusClient(BaseSettings): } response = self._get("/runs", params=filtered_params) raise_for_status_with_text(response) - return [Run(**run) for run in response.json()] + yield from [Run(**run) for run in response.json()] @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) @xor_args(("session_id", "session_name")) @@ -279,11 +262,11 @@ class LangChainPlusClient(BaseSettings): return TracerSession(**response.json()) @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) - def list_sessions(self) -> List[TracerSession]: + def list_sessions(self) -> Iterator[TracerSession]: """List sessions from the LangChain+ API.""" response = self._get("/sessions") raise_for_status_with_text(response) - return [TracerSession(**session) for session in response.json()] + yield from [TracerSession(**session) for session in response.json()] def create_dataset(self, dataset_name: str, description: str) -> Dataset: """Create a dataset in the LangChain+ API.""" @@ -326,11 +309,11 @@ class LangChainPlusClient(BaseSettings): return Dataset(**result) @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) - def list_datasets(self, limit: int = 100) -> Iterable[Dataset]: + def list_datasets(self, limit: int = 100) -> Iterator[Dataset]: """List the datasets on the LangChain+ API.""" response = self._get("/datasets", params={"limit": limit}) raise_for_status_with_text(response) - return [Dataset(**dataset) for dataset in response.json()] + yield from [Dataset(**dataset) for dataset in response.json()] @xor_args(("dataset_id", "dataset_name")) def delete_dataset( @@ -346,7 +329,7 @@ class LangChainPlusClient(BaseSettings): headers=self._headers, ) raise_for_status_with_text(response) - return response.json() + return Dataset(**response.json()) @xor_args(("dataset_id", "dataset_name")) def create_example( @@ -386,7 +369,7 @@ class LangChainPlusClient(BaseSettings): @retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5)) def list_examples( self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None - ) -> Iterable[Example]: + ) -> Iterator[Example]: """List the datasets on the LangChain+ API.""" params = {} if dataset_id is not None: @@ -398,195 +381,7 @@ class LangChainPlusClient(BaseSettings): pass response = self._get("/examples", params=params) raise_for_status_with_text(response) - return [Example(**dataset) for dataset in response.json()] - - @staticmethod - def _get_prompts(inputs: Dict[str, Any]) -> List[str]: - """Get prompts from inputs.""" - if not inputs: - raise InputFormatError("Inputs should not be empty.") - - prompts = [] - - if "prompt" in inputs: - if not isinstance(inputs["prompt"], str): - raise InputFormatError( - "Expected string for 'prompt', got" - f" {type(inputs['prompt']).__name__}" - ) - prompts = [inputs["prompt"]] - elif "prompts" in inputs: - if not isinstance(inputs["prompts"], list) or not all( - isinstance(i, str) for i in inputs["prompts"] - ): - raise InputFormatError( - "Expected list of strings for 'prompts'," - f" got {type(inputs['prompts']).__name__}" - ) - prompts = inputs["prompts"] - elif len(inputs) == 1: - prompt_ = next(iter(inputs.values())) - if isinstance(prompt_, str): - prompts = [prompt_] - elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_): - prompts = prompt_ - else: - raise InputFormatError( - f"LLM Run expects string prompt input. Got {inputs}" - ) - else: - raise InputFormatError( - f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}" - ) - - return prompts - - @staticmethod - def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]: - """Get Chat Messages from inputs.""" - if not inputs: - raise InputFormatError("Inputs should not be empty.") - - if "messages" in inputs: - single_input = inputs["messages"] - elif len(inputs) == 1: - single_input = next(iter(inputs.values())) - else: - raise InputFormatError( - f"Chat Run expects 'messages' in inputs. Got {inputs}" - ) - if isinstance(single_input, list) and all( - isinstance(i, dict) for i in single_input - ): - raw_messages = [single_input] - elif isinstance(single_input, list) and all( - isinstance(i, list) for i in single_input - ): - raw_messages = single_input - else: - raise InputFormatError( - f"Chat Run expects List[dict] or List[List[dict]] 'messages'" - f" input. Got {inputs}" - ) - return [messages_from_dict(batch) for batch in raw_messages] - - @staticmethod - async def _arun_llm( - llm: BaseLanguageModel, - inputs: Dict[str, Any], - langchain_tracer: LangChainTracer, - ) -> Union[LLMResult, ChatResult]: - if isinstance(llm, BaseLLM): - try: - llm_prompts = LangChainPlusClient._get_prompts(inputs) - llm_output = await llm.agenerate( - llm_prompts, callbacks=[langchain_tracer] - ) - except InputFormatError: - llm_messages = LangChainPlusClient._get_messages(inputs) - buffer_strings = [ - get_buffer_string(messages) for messages in llm_messages - ] - llm_output = await llm.agenerate( - buffer_strings, callbacks=[langchain_tracer] - ) - elif isinstance(llm, BaseChatModel): - try: - messages = LangChainPlusClient._get_messages(inputs) - llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer]) - except InputFormatError: - prompts = LangChainPlusClient._get_prompts(inputs) - converted_messages: List[List[BaseMessage]] = [ - [HumanMessage(content=prompt)] for prompt in prompts - ] - llm_output = await llm.agenerate( - converted_messages, callbacks=[langchain_tracer] - ) - else: - raise ValueError(f"Unsupported LLM type {type(llm)}") - return llm_output - - @staticmethod - async def _arun_llm_or_chain( - example: Example, - langchain_tracer: LangChainTracer, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, - n_repetitions: int, - ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: - """Run the chain asynchronously.""" - previous_example_id = langchain_tracer.example_id - langchain_tracer.example_id = example.id - outputs = [] - for _ in range(n_repetitions): - try: - if isinstance(llm_or_chain_factory, BaseLanguageModel): - output: Any = await LangChainPlusClient._arun_llm( - llm_or_chain_factory, example.inputs, langchain_tracer - ) - else: - chain = llm_or_chain_factory() - output = await chain.arun( - example.inputs, callbacks=[langchain_tracer] - ) - outputs.append(output) - except Exception as e: - logger.warning(f"Chain failed for example {example.id}. Error: {e}") - outputs.append({"Error": str(e)}) - langchain_tracer.example_id = previous_example_id - return outputs - - @staticmethod - async def _gather_with_concurrency( - n: int, - initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]], - *async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]], - ) -> List[Any]: - """ - Run coroutines with a concurrency limit. - - Args: - n: The maximum number of concurrent tasks. - initializer: A coroutine that initializes shared resources for the tasks. - async_funcs: The async_funcs to be run concurrently. - - Returns: - A list of results from the coroutines. - """ - semaphore = asyncio.Semaphore(n) - job_state = {"num_processed": 0} - - tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue() - for _ in range(n): - tracer_queue.put_nowait(await initializer()) - - async def run_coroutine_with_semaphore( - async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]] - ) -> Any: - async with semaphore: - tracer = await tracer_queue.get() - try: - result = await async_func(tracer, job_state) - finally: - tracer_queue.put_nowait(tracer) - return result - - return await asyncio.gather( - *(run_coroutine_with_semaphore(function) for function in async_funcs) - ) - - async def _tracer_initializer(self, session_name: str) -> LangChainTracer: - """ - Initialize a tracer to share across tasks. - - Args: - session_name: The session name for the tracer. - - Returns: - A LangChainTracer instance with an active session. - """ - tracer = LangChainTracer(session_name=session_name) - tracer.ensure_session() - return tracer + yield from [Example(**dataset) for dataset in response.json()] async def arun_on_dataset( self, @@ -622,93 +417,15 @@ class LangChainPlusClient(BaseSettings): ) dataset = self.read_dataset(dataset_name=dataset_name) examples = self.list_examples(dataset_id=str(dataset.id)) - results: Dict[str, List[Any]] = {} - - async def process_example( - example: Example, tracer: LangChainTracer, job_state: dict - ) -> None: - """Process a single example.""" - result = await LangChainPlusClient._arun_llm_or_chain( - example, - tracer, - llm_or_chain_factory, - num_repetitions, - ) - results[str(example.id)] = result - job_state["num_processed"] += 1 - if verbose: - print( - f"Processed examples: {job_state['num_processed']}", - end="\r", - flush=True, - ) - await self._gather_with_concurrency( - concurrency_level, - functools.partial(self._tracer_initializer, session_name), - *(functools.partial(process_example, e) for e in examples), + return await arun_on_examples( + examples, + llm_or_chain_factory, + concurrency_level=concurrency_level, + num_repetitions=num_repetitions, + session_name=session_name, + verbose=verbose, ) - return results - - @staticmethod - def run_llm( - llm: BaseLanguageModel, - inputs: Dict[str, Any], - langchain_tracer: LangChainTracer, - ) -> Union[LLMResult, ChatResult]: - """Run the language model on the example.""" - if isinstance(llm, BaseLLM): - try: - llm_prompts = LangChainPlusClient._get_prompts(inputs) - llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer]) - except InputFormatError: - llm_messages = LangChainPlusClient._get_messages(inputs) - buffer_strings = [ - get_buffer_string(messages) for messages in llm_messages - ] - llm_output = llm.generate(buffer_strings, callbacks=[langchain_tracer]) - elif isinstance(llm, BaseChatModel): - try: - messages = LangChainPlusClient._get_messages(inputs) - llm_output = llm.generate(messages, callbacks=[langchain_tracer]) - except InputFormatError: - prompts = LangChainPlusClient._get_prompts(inputs) - converted_messages: List[List[BaseMessage]] = [ - [HumanMessage(content=prompt)] for prompt in prompts - ] - llm_output = llm.generate( - converted_messages, callbacks=[langchain_tracer] - ) - else: - raise ValueError(f"Unsupported LLM type {type(llm)}") - return llm_output - - @staticmethod - def run_llm_or_chain( - example: Example, - langchain_tracer: LangChainTracer, - llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, - n_repetitions: int, - ) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: - """Run the chain synchronously.""" - previous_example_id = langchain_tracer.example_id - langchain_tracer.example_id = example.id - outputs = [] - for _ in range(n_repetitions): - try: - if isinstance(llm_or_chain_factory, BaseLanguageModel): - output: Any = LangChainPlusClient.run_llm( - llm_or_chain_factory, example.inputs, langchain_tracer - ) - else: - chain = llm_or_chain_factory() - output = chain.run(example.inputs, callbacks=[langchain_tracer]) - outputs.append(output) - except Exception as e: - logger.warning(f"Chain failed for example {example.id}. Error: {e}") - outputs.append({"Error": str(e)}) - langchain_tracer.example_id = previous_example_id - return outputs def run_on_dataset( self, @@ -741,18 +458,11 @@ class LangChainPlusClient(BaseSettings): session_name, llm_or_chain_factory, dataset_name ) dataset = self.read_dataset(dataset_name=dataset_name) - examples = list(self.list_examples(dataset_id=str(dataset.id))) - results: Dict[str, Any] = {} - tracer = LangChainTracer(session_name=session_name) - tracer.ensure_session() - for i, example in enumerate(examples): - result = self.run_llm_or_chain( - example, - tracer, - llm_or_chain_factory, - num_repetitions, - ) - if verbose: - print(f"{i+1} processed", flush=True, end="\r") - results[str(example.id)] = result - return results + examples = self.list_examples(dataset_id=str(dataset.id)) + return run_on_examples( + examples, + llm_or_chain_factory, + num_repetitions=num_repetitions, + session_name=session_name, + verbose=verbose, + ) diff --git a/langchain/client/runner_utils.py b/langchain/client/runner_utils.py new file mode 100644 index 00000000..316d3585 --- /dev/null +++ b/langchain/client/runner_utils.py @@ -0,0 +1,375 @@ +"""Utilities for running LLMs/Chains over datasets.""" +from __future__ import annotations + +import asyncio +import functools +import logging +from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.manager import Callbacks +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.chains.base import Chain +from langchain.chat_models.base import BaseChatModel +from langchain.client.models import Example +from langchain.llms.base import BaseLLM +from langchain.schema import ( + BaseMessage, + ChatResult, + HumanMessage, + LLMResult, + get_buffer_string, + messages_from_dict, +) + +logger = logging.getLogger(__name__) + +MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] + + +class InputFormatError(Exception): + """Raised when input format is invalid.""" + + +def _get_prompts(inputs: Dict[str, Any]) -> List[str]: + """Get prompts from inputs.""" + if not inputs: + raise InputFormatError("Inputs should not be empty.") + + prompts = [] + if "prompt" in inputs: + if not isinstance(inputs["prompt"], str): + raise InputFormatError( + "Expected string for 'prompt', got" + f" {type(inputs['prompt']).__name__}" + ) + prompts = [inputs["prompt"]] + elif "prompts" in inputs: + if not isinstance(inputs["prompts"], list) or not all( + isinstance(i, str) for i in inputs["prompts"] + ): + raise InputFormatError( + "Expected list of strings for 'prompts'," + f" got {type(inputs['prompts']).__name__}" + ) + prompts = inputs["prompts"] + elif len(inputs) == 1: + prompt_ = next(iter(inputs.values())) + if isinstance(prompt_, str): + prompts = [prompt_] + elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_): + prompts = prompt_ + else: + raise InputFormatError(f"LLM Run expects string prompt input. Got {inputs}") + else: + raise InputFormatError( + f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}" + ) + + return prompts + + +def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]: + """Get Chat Messages from inputs.""" + if not inputs: + raise InputFormatError("Inputs should not be empty.") + + if "messages" in inputs: + single_input = inputs["messages"] + elif len(inputs) == 1: + single_input = next(iter(inputs.values())) + else: + raise InputFormatError(f"Chat Run expects 'messages' in inputs. Got {inputs}") + if isinstance(single_input, list) and all( + isinstance(i, dict) for i in single_input + ): + raw_messages = [single_input] + elif isinstance(single_input, list) and all( + isinstance(i, list) for i in single_input + ): + raw_messages = single_input + else: + raise InputFormatError( + f"Chat Run expects List[dict] or List[List[dict]] 'messages'" + f" input. Got {inputs}" + ) + return [messages_from_dict(batch) for batch in raw_messages] + + +async def _arun_llm( + llm: BaseLanguageModel, + inputs: Dict[str, Any], + langchain_tracer: Optional[LangChainTracer], +) -> Union[LLMResult, ChatResult]: + callbacks: Optional[List[BaseCallbackHandler]] = ( + [langchain_tracer] if langchain_tracer else None + ) + if isinstance(llm, BaseLLM): + try: + llm_prompts = _get_prompts(inputs) + llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks) + except InputFormatError: + llm_messages = _get_messages(inputs) + buffer_strings = [get_buffer_string(messages) for messages in llm_messages] + llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks) + elif isinstance(llm, BaseChatModel): + try: + messages = _get_messages(inputs) + llm_output = await llm.agenerate(messages, callbacks=callbacks) + except InputFormatError: + prompts = _get_prompts(inputs) + converted_messages: List[List[BaseMessage]] = [ + [HumanMessage(content=prompt)] for prompt in prompts + ] + llm_output = await llm.agenerate(converted_messages, callbacks=callbacks) + else: + raise ValueError(f"Unsupported LLM type {type(llm)}") + return llm_output + + +async def _arun_llm_or_chain( + example: Example, + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + n_repetitions: int, + langchain_tracer: Optional[LangChainTracer], +) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: + """Run the chain asynchronously.""" + if langchain_tracer is not None: + previous_example_id = langchain_tracer.example_id + langchain_tracer.example_id = example.id + callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer] + else: + previous_example_id = None + callbacks = None + outputs = [] + for _ in range(n_repetitions): + try: + if isinstance(llm_or_chain_factory, BaseLanguageModel): + output: Any = await _arun_llm( + llm_or_chain_factory, example.inputs, langchain_tracer + ) + else: + chain = llm_or_chain_factory() + output = await chain.arun(example.inputs, callbacks=callbacks) + outputs.append(output) + except Exception as e: + logger.warning(f"Chain failed for example {example.id}. Error: {e}") + outputs.append({"Error": str(e)}) + if langchain_tracer is not None: + langchain_tracer.example_id = previous_example_id + return outputs + + +async def _gather_with_concurrency( + n: int, + initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]], + *async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]], +) -> List[Any]: + """ + Run coroutines with a concurrency limit. + + Args: + n: The maximum number of concurrent tasks. + initializer: A coroutine that initializes shared resources for the tasks. + async_funcs: The async_funcs to be run concurrently. + + Returns: + A list of results from the coroutines. + """ + semaphore = asyncio.Semaphore(n) + job_state = {"num_processed": 0} + + tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue() + for _ in range(n): + tracer_queue.put_nowait(await initializer()) + + async def run_coroutine_with_semaphore( + async_func: Callable[ + [Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any] + ] + ) -> Any: + async with semaphore: + tracer = await tracer_queue.get() + try: + result = await async_func(tracer, job_state) + finally: + tracer_queue.put_nowait(tracer) + return result + + return await asyncio.gather( + *(run_coroutine_with_semaphore(function) for function in async_funcs) + ) + + +async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]: + """ + Initialize a tracer to share across tasks. + + Args: + session_name: The session name for the tracer. + + Returns: + A LangChainTracer instance with an active session. + """ + if session_name: + tracer = LangChainTracer(session_name=session_name) + tracer.ensure_session() + return tracer + else: + return None + + +async def arun_on_examples( + examples: Iterator[Example], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, + concurrency_level: int = 5, + num_repetitions: int = 1, + session_name: Optional[str] = None, + verbose: bool = False, +) -> Dict[str, Any]: + """ + Run the chain on examples and store traces to the specified session name. + + Args: + examples: Examples to run the model or chain over + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. + concurrency_level: The number of async tasks to run concurrently. + num_repetitions: Number of times to run the model on each example. + This is useful when testing success rates or generating confidence + intervals. + session_name: Session name to use when tracing runs. + verbose: Whether to print progress. + + Returns: + A dictionary mapping example ids to the model outputs. + """ + results: Dict[str, List[Any]] = {} + + async def process_example( + example: Example, tracer: LangChainTracer, job_state: dict + ) -> None: + """Process a single example.""" + result = await _arun_llm_or_chain( + example, + llm_or_chain_factory, + num_repetitions, + tracer, + ) + results[str(example.id)] = result + job_state["num_processed"] += 1 + if verbose: + print( + f"Processed examples: {job_state['num_processed']}", + end="\r", + flush=True, + ) + + await _gather_with_concurrency( + concurrency_level, + functools.partial(_tracer_initializer, session_name), + *(functools.partial(process_example, e) for e in examples), + ) + return results + + +def run_llm( + llm: BaseLanguageModel, + inputs: Dict[str, Any], + callbacks: Callbacks, +) -> Union[LLMResult, ChatResult]: + """Run the language model on the example.""" + if isinstance(llm, BaseLLM): + try: + llm_prompts = _get_prompts(inputs) + llm_output = llm.generate(llm_prompts, callbacks=callbacks) + except InputFormatError: + llm_messages = _get_messages(inputs) + buffer_strings = [get_buffer_string(messages) for messages in llm_messages] + llm_output = llm.generate(buffer_strings, callbacks=callbacks) + elif isinstance(llm, BaseChatModel): + try: + messages = _get_messages(inputs) + llm_output = llm.generate(messages, callbacks=callbacks) + except InputFormatError: + prompts = _get_prompts(inputs) + converted_messages: List[List[BaseMessage]] = [ + [HumanMessage(content=prompt)] for prompt in prompts + ] + llm_output = llm.generate(converted_messages, callbacks=callbacks) + else: + raise ValueError(f"Unsupported LLM type {type(llm)}") + return llm_output + + +def run_llm_or_chain( + example: Example, + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + n_repetitions: int, + langchain_tracer: Optional[LangChainTracer] = None, +) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]: + """Run the chain synchronously.""" + if langchain_tracer is not None: + previous_example_id = langchain_tracer.example_id + langchain_tracer.example_id = example.id + callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer] + else: + previous_example_id = None + callbacks = None + outputs = [] + for _ in range(n_repetitions): + try: + if isinstance(llm_or_chain_factory, BaseLanguageModel): + output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks) + else: + chain = llm_or_chain_factory() + output = chain.run(example.inputs, callbacks=callbacks) + outputs.append(output) + except Exception as e: + logger.warning(f"Chain failed for example {example.id}. Error: {e}") + outputs.append({"Error": str(e)}) + if langchain_tracer is not None: + langchain_tracer.example_id = previous_example_id + return outputs + + +def run_on_examples( + examples: Iterator[Example], + llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY, + *, + num_repetitions: int = 1, + session_name: Optional[str] = None, + verbose: bool = False, +) -> Dict[str, Any]: + """Run the chain on examples and store traces to the specified session name. + + Args: + examples: Examples to run model or chain over. + llm_or_chain_factory: Language model or Chain constructor to run + over the dataset. The Chain constructor is used to permit + independent calls on each example without carrying over state. + concurrency_level: Number of async workers to run in parallel. + num_repetitions: Number of times to run the model on each example. + This is useful when testing success rates or generating confidence + intervals. + session_name: Session name to use when tracing runs. + verbose: Whether to print progress. + Returns: + A dictionary mapping example ids to the model outputs. + """ + results: Dict[str, Any] = {} + tracer = LangChainTracer(session_name=session_name) if session_name else None + for i, example in enumerate(examples): + result = run_llm_or_chain( + example, + llm_or_chain_factory, + num_repetitions, + langchain_tracer=tracer, + ) + if verbose: + print(f"{i+1} processed", flush=True, end="\r") + results[str(example.id)] = result + return results diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index e712b93f..b45ff3ca 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -12,14 +12,11 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.schemas import TracerSession from langchain.chains.base import Chain from langchain.client.langchain import ( - InputFormatError, LangChainPlusClient, _get_link_stem, _is_localhost, ) from langchain.client.models import Dataset, Example -from tests.unit_tests.llms.fake_chat_model import FakeChatModel -from tests.unit_tests.llms.fake_llm import FakeLLM _CREATED_AT = datetime(2015, 1, 1, 0, 0, 0) _TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4" @@ -191,9 +188,9 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: async def mock_arun_chain( example: Example, - tracer: Any, llm_or_chain: Union[BaseLanguageModel, Chain], n_repetitions: int, + tracer: Any, ) -> List[Dict[str, Any]]: return [ {"result": f"Result for example {example.id}"} for _ in range(n_repetitions) @@ -206,8 +203,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: LangChainPlusClient, "read_dataset", new=mock_read_dataset ), mock.patch.object( LangChainPlusClient, "list_examples", new=mock_list_examples - ), mock.patch.object( - LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain + ), mock.patch( + "langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain ), mock.patch.object( LangChainTracer, "ensure_session", new=mock_ensure_session ): @@ -233,85 +230,3 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: for uuid_ in uuids } assert results == expected - - -_EXAMPLE_MESSAGE = { - "data": {"content": "Foo", "example": False, "additional_kwargs": {}}, - "type": "human", -} -_VALID_MESSAGES = [ - {"messages": [_EXAMPLE_MESSAGE], "other_key": "value"}, - {"messages": [], "other_key": "value"}, - { - "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]], - "other_key": "value", - }, - {"any_key": [_EXAMPLE_MESSAGE]}, - {"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]}, -] -_VALID_PROMPTS = [ - {"prompts": ["foo", "bar", "baz"], "other_key": "value"}, - {"prompt": "foo", "other_key": ["bar", "baz"]}, - {"some_key": "foo"}, - {"some_key": ["foo", "bar"]}, -] - - -@pytest.mark.parametrize( - "inputs", - _VALID_MESSAGES, -) -def test__get_messages_valid(inputs: Dict[str, Any]) -> None: - {"messages": []} - LangChainPlusClient._get_messages(inputs) - - -@pytest.mark.parametrize( - "inputs", - _VALID_PROMPTS, -) -def test__get_prompts_valid(inputs: Dict[str, Any]) -> None: - LangChainPlusClient._get_prompts(inputs) - - -@pytest.mark.parametrize( - "inputs", - [ - {"prompts": "foo"}, - {"prompt": ["foo"]}, - {"some_key": 3}, - {"some_key": "foo", "other_key": "bar"}, - ], -) -def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None: - with pytest.raises(InputFormatError): - LangChainPlusClient._get_prompts(inputs) - - -@pytest.mark.parametrize( - "inputs", - [ - {"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"}, - { - "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE], - "other_key": "value", - }, - {"prompts": "foo"}, - {}, - ], -) -def test__get_messages_invalid(inputs: Dict[str, Any]) -> None: - with pytest.raises(InputFormatError): - LangChainPlusClient._get_messages(inputs) - - -@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES) -def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None: - llm = FakeLLM() - LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock()) - - -@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS) -def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None: - llm = FakeChatModel() - LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock()) diff --git a/tests/unit_tests/client/test_runner_utils.py b/tests/unit_tests/client/test_runner_utils.py new file mode 100644 index 00000000..bfd09920 --- /dev/null +++ b/tests/unit_tests/client/test_runner_utils.py @@ -0,0 +1,95 @@ +"""Test the LangChain+ client.""" +from typing import Any, Dict +from unittest import mock + +import pytest + +from langchain.client.runner_utils import ( + InputFormatError, + _get_messages, + _get_prompts, + run_llm, +) +from tests.unit_tests.llms.fake_chat_model import FakeChatModel +from tests.unit_tests.llms.fake_llm import FakeLLM + +_EXAMPLE_MESSAGE = { + "data": {"content": "Foo", "example": False, "additional_kwargs": {}}, + "type": "human", +} +_VALID_MESSAGES = [ + {"messages": [_EXAMPLE_MESSAGE], "other_key": "value"}, + {"messages": [], "other_key": "value"}, + { + "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]], + "other_key": "value", + }, + {"any_key": [_EXAMPLE_MESSAGE]}, + {"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]}, +] +_VALID_PROMPTS = [ + {"prompts": ["foo", "bar", "baz"], "other_key": "value"}, + {"prompt": "foo", "other_key": ["bar", "baz"]}, + {"some_key": "foo"}, + {"some_key": ["foo", "bar"]}, +] + + +@pytest.mark.parametrize( + "inputs", + _VALID_MESSAGES, +) +def test__get_messages_valid(inputs: Dict[str, Any]) -> None: + {"messages": []} + _get_messages(inputs) + + +@pytest.mark.parametrize( + "inputs", + _VALID_PROMPTS, +) +def test__get_prompts_valid(inputs: Dict[str, Any]) -> None: + _get_prompts(inputs) + + +@pytest.mark.parametrize( + "inputs", + [ + {"prompts": "foo"}, + {"prompt": ["foo"]}, + {"some_key": 3}, + {"some_key": "foo", "other_key": "bar"}, + ], +) +def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None: + with pytest.raises(InputFormatError): + _get_prompts(inputs) + + +@pytest.mark.parametrize( + "inputs", + [ + {"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"}, + { + "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE], + "other_key": "value", + }, + {"prompts": "foo"}, + {}, + ], +) +def test__get_messages_invalid(inputs: Dict[str, Any]) -> None: + with pytest.raises(InputFormatError): + _get_messages(inputs) + + +@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES) +def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None: + llm = FakeLLM() + run_llm(llm, inputs, mock.MagicMock()) + + +@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS) +def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None: + llm = FakeChatModel() + run_llm(llm, inputs, mock.MagicMock())