fix(openllm): update with newer remote client implementation (#11740)

cc @baskaryan

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
pull/11745/head
Aaron Pham 10 months ago committed by GitHub
parent 11cdfe44af
commit 6c61315067
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -85,7 +85,7 @@ class OpenLLM(LLM):
server_type: ServerType = "http" server_type: ServerType = "http"
"""Optional server type. Either 'http' or 'grpc'.""" """Optional server type. Either 'http' or 'grpc'."""
embedded: bool = True embedded: bool = True
"""Initialize this LLM instance in current process by default. Should """Initialize this LLM instance in current process by default. Should
only set to False when using in conjunction with BentoML Service.""" only set to False when using in conjunction with BentoML Service."""
llm_kwargs: Dict[str, Any] llm_kwargs: Dict[str, Any]
"""Keyword arguments to be passed to openllm.LLM""" """Keyword arguments to be passed to openllm.LLM"""
@ -217,9 +217,9 @@ class OpenLLM(LLM):
def _identifying_params(self) -> IdentifyingParams: def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters.""" """Get the identifying parameters."""
if self._client is not None: if self._client is not None:
self.llm_kwargs.update(self._client.configuration) self.llm_kwargs.update(self._client._config())
model_name = self._client.model_name model_name = self._client._metadata()["model_name"]
model_id = self._client.model_id model_id = self._client._metadata()["model_id"]
else: else:
if self._runner is None: if self._runner is None:
raise ValueError("Runner must be initialized.") raise ValueError("Runner must be initialized.")
@ -265,7 +265,9 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied self._identifying_params["model_name"], **copied
) )
if self._client: if self._client:
res = self._client.query(prompt, **config.model_dump(flatten=True)) res = self._client.generate(
prompt, **config.model_dump(flatten=True)
).responses[0]
else: else:
assert self._runner is not None assert self._runner is not None
res = self._runner(prompt, **config.model_dump(flatten=True)) res = self._runner(prompt, **config.model_dump(flatten=True))
@ -300,9 +302,10 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied self._identifying_params["model_name"], **copied
) )
if self._client: if self._client:
res = await self._client.acall( async_client = openllm.client.AsyncHTTPClient(self.server_url)
"generate", prompt, **config.model_dump(flatten=True) res = (
) await async_client.generate(prompt, **config.model_dump(flatten=True))
).responses[0]
else: else:
assert self._runner is not None assert self._runner is not None
( (

Loading…
Cancel
Save