explicitly check openllm return type (#10560)

cc @aarnphm
pull/10594/head
Bagatur 1 year ago committed by GitHub
parent 85e05fa5d6
commit 49694f6a3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -265,16 +265,19 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied
)
if self._client:
o = self._client.query(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
res = self._client.query(prompt, **config.model_dump(flatten=True))
else:
assert self._runner is not None
o = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
res = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)
async def _acall(
self,
@ -297,12 +300,9 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied
)
if self._client:
o = await self._client.acall(
res = await self._client.acall(
"generate", prompt, **config.model_dump(flatten=True)
)
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
else:
assert self._runner is not None
(
@ -313,9 +313,16 @@ class OpenLLM(LLM):
generated_result = await self._runner.generate.async_run(
prompt, **generate_kwargs
)
o = self._runner.llm.postprocess_generate(
res = self._runner.llm.postprocess_generate(
prompt, generated_result, **postprocess_kwargs
)
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)

Loading…
Cancel
Save