From 6c6131506798187786fe7547425dc54379438ce0 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Thu, 12 Oct 2023 20:01:18 -0400 Subject: [PATCH] fix(openllm): update with newer remote client implementation (#11740) cc @baskaryan --------- Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- libs/langchain/langchain/llms/openllm.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/llms/openllm.py b/libs/langchain/langchain/llms/openllm.py index 42238ac3b8..1fc81acb4f 100644 --- a/libs/langchain/langchain/llms/openllm.py +++ b/libs/langchain/langchain/llms/openllm.py @@ -85,7 +85,7 @@ class OpenLLM(LLM): server_type: ServerType = "http" """Optional server type. Either 'http' or 'grpc'.""" embedded: bool = True - """Initialize this LLM instance in current process by default. Should + """Initialize this LLM instance in current process by default. Should only set to False when using in conjunction with BentoML Service.""" llm_kwargs: Dict[str, Any] """Keyword arguments to be passed to openllm.LLM""" @@ -217,9 +217,9 @@ class OpenLLM(LLM): def _identifying_params(self) -> IdentifyingParams: """Get the identifying parameters.""" if self._client is not None: - self.llm_kwargs.update(self._client.configuration) - model_name = self._client.model_name - model_id = self._client.model_id + self.llm_kwargs.update(self._client._config()) + model_name = self._client._metadata()["model_name"] + model_id = self._client._metadata()["model_id"] else: if self._runner is None: raise ValueError("Runner must be initialized.") @@ -265,7 +265,9 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - res = self._client.query(prompt, **config.model_dump(flatten=True)) + res = self._client.generate( + prompt, **config.model_dump(flatten=True) + ).responses[0] else: assert self._runner is not None res = self._runner(prompt, **config.model_dump(flatten=True)) @@ -300,9 +302,10 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - res = await self._client.acall( - "generate", prompt, **config.model_dump(flatten=True) - ) + async_client = openllm.client.AsyncHTTPClient(self.server_url) + res = ( + await async_client.generate(prompt, **config.model_dump(flatten=True)) + ).responses[0] else: assert self._runner is not None (