From 49694f6a3fe96edc5bdfc0c2f75762cc8f46354c Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:13:15 -0700 Subject: [PATCH] explicitly check openllm return type (#10560) cc @aarnphm --- libs/langchain/langchain/llms/openllm.py | 39 ++++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/libs/langchain/langchain/llms/openllm.py b/libs/langchain/langchain/llms/openllm.py index d0d70f1494..4677d8b96c 100644 --- a/libs/langchain/langchain/llms/openllm.py +++ b/libs/langchain/langchain/llms/openllm.py @@ -265,16 +265,19 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - o = self._client.query(prompt, **config.model_dump(flatten=True)) - if isinstance(o, dict) and "text" in o: - return o["text"] - return o + res = self._client.query(prompt, **config.model_dump(flatten=True)) else: assert self._runner is not None - o = self._runner(prompt, **config.model_dump(flatten=True)) - if isinstance(o, dict) and "text" in o: - return o["text"] - return o + res = self._runner(prompt, **config.model_dump(flatten=True)) + if isinstance(res, dict) and "text" in res: + return res["text"] + elif isinstance(res, str): + return res + else: + raise ValueError( + "Expected result to be a dict with key 'text' or a string. " + f"Received {res}" + ) async def _acall( self, @@ -297,12 +300,9 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - o = await self._client.acall( + res = await self._client.acall( "generate", prompt, **config.model_dump(flatten=True) ) - if isinstance(o, dict) and "text" in o: - return o["text"] - return o else: assert self._runner is not None ( @@ -313,9 +313,16 @@ class OpenLLM(LLM): generated_result = await self._runner.generate.async_run( prompt, **generate_kwargs ) - o = self._runner.llm.postprocess_generate( + res = self._runner.llm.postprocess_generate( prompt, generated_result, **postprocess_kwargs ) - if isinstance(o, dict) and "text" in o: - return o["text"] - return o + + if isinstance(res, dict) and "text" in res: + return res["text"] + elif isinstance(res, str): + return res + else: + raise ValueError( + "Expected result to be a dict with key 'text' or a string. " + f"Received {res}" + )