diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 5f6d25b5f8..83d8330582 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -39,7 +39,14 @@ from langchain.client.models import ( ListRunsQueryParams, ) from langchain.llms.base import BaseLLM -from langchain.schema import ChatResult, LLMResult, messages_from_dict +from langchain.schema import ( + BaseMessage, + ChatResult, + HumanMessage, + LLMResult, + get_buffer_string, + messages_from_dict, +) from langchain.utils import raise_for_status_with_text, xor_args if TYPE_CHECKING: @@ -50,6 +57,10 @@ logger = logging.getLogger(__name__) MODEL_OR_CHAIN_FACTORY = Union[Callable[[], Chain], BaseLanguageModel] +class InputFormatError(Exception): + """Raised when input format is invalid.""" + + def _get_link_stem(url: str) -> str: scheme = urlsplit(url).scheme netloc_prefix = urlsplit(url).netloc.split(":")[0] @@ -389,6 +400,76 @@ class LangChainPlusClient(BaseSettings): raise_for_status_with_text(response) return [Example(**dataset) for dataset in response.json()] + @staticmethod + def _get_prompts(inputs: Dict[str, Any]) -> List[str]: + """Get prompts from inputs.""" + if not inputs: + raise InputFormatError("Inputs should not be empty.") + + prompts = [] + + if "prompt" in inputs: + if not isinstance(inputs["prompt"], str): + raise InputFormatError( + "Expected string for 'prompt', got" + f" {type(inputs['prompt']).__name__}" + ) + prompts = [inputs["prompt"]] + elif "prompts" in inputs: + if not isinstance(inputs["prompts"], list) or not all( + isinstance(i, str) for i in inputs["prompts"] + ): + raise InputFormatError( + "Expected list of strings for 'prompts'," + f" got {type(inputs['prompts']).__name__}" + ) + prompts = inputs["prompts"] + elif len(inputs) == 1: + prompt_ = next(iter(inputs.values())) + if isinstance(prompt_, str): + prompts = [prompt_] + elif isinstance(prompt_, list) and all(isinstance(i, str) for i in prompt_): + prompts = prompt_ + else: + raise InputFormatError( + f"LLM Run expects string prompt input. Got {inputs}" + ) + else: + raise InputFormatError( + f"LLM Run expects 'prompt' or 'prompts' in inputs. Got {inputs}" + ) + + return prompts + + @staticmethod + def _get_messages(inputs: Dict[str, Any]) -> List[List[BaseMessage]]: + """Get Chat Messages from inputs.""" + if not inputs: + raise InputFormatError("Inputs should not be empty.") + + if "messages" in inputs: + single_input = inputs["messages"] + elif len(inputs) == 1: + single_input = next(iter(inputs.values())) + else: + raise InputFormatError( + f"Chat Run expects 'messages' in inputs. Got {inputs}" + ) + if isinstance(single_input, list) and all( + isinstance(i, dict) for i in single_input + ): + raw_messages = [single_input] + elif isinstance(single_input, list) and all( + isinstance(i, list) for i in single_input + ): + raw_messages = single_input + else: + raise InputFormatError( + f"Chat Run expects List[dict] or List[List[dict]] 'messages'" + f" input. Got {inputs}" + ) + return [messages_from_dict(batch) for batch in raw_messages] + @staticmethod async def _arun_llm( llm: BaseLanguageModel, @@ -396,16 +477,31 @@ class LangChainPlusClient(BaseSettings): langchain_tracer: LangChainTracer, ) -> Union[LLMResult, ChatResult]: if isinstance(llm, BaseLLM): - if "prompt" not in inputs: - raise ValueError(f"LLM Run requires 'prompt' input. Got {inputs}") - llm_prompt: str = inputs["prompt"] - llm_output = await llm.agenerate([llm_prompt], callbacks=[langchain_tracer]) + try: + llm_prompts = LangChainPlusClient._get_prompts(inputs) + llm_output = await llm.agenerate( + llm_prompts, callbacks=[langchain_tracer] + ) + except InputFormatError: + llm_messages = LangChainPlusClient._get_messages(inputs) + buffer_strings = [ + get_buffer_string(messages) for messages in llm_messages + ] + llm_output = await llm.agenerate( + buffer_strings, callbacks=[langchain_tracer] + ) elif isinstance(llm, BaseChatModel): - if "messages" not in inputs: - raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}") - raw_messages: List[dict] = inputs["messages"] - messages = messages_from_dict(raw_messages) - llm_output = await llm.agenerate([messages], callbacks=[langchain_tracer]) + try: + messages = LangChainPlusClient._get_messages(inputs) + llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer]) + except InputFormatError: + prompts = LangChainPlusClient._get_prompts(inputs) + converted_messages: List[List[BaseMessage]] = [ + [HumanMessage(content=prompt)] for prompt in prompts + ] + llm_output = await llm.agenerate( + converted_messages, callbacks=[langchain_tracer] + ) else: raise ValueError(f"Unsupported LLM type {type(llm)}") return llm_output @@ -562,18 +658,27 @@ class LangChainPlusClient(BaseSettings): ) -> Union[LLMResult, ChatResult]: """Run the language model on the example.""" if isinstance(llm, BaseLLM): - if "prompt" not in inputs: - raise ValueError(f"LLM Run must contain 'prompt' key. Got {inputs}") - llm_prompt: str = inputs["prompt"] - llm_output = llm.generate([llm_prompt], callbacks=[langchain_tracer]) + try: + llm_prompts = LangChainPlusClient._get_prompts(inputs) + llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer]) + except InputFormatError: + llm_messages = LangChainPlusClient._get_messages(inputs) + buffer_strings = [ + get_buffer_string(messages) for messages in llm_messages + ] + llm_output = llm.generate(buffer_strings, callbacks=[langchain_tracer]) elif isinstance(llm, BaseChatModel): - if "messages" not in inputs: - raise ValueError( - f"Chat Model Run must contain 'messages' key. Got {inputs}" + try: + messages = LangChainPlusClient._get_messages(inputs) + llm_output = llm.generate(messages, callbacks=[langchain_tracer]) + except InputFormatError: + prompts = LangChainPlusClient._get_prompts(inputs) + converted_messages: List[List[BaseMessage]] = [ + [HumanMessage(content=prompt)] for prompt in prompts + ] + llm_output = llm.generate( + converted_messages, callbacks=[langchain_tracer] ) - raw_messages: List[dict] = inputs["messages"] - messages = messages_from_dict(raw_messages) - llm_output = llm.generate([messages], callbacks=[langchain_tracer]) else: raise ValueError(f"Unsupported LLM type {type(llm)}") return llm_output diff --git a/tests/unit_tests/client/test_langchain.py b/tests/unit_tests/client/test_langchain.py index 07fddcbd27..e712b93f6b 100644 --- a/tests/unit_tests/client/test_langchain.py +++ b/tests/unit_tests/client/test_langchain.py @@ -12,11 +12,14 @@ from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.schemas import TracerSession from langchain.chains.base import Chain from langchain.client.langchain import ( + InputFormatError, LangChainPlusClient, _get_link_stem, _is_localhost, ) from langchain.client.models import Dataset, Example +from tests.unit_tests.llms.fake_chat_model import FakeChatModel +from tests.unit_tests.llms.fake_llm import FakeLLM _CREATED_AT = datetime(2015, 1, 1, 0, 0, 0) _TENANT_ID = "7a3d2b56-cd5b-44e5-846f-7eb6e8144ce4" @@ -230,3 +233,85 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: for uuid_ in uuids } assert results == expected + + +_EXAMPLE_MESSAGE = { + "data": {"content": "Foo", "example": False, "additional_kwargs": {}}, + "type": "human", +} +_VALID_MESSAGES = [ + {"messages": [_EXAMPLE_MESSAGE], "other_key": "value"}, + {"messages": [], "other_key": "value"}, + { + "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]], + "other_key": "value", + }, + {"any_key": [_EXAMPLE_MESSAGE]}, + {"any_key": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], [_EXAMPLE_MESSAGE]]}, +] +_VALID_PROMPTS = [ + {"prompts": ["foo", "bar", "baz"], "other_key": "value"}, + {"prompt": "foo", "other_key": ["bar", "baz"]}, + {"some_key": "foo"}, + {"some_key": ["foo", "bar"]}, +] + + +@pytest.mark.parametrize( + "inputs", + _VALID_MESSAGES, +) +def test__get_messages_valid(inputs: Dict[str, Any]) -> None: + {"messages": []} + LangChainPlusClient._get_messages(inputs) + + +@pytest.mark.parametrize( + "inputs", + _VALID_PROMPTS, +) +def test__get_prompts_valid(inputs: Dict[str, Any]) -> None: + LangChainPlusClient._get_prompts(inputs) + + +@pytest.mark.parametrize( + "inputs", + [ + {"prompts": "foo"}, + {"prompt": ["foo"]}, + {"some_key": 3}, + {"some_key": "foo", "other_key": "bar"}, + ], +) +def test__get_prompts_invalid(inputs: Dict[str, Any]) -> None: + with pytest.raises(InputFormatError): + LangChainPlusClient._get_prompts(inputs) + + +@pytest.mark.parametrize( + "inputs", + [ + {"one_key": [_EXAMPLE_MESSAGE], "other_key": "value"}, + { + "messages": [[_EXAMPLE_MESSAGE, _EXAMPLE_MESSAGE], _EXAMPLE_MESSAGE], + "other_key": "value", + }, + {"prompts": "foo"}, + {}, + ], +) +def test__get_messages_invalid(inputs: Dict[str, Any]) -> None: + with pytest.raises(InputFormatError): + LangChainPlusClient._get_messages(inputs) + + +@pytest.mark.parametrize("inputs", _VALID_PROMPTS + _VALID_MESSAGES) +def test_run_llm_all_formats(inputs: Dict[str, Any]) -> None: + llm = FakeLLM() + LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock()) + + +@pytest.mark.parametrize("inputs", _VALID_MESSAGES + _VALID_PROMPTS) +def test_run_chat_model_all_formats(inputs: Dict[str, Any]) -> None: + llm = FakeChatModel() + LangChainPlusClient.run_llm(llm, inputs, mock.MagicMock())