fix: unify generation outputs on newer openllm release (#10523)

update newer generation format from OpenLLm where it returns a
dictionary for one shot generation

cc @baskaryan 

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron Pham 2023-09-13 16:49:16 -04:00 committed by GitHub
parent 201b61d5b3
commit ac9609f58f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -265,10 +265,16 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied
)
if self._client:
return self._client.query(prompt, **config.model_dump(flatten=True))
o = self._client.query(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
else:
assert self._runner is not None
return self._runner(prompt, **config.model_dump(flatten=True))
o = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
async def _acall(
self,
@ -291,9 +297,12 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied
)
if self._client:
return await self._client.acall(
o = await self._client.acall(
"generate", prompt, **config.model_dump(flatten=True)
)
if isinstance(o, dict) and "text" in o:
return o["text"]
return o
else:
assert self._runner is not None
(
@ -304,6 +313,9 @@ class OpenLLM(LLM):
generated_result = await self._runner.generate.async_run(
prompt, **generate_kwargs
)
return self._runner.llm.postprocess_generate(
o = self._runner.llm.postprocess_generate(
prompt, generated_result, **postprocess_kwargs
)
if isinstance(o, dict) and "text" in o:
return o["text"]
return o