forked from Archives/langchain
Separate Runner Functions from Client (#5079)
Extract the methods specific to running an LLM or Chain on a dataset to separate utility functions. This simplifies the client a bit and lets us separate concerns of LCP details from running examples (e.g., for evals)
This commit is contained in:
parent
443ebe22f4
commit
ef7d015be5
@ -1,7 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import functools
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -10,9 +8,8 @@ from typing import (
|
|||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Coroutine,
|
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
@ -27,10 +24,8 @@ from requests import Response
|
|||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
|
||||||
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
|
||||||
from langchain.client.models import (
|
from langchain.client.models import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetCreate,
|
DatasetCreate,
|
||||||
@ -38,15 +33,7 @@ from langchain.client.models import (
|
|||||||
ExampleCreate,
|
ExampleCreate,
|
||||||
ListRunsQueryParams,
|
ListRunsQueryParams,
|
||||||
)
|
)
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.client.runner_utils import arun_on_examples, run_on_examples
|
||||||
from langchain.schema import (
|
|
||||||
BaseMessage,
|
|
||||||
ChatResult,
|
|
||||||
HumanMessage,
|
|
||||||
LLMResult,
|
|
||||||
get_buffer_string,
|
|
||||||
messages_from_dict,
|
|
||||||
)
|
|
||||||
from langchain.utils import raise_for_status_with_text, xor_args
|
from langchain.utils import raise_for_status_with_text, xor_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -57,10 +44,6 @@ logger = logging.getLogger(__name__)
|
|||||||
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||||
|
|
||||||
|
|
||||||
class InputFormatError(Exception):
|
|
||||||
"""Raised when input format is invalid."""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_link_stem(url: str) -> str:
|
def _get_link_stem(url: str) -> str:
|
||||||
scheme = urlsplit(url).scheme
|
scheme = urlsplit(url).scheme
|
||||||
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
netloc_prefix = urlsplit(url).netloc.split(":")[0]
|
||||||
@ -231,7 +214,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
session_name: Optional[str] = None,
|
session_name: Optional[str] = None,
|
||||||
run_type: Optional[str] = None,
|
run_type: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Run]:
|
) -> Iterator[Run]:
|
||||||
"""List runs from the LangChain+ API."""
|
"""List runs from the LangChain+ API."""
|
||||||
if session_name is not None:
|
if session_name is not None:
|
||||||
if session_id is not None:
|
if session_id is not None:
|
||||||
@ -245,7 +228,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
}
|
}
|
||||||
response = self._get("/runs", params=filtered_params)
|
response = self._get("/runs", params=filtered_params)
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return [Run(**run) for run in response.json()]
|
yield from [Run(**run) for run in response.json()]
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
@xor_args(("session_id", "session_name"))
|
@xor_args(("session_id", "session_name"))
|
||||||
@ -279,11 +262,11 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
return TracerSession(**response.json())
|
return TracerSession(**response.json())
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
def list_sessions(self) -> List[TracerSession]:
|
def list_sessions(self) -> Iterator[TracerSession]:
|
||||||
"""List sessions from the LangChain+ API."""
|
"""List sessions from the LangChain+ API."""
|
||||||
response = self._get("/sessions")
|
response = self._get("/sessions")
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return [TracerSession(**session) for session in response.json()]
|
yield from [TracerSession(**session) for session in response.json()]
|
||||||
|
|
||||||
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
def create_dataset(self, dataset_name: str, description: str) -> Dataset:
|
||||||
"""Create a dataset in the LangChain+ API."""
|
"""Create a dataset in the LangChain+ API."""
|
||||||
@ -326,11 +309,11 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
return Dataset(**result)
|
return Dataset(**result)
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
def list_datasets(self, limit: int = 100) -> Iterable[Dataset]:
|
def list_datasets(self, limit: int = 100) -> Iterator[Dataset]:
|
||||||
"""List the datasets on the LangChain+ API."""
|
"""List the datasets on the LangChain+ API."""
|
||||||
response = self._get("/datasets", params={"limit": limit})
|
response = self._get("/datasets", params={"limit": limit})
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return [Dataset(**dataset) for dataset in response.json()]
|
yield from [Dataset(**dataset) for dataset in response.json()]
|
||||||
|
|
||||||
@xor_args(("dataset_id", "dataset_name"))
|
@xor_args(("dataset_id", "dataset_name"))
|
||||||
def delete_dataset(
|
def delete_dataset(
|
||||||
@ -346,7 +329,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return response.json()
|
return Dataset(**response.json())
|
||||||
|
|
||||||
@xor_args(("dataset_id", "dataset_name"))
|
@xor_args(("dataset_id", "dataset_name"))
|
||||||
def create_example(
|
def create_example(
|
||||||
@ -386,7 +369,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||||
def list_examples(
|
def list_examples(
|
||||||
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
self, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None
|
||||||
) -> Iterable[Example]:
|
) -> Iterator[Example]:
|
||||||
"""List the datasets on the LangChain+ API."""
|
"""List the datasets on the LangChain+ API."""
|
||||||
params = {}
|
params = {}
|
||||||
if dataset_id is not None:
|
if dataset_id is not None:
|
||||||
@ -398,195 +381,7 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
pass
|
pass
|
||||||
response = self._get("/examples", params=params)
|
response = self._get("/examples", params=params)
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
return [Example(**dataset) for dataset in response.json()]
|
yield from [Example(**dataset) for dataset in response.json()]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
|
||||||
"""Get prompts from inputs."""
|
|
||||||
if not inputs:
|
|
||||||
raise InputFormatError("Inputs should not be empty.")
|
|
||||||
|
|
||||||
prompts = []
|
|
||||||
|
|
||||||
if "prompt" in inputs:
|
|
||||||
if not isinstance(inputs["prompt"], str):
|
|
||||||
raise InputFormatError(
|
|
||||||
"Expected string for 'prompt', got"
|
|
||||||
f" {type(inputs['prompt']).__name__}"
|
|
||||||
)
|
|
||||||
prompts = [inputs["prompt"]]
|
|
||||||
elif "prompts" in inputs:
|
|
||||||
if not isinstance(inputs["prompts"], list) or not all(
|
|
||||||
isinstance(i, str) for i in inputs["prompts"]
|
|
||||||
):
|
|
||||||
raise InputFormatError(
|
|
||||||
"Expected list of strings for 'prompts',"
|
|
||||||
f" got {type(inputs['prompts']).__name__}"
|
|
||||||
)
|
|
||||||
prompts = inputs["prompts"]
|
|
||||||
elif len(inputs) == 1:
|
|
||||||
prompt_ = next(iter(inputs.values()))
|
|
||||||
if isinstance(prompt_, str):
|
|
||||||
prompts = [prompt_]
|
|
||||||
elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_):
|
|
||||||
prompts = prompt_
|
|
||||||
else:
|
|
||||||
raise InputFormatError(
|
|
||||||
f"LLM Run expects string prompt input. Got {inputs}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise InputFormatError(
|
|
||||||
f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return prompts
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
|
||||||
"""Get Chat Messages from inputs."""
|
|
||||||
if not inputs:
|
|
||||||
raise InputFormatError("Inputs should not be empty.")
|
|
||||||
|
|
||||||
if "messages" in inputs:
|
|
||||||
single_input = inputs["messages"]
|
|
||||||
elif len(inputs) == 1:
|
|
||||||
single_input = next(iter(inputs.values()))
|
|
||||||
else:
|
|
||||||
raise InputFormatError(
|
|
||||||
f"Chat Run expects 'messages' in inputs. Got {inputs}"
|
|
||||||
)
|
|
||||||
if isinstance(single_input, list) and all(
|
|
||||||
isinstance(i, dict) for i in single_input
|
|
||||||
):
|
|
||||||
raw_messages = [single_input]
|
|
||||||
elif isinstance(single_input, list) and all(
|
|
||||||
isinstance(i, list) for i in single_input
|
|
||||||
):
|
|
||||||
raw_messages = single_input
|
|
||||||
else:
|
|
||||||
raise InputFormatError(
|
|
||||||
f"Chat Run expects List[dict] or List[List[dict]] 'messages'"
|
|
||||||
f" input. Got {inputs}"
|
|
||||||
)
|
|
||||||
return [messages_from_dict(batch) for batch in raw_messages]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _arun_llm(
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
langchain_tracer: LangChainTracer,
|
|
||||||
) -> Union[LLMResult, ChatResult]:
|
|
||||||
if isinstance(llm, BaseLLM):
|
|
||||||
try:
|
|
||||||
llm_prompts = LangChainPlusClient._get_prompts(inputs)
|
|
||||||
llm_output = await llm.agenerate(
|
|
||||||
llm_prompts, callbacks=[langchain_tracer]
|
|
||||||
)
|
|
||||||
except InputFormatError:
|
|
||||||
llm_messages = LangChainPlusClient._get_messages(inputs)
|
|
||||||
buffer_strings = [
|
|
||||||
get_buffer_string(messages) for messages in llm_messages
|
|
||||||
]
|
|
||||||
llm_output = await llm.agenerate(
|
|
||||||
buffer_strings, callbacks=[langchain_tracer]
|
|
||||||
)
|
|
||||||
elif isinstance(llm, BaseChatModel):
|
|
||||||
try:
|
|
||||||
messages = LangChainPlusClient._get_messages(inputs)
|
|
||||||
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
|
||||||
except InputFormatError:
|
|
||||||
prompts = LangChainPlusClient._get_prompts(inputs)
|
|
||||||
converted_messages: List[List[BaseMessage]] = [
|
|
||||||
[HumanMessage(content=prompt)] for prompt in prompts
|
|
||||||
]
|
|
||||||
llm_output = await llm.agenerate(
|
|
||||||
converted_messages, callbacks=[langchain_tracer]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
|
||||||
return llm_output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _arun_llm_or_chain(
|
|
||||||
example: Example,
|
|
||||||
langchain_tracer: LangChainTracer,
|
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
||||||
n_repetitions: int,
|
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
|
||||||
"""Run the chain asynchronously."""
|
|
||||||
previous_example_id = langchain_tracer.example_id
|
|
||||||
langchain_tracer.example_id = example.id
|
|
||||||
outputs = []
|
|
||||||
for _ in range(n_repetitions):
|
|
||||||
try:
|
|
||||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
|
||||||
output: Any = await LangChainPlusClient._arun_llm(
|
|
||||||
llm_or_chain_factory, example.inputs, langchain_tracer
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chain = llm_or_chain_factory()
|
|
||||||
output = await chain.arun(
|
|
||||||
example.inputs, callbacks=[langchain_tracer]
|
|
||||||
)
|
|
||||||
outputs.append(output)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
|
||||||
outputs.append({"Error": str(e)})
|
|
||||||
langchain_tracer.example_id = previous_example_id
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _gather_with_concurrency(
|
|
||||||
n: int,
|
|
||||||
initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]],
|
|
||||||
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
|
|
||||||
) -> List[Any]:
|
|
||||||
"""
|
|
||||||
Run coroutines with a concurrency limit.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n: The maximum number of concurrent tasks.
|
|
||||||
initializer: A coroutine that initializes shared resources for the tasks.
|
|
||||||
async_funcs: The async_funcs to be run concurrently.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of results from the coroutines.
|
|
||||||
"""
|
|
||||||
semaphore = asyncio.Semaphore(n)
|
|
||||||
job_state = {"num_processed": 0}
|
|
||||||
|
|
||||||
tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue()
|
|
||||||
for _ in range(n):
|
|
||||||
tracer_queue.put_nowait(await initializer())
|
|
||||||
|
|
||||||
async def run_coroutine_with_semaphore(
|
|
||||||
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
|
|
||||||
) -> Any:
|
|
||||||
async with semaphore:
|
|
||||||
tracer = await tracer_queue.get()
|
|
||||||
try:
|
|
||||||
result = await async_func(tracer, job_state)
|
|
||||||
finally:
|
|
||||||
tracer_queue.put_nowait(tracer)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return await asyncio.gather(
|
|
||||||
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _tracer_initializer(self, session_name: str) -> LangChainTracer:
|
|
||||||
"""
|
|
||||||
Initialize a tracer to share across tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_name: The session name for the tracer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A LangChainTracer instance with an active session.
|
|
||||||
"""
|
|
||||||
tracer = LangChainTracer(session_name=session_name)
|
|
||||||
tracer.ensure_session()
|
|
||||||
return tracer
|
|
||||||
|
|
||||||
async def arun_on_dataset(
|
async def arun_on_dataset(
|
||||||
self,
|
self,
|
||||||
@ -622,93 +417,15 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
)
|
)
|
||||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||||
examples = self.list_examples(dataset_id=str(dataset.id))
|
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||||
results: Dict[str, List[Any]] = {}
|
|
||||||
|
|
||||||
async def process_example(
|
return await arun_on_examples(
|
||||||
example: Example, tracer: LangChainTracer, job_state: dict
|
examples,
|
||||||
) -> None:
|
|
||||||
"""Process a single example."""
|
|
||||||
result = await LangChainPlusClient._arun_llm_or_chain(
|
|
||||||
example,
|
|
||||||
tracer,
|
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
num_repetitions,
|
concurrency_level=concurrency_level,
|
||||||
|
num_repetitions=num_repetitions,
|
||||||
|
session_name=session_name,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
results[str(example.id)] = result
|
|
||||||
job_state["num_processed"] += 1
|
|
||||||
if verbose:
|
|
||||||
print(
|
|
||||||
f"Processed examples: {job_state['num_processed']}",
|
|
||||||
end="\r",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._gather_with_concurrency(
|
|
||||||
concurrency_level,
|
|
||||||
functools.partial(self._tracer_initializer, session_name),
|
|
||||||
*(functools.partial(process_example, e) for e in examples),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def run_llm(
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
langchain_tracer: LangChainTracer,
|
|
||||||
) -> Union[LLMResult, ChatResult]:
|
|
||||||
"""Run the language model on the example."""
|
|
||||||
if isinstance(llm, BaseLLM):
|
|
||||||
try:
|
|
||||||
llm_prompts = LangChainPlusClient._get_prompts(inputs)
|
|
||||||
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
|
||||||
except InputFormatError:
|
|
||||||
llm_messages = LangChainPlusClient._get_messages(inputs)
|
|
||||||
buffer_strings = [
|
|
||||||
get_buffer_string(messages) for messages in llm_messages
|
|
||||||
]
|
|
||||||
llm_output = llm.generate(buffer_strings, callbacks=[langchain_tracer])
|
|
||||||
elif isinstance(llm, BaseChatModel):
|
|
||||||
try:
|
|
||||||
messages = LangChainPlusClient._get_messages(inputs)
|
|
||||||
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
|
||||||
except InputFormatError:
|
|
||||||
prompts = LangChainPlusClient._get_prompts(inputs)
|
|
||||||
converted_messages: List[List[BaseMessage]] = [
|
|
||||||
[HumanMessage(content=prompt)] for prompt in prompts
|
|
||||||
]
|
|
||||||
llm_output = llm.generate(
|
|
||||||
converted_messages, callbacks=[langchain_tracer]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
|
||||||
return llm_output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def run_llm_or_chain(
|
|
||||||
example: Example,
|
|
||||||
langchain_tracer: LangChainTracer,
|
|
||||||
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
|
||||||
n_repetitions: int,
|
|
||||||
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
|
||||||
"""Run the chain synchronously."""
|
|
||||||
previous_example_id = langchain_tracer.example_id
|
|
||||||
langchain_tracer.example_id = example.id
|
|
||||||
outputs = []
|
|
||||||
for _ in range(n_repetitions):
|
|
||||||
try:
|
|
||||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
|
||||||
output: Any = LangChainPlusClient.run_llm(
|
|
||||||
llm_or_chain_factory, example.inputs, langchain_tracer
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
chain = llm_or_chain_factory()
|
|
||||||
output = chain.run(example.inputs, callbacks=[langchain_tracer])
|
|
||||||
outputs.append(output)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
|
||||||
outputs.append({"Error": str(e)})
|
|
||||||
langchain_tracer.example_id = previous_example_id
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def run_on_dataset(
|
def run_on_dataset(
|
||||||
self,
|
self,
|
||||||
@ -741,18 +458,11 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
session_name, llm_or_chain_factory, dataset_name
|
session_name, llm_or_chain_factory, dataset_name
|
||||||
)
|
)
|
||||||
dataset = self.read_dataset(dataset_name=dataset_name)
|
dataset = self.read_dataset(dataset_name=dataset_name)
|
||||||
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
examples = self.list_examples(dataset_id=str(dataset.id))
|
||||||
results: Dict[str, Any] = {}
|
return run_on_examples(
|
||||||
tracer = LangChainTracer(session_name=session_name)
|
examples,
|
||||||
tracer.ensure_session()
|
|
||||||
for i, example in enumerate(examples):
|
|
||||||
result = self.run_llm_or_chain(
|
|
||||||
example,
|
|
||||||
tracer,
|
|
||||||
llm_or_chain_factory,
|
llm_or_chain_factory,
|
||||||
num_repetitions,
|
num_repetitions=num_repetitions,
|
||||||
|
session_name=session_name,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
if verbose:
|
|
||||||
print(f"{i+1} processed", flush=True, end="\r")
|
|
||||||
results[str(example.id)] = result
|
|
||||||
return results
|
|
||||||
|
375
langchain/client/runner_utils.py
Normal file
375
langchain/client/runner_utils.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
"""Utilities for running LLMs/Chains over datasets."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Coroutine, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.client.models import Example
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseMessage,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
LLMResult,
|
||||||
|
get_buffer_string,
|
||||||
|
messages_from_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel]
|
||||||
|
|
||||||
|
|
||||||
|
class InputFormatError(Exception):
|
||||||
|
"""Raised when input format is invalid."""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompts(inputs: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Get prompts from inputs."""
|
||||||
|
if not inputs:
|
||||||
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
if "prompt" in inputs:
|
||||||
|
if not isinstance(inputs["prompt"], str):
|
||||||
|
raise InputFormatError(
|
||||||
|
"Expected string for 'prompt', got"
|
||||||
|
f" {type(inputs['prompt']).__name__}"
|
||||||
|
)
|
||||||
|
prompts = [inputs["prompt"]]
|
||||||
|
elif "prompts" in inputs:
|
||||||
|
if not isinstance(inputs["prompts"], list) or not all(
|
||||||
|
isinstance(i, str) for i in inputs["prompts"]
|
||||||
|
):
|
||||||
|
raise InputFormatError(
|
||||||
|
"Expected list of strings for 'prompts',"
|
||||||
|
f" got {type(inputs['prompts']).__name__}"
|
||||||
|
)
|
||||||
|
prompts = inputs["prompts"]
|
||||||
|
elif len(inputs) == 1:
|
||||||
|
prompt_ = next(iter(inputs.values()))
|
||||||
|
if isinstance(prompt_, str):
|
||||||
|
prompts = [prompt_]
|
||||||
|
elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_):
|
||||||
|
prompts = prompt_
|
||||||
|
else:
|
||||||
|
raise InputFormatError(f"LLM Run expects string prompt input. Got {inputs}")
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]:
|
||||||
|
"""Get Chat Messages from inputs."""
|
||||||
|
if not inputs:
|
||||||
|
raise InputFormatError("Inputs should not be empty.")
|
||||||
|
|
||||||
|
if "messages" in inputs:
|
||||||
|
single_input = inputs["messages"]
|
||||||
|
elif len(inputs) == 1:
|
||||||
|
single_input = next(iter(inputs.values()))
|
||||||
|
else:
|
||||||
|
raise InputFormatError(f"Chat Run expects 'messages' in inputs. Got {inputs}")
|
||||||
|
if isinstance(single_input, list) and all(
|
||||||
|
isinstance(i, dict) for i in single_input
|
||||||
|
):
|
||||||
|
raw_messages = [single_input]
|
||||||
|
elif isinstance(single_input, list) and all(
|
||||||
|
isinstance(i, list) for i in single_input
|
||||||
|
):
|
||||||
|
raw_messages = single_input
|
||||||
|
else:
|
||||||
|
raise InputFormatError(
|
||||||
|
f"Chat Run expects List[dict] or List[List[dict]] 'messages'"
|
||||||
|
f" input. Got {inputs}"
|
||||||
|
)
|
||||||
|
return [messages_from_dict(batch) for batch in raw_messages]
|
||||||
|
|
||||||
|
|
||||||
|
async def _arun_llm(
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
langchain_tracer: Optional[LangChainTracer],
|
||||||
|
) -> Union[LLMResult, ChatResult]:
|
||||||
|
callbacks: Optional[List[BaseCallbackHandler]] = (
|
||||||
|
[langchain_tracer] if langchain_tracer else None
|
||||||
|
)
|
||||||
|
if isinstance(llm, BaseLLM):
|
||||||
|
try:
|
||||||
|
llm_prompts = _get_prompts(inputs)
|
||||||
|
llm_output = await llm.agenerate(llm_prompts, callbacks=callbacks)
|
||||||
|
except InputFormatError:
|
||||||
|
llm_messages = _get_messages(inputs)
|
||||||
|
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
|
||||||
|
llm_output = await llm.agenerate(buffer_strings, callbacks=callbacks)
|
||||||
|
elif isinstance(llm, BaseChatModel):
|
||||||
|
try:
|
||||||
|
messages = _get_messages(inputs)
|
||||||
|
llm_output = await llm.agenerate(messages, callbacks=callbacks)
|
||||||
|
except InputFormatError:
|
||||||
|
prompts = _get_prompts(inputs)
|
||||||
|
converted_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content=prompt)] for prompt in prompts
|
||||||
|
]
|
||||||
|
llm_output = await llm.agenerate(converted_messages, callbacks=callbacks)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
|
return llm_output
|
||||||
|
|
||||||
|
|
||||||
|
async def _arun_llm_or_chain(
|
||||||
|
example: Example,
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
n_repetitions: int,
|
||||||
|
langchain_tracer: Optional[LangChainTracer],
|
||||||
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
|
"""Run the chain asynchronously."""
|
||||||
|
if langchain_tracer is not None:
|
||||||
|
previous_example_id = langchain_tracer.example_id
|
||||||
|
langchain_tracer.example_id = example.id
|
||||||
|
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
||||||
|
else:
|
||||||
|
previous_example_id = None
|
||||||
|
callbacks = None
|
||||||
|
outputs = []
|
||||||
|
for _ in range(n_repetitions):
|
||||||
|
try:
|
||||||
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
|
output: Any = await _arun_llm(
|
||||||
|
llm_or_chain_factory, example.inputs, langchain_tracer
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chain = llm_or_chain_factory()
|
||||||
|
output = await chain.arun(example.inputs, callbacks=callbacks)
|
||||||
|
outputs.append(output)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||||
|
outputs.append({"Error": str(e)})
|
||||||
|
if langchain_tracer is not None:
|
||||||
|
langchain_tracer.example_id = previous_example_id
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
async def _gather_with_concurrency(
|
||||||
|
n: int,
|
||||||
|
initializer: Callable[[], Coroutine[Any, Any, Optional[LangChainTracer]]],
|
||||||
|
*async_funcs: Callable[[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]],
|
||||||
|
) -> List[Any]:
|
||||||
|
"""
|
||||||
|
Run coroutines with a concurrency limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n: The maximum number of concurrent tasks.
|
||||||
|
initializer: A coroutine that initializes shared resources for the tasks.
|
||||||
|
async_funcs: The async_funcs to be run concurrently.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of results from the coroutines.
|
||||||
|
"""
|
||||||
|
semaphore = asyncio.Semaphore(n)
|
||||||
|
job_state = {"num_processed": 0}
|
||||||
|
|
||||||
|
tracer_queue: asyncio.Queue[Optional[LangChainTracer]] = asyncio.Queue()
|
||||||
|
for _ in range(n):
|
||||||
|
tracer_queue.put_nowait(await initializer())
|
||||||
|
|
||||||
|
async def run_coroutine_with_semaphore(
|
||||||
|
async_func: Callable[
|
||||||
|
[Optional[LangChainTracer], Dict], Coroutine[Any, Any, Any]
|
||||||
|
]
|
||||||
|
) -> Any:
|
||||||
|
async with semaphore:
|
||||||
|
tracer = await tracer_queue.get()
|
||||||
|
try:
|
||||||
|
result = await async_func(tracer, job_state)
|
||||||
|
finally:
|
||||||
|
tracer_queue.put_nowait(tracer)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return await asyncio.gather(
|
||||||
|
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _tracer_initializer(session_name: Optional[str]) -> Optional[LangChainTracer]:
|
||||||
|
"""
|
||||||
|
Initialize a tracer to share across tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_name: The session name for the tracer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LangChainTracer instance with an active session.
|
||||||
|
"""
|
||||||
|
if session_name:
|
||||||
|
tracer = LangChainTracer(session_name=session_name)
|
||||||
|
tracer.ensure_session()
|
||||||
|
return tracer
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def arun_on_examples(
|
||||||
|
examples: Iterator[Example],
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
*,
|
||||||
|
concurrency_level: int = 5,
|
||||||
|
num_repetitions: int = 1,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Run the chain on examples and store traces to the specified session name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: Examples to run the model or chain over
|
||||||
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
|
over the dataset. The Chain constructor is used to permit
|
||||||
|
independent calls on each example without carrying over state.
|
||||||
|
concurrency_level: The number of async tasks to run concurrently.
|
||||||
|
num_repetitions: Number of times to run the model on each example.
|
||||||
|
This is useful when testing success rates or generating confidence
|
||||||
|
intervals.
|
||||||
|
session_name: Session name to use when tracing runs.
|
||||||
|
verbose: Whether to print progress.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping example ids to the model outputs.
|
||||||
|
"""
|
||||||
|
results: Dict[str, List[Any]] = {}
|
||||||
|
|
||||||
|
async def process_example(
|
||||||
|
example: Example, tracer: LangChainTracer, job_state: dict
|
||||||
|
) -> None:
|
||||||
|
"""Process a single example."""
|
||||||
|
result = await _arun_llm_or_chain(
|
||||||
|
example,
|
||||||
|
llm_or_chain_factory,
|
||||||
|
num_repetitions,
|
||||||
|
tracer,
|
||||||
|
)
|
||||||
|
results[str(example.id)] = result
|
||||||
|
job_state["num_processed"] += 1
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"Processed examples: {job_state['num_processed']}",
|
||||||
|
end="\r",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _gather_with_concurrency(
|
||||||
|
concurrency_level,
|
||||||
|
functools.partial(_tracer_initializer, session_name),
|
||||||
|
*(functools.partial(process_example, e) for e in examples),
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def run_llm(
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
callbacks: Callbacks,
|
||||||
|
) -> Union[LLMResult, ChatResult]:
|
||||||
|
"""Run the language model on the example."""
|
||||||
|
if isinstance(llm, BaseLLM):
|
||||||
|
try:
|
||||||
|
llm_prompts = _get_prompts(inputs)
|
||||||
|
llm_output = llm.generate(llm_prompts, callbacks=callbacks)
|
||||||
|
except InputFormatError:
|
||||||
|
llm_messages = _get_messages(inputs)
|
||||||
|
buffer_strings = [get_buffer_string(messages) for messages in llm_messages]
|
||||||
|
llm_output = llm.generate(buffer_strings, callbacks=callbacks)
|
||||||
|
elif isinstance(llm, BaseChatModel):
|
||||||
|
try:
|
||||||
|
messages = _get_messages(inputs)
|
||||||
|
llm_output = llm.generate(messages, callbacks=callbacks)
|
||||||
|
except InputFormatError:
|
||||||
|
prompts = _get_prompts(inputs)
|
||||||
|
converted_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content=prompt)] for prompt in prompts
|
||||||
|
]
|
||||||
|
llm_output = llm.generate(converted_messages, callbacks=callbacks)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
|
return llm_output
|
||||||
|
|
||||||
|
|
||||||
|
def run_llm_or_chain(
|
||||||
|
example: Example,
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
n_repetitions: int,
|
||||||
|
langchain_tracer: Optional[LangChainTracer] = None,
|
||||||
|
) -> Union[List[dict], List[str], List[LLMResult], List[ChatResult]]:
|
||||||
|
"""Run the chain synchronously."""
|
||||||
|
if langchain_tracer is not None:
|
||||||
|
previous_example_id = langchain_tracer.example_id
|
||||||
|
langchain_tracer.example_id = example.id
|
||||||
|
callbacks: Optional[List[BaseCallbackHandler]] = [langchain_tracer]
|
||||||
|
else:
|
||||||
|
previous_example_id = None
|
||||||
|
callbacks = None
|
||||||
|
outputs = []
|
||||||
|
for _ in range(n_repetitions):
|
||||||
|
try:
|
||||||
|
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||||
|
output: Any = run_llm(llm_or_chain_factory, example.inputs, callbacks)
|
||||||
|
else:
|
||||||
|
chain = llm_or_chain_factory()
|
||||||
|
output = chain.run(example.inputs, callbacks=callbacks)
|
||||||
|
outputs.append(output)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
||||||
|
outputs.append({"Error": str(e)})
|
||||||
|
if langchain_tracer is not None:
|
||||||
|
langchain_tracer.example_id = previous_example_id
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def run_on_examples(
|
||||||
|
examples: Iterator[Example],
|
||||||
|
llm_or_chain_factory: MODEL_OR_CHAIN_FACTORY,
|
||||||
|
*,
|
||||||
|
num_repetitions: int = 1,
|
||||||
|
session_name: Optional[str] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run the chain on examples and store traces to the specified session name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: Examples to run model or chain over.
|
||||||
|
llm_or_chain_factory: Language model or Chain constructor to run
|
||||||
|
over the dataset. The Chain constructor is used to permit
|
||||||
|
independent calls on each example without carrying over state.
|
||||||
|
concurrency_level: Number of async workers to run in parallel.
|
||||||
|
num_repetitions: Number of times to run the model on each example.
|
||||||
|
This is useful when testing success rates or generating confidence
|
||||||
|
intervals.
|
||||||
|
session_name: Session name to use when tracing runs.
|
||||||
|
verbose: Whether to print progress.
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping example ids to the model outputs.
|
||||||
|
"""
|
||||||
|
results: Dict[str, Any] = {}
|
||||||
|
tracer = LangChainTracer(session_name=session_name) if session_name else None
|
||||||
|
for i, example in enumerate(examples):
|
||||||
|
result = run_llm_or_chain(
|
||||||
|
example,
|
||||||
|
llm_or_chain_factory,
|
||||||
|
num_repetitions,
|
||||||
|
langchain_tracer=tracer,
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print(f"{i+1} processed", flush=True, end="\r")
|
||||||
|
results[str(example.id)] = result
|
||||||
|
return results
|
@ -12,14 +12,11 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
|||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.client.langchain import (
|
from langchain.client.langchain import (
|
||||||
InputFormatError,
|
|
||||||
LangChainPlusClient,
|
LangChainPlusClient,
|
||||||
_get_link_stem,
|
_get_link_stem,
|
||||||
_is_localhost,
|
_is_localhost,
|
||||||
)
|
)
|
||||||
from langchain.client.models import Dataset, Example
|
from langchain.client.models import Dataset, Example
|
||||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
||||||
|
|
||||||
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
_CREATED_AT = datetime(2015, 1, 1, 0, 0, 0)
|
||||||
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
_TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4"
|
||||||
@ -191,9 +188,9 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
|
|
||||||
async def mock_arun_chain(
|
async def mock_arun_chain(
|
||||||
example: Example,
|
example: Example,
|
||||||
tracer: Any,
|
|
||||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||||
n_repetitions: int,
|
n_repetitions: int,
|
||||||
|
tracer: Any,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
return [
|
return [
|
||||||
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
{"result": f"Result for example {example.id}"} for _ in range(n_repetitions)
|
||||||
@ -206,8 +203,8 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
LangChainPlusClient, "read_dataset", new=mock_read_dataset
|
||||||
), mock.patch.object(
|
), mock.patch.object(
|
||||||
LangChainPlusClient, "list_examples", new=mock_list_examples
|
LangChainPlusClient, "list_examples", new=mock_list_examples
|
||||||
), mock.patch.object(
|
), mock.patch(
|
||||||
LangChainPlusClient, "_arun_llm_or_chain", new=mock_arun_chain
|
"langchain.client.runner_utils._arun_llm_or_chain", new=mock_arun_chain
|
||||||
), mock.patch.object(
|
), mock.patch.object(
|
||||||
LangChainTracer, "ensure_session", new=mock_ensure_session
|
LangChainTracer, "ensure_session", new=mock_ensure_session
|
||||||
):
|
):
|
||||||
@ -233,85 +230,3 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
for uuid_ in uuids
|
for uuid_ in uuids
|
||||||
}
|
}
|
||||||
assert results == expected
|
assert results == expected
|
||||||
|
|
||||||
|
|
||||||
_EXAMPLE_MESSAGE = {
|
|
||||||
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
|
||||||
"type": "human",
|
|
||||||
}
|
|
||||||
_VALID_MESSAGES = [
|
|
||||||
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
|
||||||
{"messages": [], "other_key": "value"},
|
|
||||||
{
|
|
||||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
|
|
||||||
"other_key": "value",
|
|
||||||
},
|
|
||||||
{"any_key": [_EXAMPLE_MESSAGE]},
|
|
||||||
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
|
|
||||||
]
|
|
||||||
_VALID_PROMPTS = [
|
|
||||||
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
|
|
||||||
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
|
||||||
{"some_key": "foo"},
|
|
||||||
{"some_key": ["foo", "bar"]},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inputs",
|
|
||||||
_VALID_MESSAGES,
|
|
||||||
)
|
|
||||||
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
|
||||||
{"messages": []}
|
|
||||||
LangChainPlusClient._get_messages(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inputs",
|
|
||||||
_VALID_PROMPTS,
|
|
||||||
)
|
|
||||||
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
|
||||||
LangChainPlusClient._get_prompts(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inputs",
|
|
||||||
[
|
|
||||||
{"prompts": "foo"},
|
|
||||||
{"prompt": ["foo"]},
|
|
||||||
{"some_key": 3},
|
|
||||||
{"some_key": "foo", "other_key": "bar"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
|
||||||
with pytest.raises(InputFormatError):
|
|
||||||
LangChainPlusClient._get_prompts(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"inputs",
|
|
||||||
[
|
|
||||||
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
|
||||||
{
|
|
||||||
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
|
||||||
"other_key": "value",
|
|
||||||
},
|
|
||||||
{"prompts": "foo"},
|
|
||||||
{},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
|
||||||
with pytest.raises(InputFormatError):
|
|
||||||
LangChainPlusClient._get_messages(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
|
||||||
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
|
||||||
llm = FakeLLM()
|
|
||||||
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
|
||||||
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
|
||||||
llm = FakeChatModel()
|
|
||||||
LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())
|
|
||||||
|
95
tests/unit_tests/client/test_runner_utils.py
Normal file
95
tests/unit_tests/client/test_runner_utils.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Test the LangChain+ client."""
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.client.runner_utils import (
|
||||||
|
InputFormatError,
|
||||||
|
_get_messages,
|
||||||
|
_get_prompts,
|
||||||
|
run_llm,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
_EXAMPLE_MESSAGE = {
|
||||||
|
"data": {"content": "Foo", "example": False, "additional_kwargs": {}},
|
||||||
|
"type": "human",
|
||||||
|
}
|
||||||
|
_VALID_MESSAGES = [
|
||||||
|
{"messages": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||||
|
{"messages": [], "other_key": "value"},
|
||||||
|
{
|
||||||
|
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]],
|
||||||
|
"other_key": "value",
|
||||||
|
},
|
||||||
|
{"any_key": [_EXAMPLE_MESSAGE]},
|
||||||
|
{"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]},
|
||||||
|
]
|
||||||
|
_VALID_PROMPTS = [
|
||||||
|
{"prompts": ["foo", "bar", "baz"], "other_key": "value"},
|
||||||
|
{"prompt": "foo", "other_key": ["bar", "baz"]},
|
||||||
|
{"some_key": "foo"},
|
||||||
|
{"some_key": ["foo", "bar"]},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
_VALID_MESSAGES,
|
||||||
|
)
|
||||||
|
def test__get_messages_valid(inputs: Dict[str, Any]) -> None:
|
||||||
|
{"messages": []}
|
||||||
|
_get_messages(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
_VALID_PROMPTS,
|
||||||
|
)
|
||||||
|
def test__get_prompts_valid(inputs: Dict[str, Any]) -> None:
|
||||||
|
_get_prompts(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
[
|
||||||
|
{"prompts": "foo"},
|
||||||
|
{"prompt": ["foo"]},
|
||||||
|
{"some_key": 3},
|
||||||
|
{"some_key": "foo", "other_key": "bar"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None:
|
||||||
|
with pytest.raises(InputFormatError):
|
||||||
|
_get_prompts(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"inputs",
|
||||||
|
[
|
||||||
|
{"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"},
|
||||||
|
{
|
||||||
|
"messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE],
|
||||||
|
"other_key": "value",
|
||||||
|
},
|
||||||
|
{"prompts": "foo"},
|
||||||
|
{},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test__get_messages_invalid(inputs: Dict[str, Any]) -> None:
|
||||||
|
with pytest.raises(InputFormatError):
|
||||||
|
_get_messages(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES)
|
||||||
|
def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None:
|
||||||
|
llm = FakeLLM()
|
||||||
|
run_llm(llm, inputs, mock.MagicMock())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS)
|
||||||
|
def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None:
|
||||||
|
llm = FakeChatModel()
|
||||||
|
run_llm(llm, inputs, mock.MagicMock())
|
Loading…
Reference in New Issue
Block a user