langchain/libs/community/langchain_community/llms/openllm.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

338 lines
11 KiB
Python

from __future__ import annotations
import copy
import json
import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Literal,
Optional,
TypedDict,
Union,
overload,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import PrivateAttr
if TYPE_CHECKING:
import openllm
ServerType = Literal["http", "grpc"]
class IdentifyingParams(TypedDict):
"""Parameters for identifying a model as a typed dict."""
model_name: str
model_id: Optional[str]
server_url: Optional[str]
server_type: Optional[ServerType]
embedded: bool
llm_kwargs: Dict[str, Any]
logger = logging.getLogger(__name__)
class OpenLLM(LLM):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
To use, you should have the openllm library installed:
.. code-block:: bash
pip install openllm
Learn more at: https://github.com/bentoml/openllm
Example running an LLM model locally managed by OpenLLM:
.. code-block:: python
from langchain_community.llms import OpenLLM
llm = OpenLLM(
model_name='flan-t5',
model_id='google/flan-t5-large',
)
llm.invoke("What is the difference between a duck and a goose?")
For all available supported models, you can run 'openllm models'.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain_community.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm.invoke("What is the difference between a duck and a goose?")
"""
model_name: Optional[str] = None
"""Model name to use. See 'openllm models' for all available models."""
model_id: Optional[str] = None
"""Model Id to use. If not provided, will use the default model for the model name.
See 'openllm models' for all available model variants."""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
timeout: int = 30
""""Time out for the openllm client"""
server_type: ServerType = "http"
"""Optional server type. Either 'http' or 'grpc'."""
embedded: bool = True
"""Initialize this LLM instance in current process by default. Should
only set to False when using in conjunction with BentoML Service."""
llm_kwargs: Dict[str, Any]
"""Keyword arguments to be passed to openllm.LLM"""
_runner: Optional[openllm.LLMRunner] = PrivateAttr(default=None)
_client: Union[openllm.client.HTTPClient, openllm.client.GrpcClient, None] = (
PrivateAttr(default=None)
)
class Config:
extra = "forbid"
@overload
def __init__(
self,
model_name: Optional[str] = ...,
*,
model_id: Optional[str] = ...,
embedded: Literal[True, False] = ...,
**llm_kwargs: Any,
) -> None: ...
@overload
def __init__(
self,
*,
server_url: str = ...,
server_type: Literal["grpc", "http"] = ...,
**llm_kwargs: Any,
) -> None: ...
def __init__(
self,
model_name: Optional[str] = None,
*,
model_id: Optional[str] = None,
server_url: Optional[str] = None,
timeout: int = 30,
server_type: Literal["grpc", "http"] = "http",
embedded: bool = True,
**llm_kwargs: Any,
):
try:
import openllm
except ImportError as e:
raise ImportError(
"Could not import openllm. Make sure to install it with "
"'pip install openllm.'"
) from e
llm_kwargs = llm_kwargs or {}
if server_url is not None:
logger.debug("'server_url' is provided, returning a openllm.Client")
assert (
model_id is None and model_name is None
), "'server_url' and {'model_id', 'model_name'} are mutually exclusive"
client_cls = (
openllm.client.HTTPClient
if server_type == "http"
else openllm.client.GrpcClient
)
client = client_cls(server_url, timeout)
super().__init__(
**{ # type: ignore[arg-type]
"server_url": server_url,
"timeout": timeout,
"server_type": server_type,
"llm_kwargs": llm_kwargs,
}
)
self._runner = None # type: ignore
self._client = client
else:
assert model_name is not None, "Must provide 'model_name' or 'server_url'"
# since the LLM are relatively huge, we don't actually want to convert the
# Runner with embedded when running the server. Instead, we will only set
# the init_local here so that LangChain users can still use the LLM
# in-process. Wrt to BentoML users, setting embedded=False is the expected
# behaviour to invoke the runners remotely.
# We need to also enable ensure_available to download and setup the model.
runner = openllm.Runner(
model_name=model_name,
model_id=model_id,
init_local=embedded,
ensure_available=True,
**llm_kwargs,
)
super().__init__(
**{ # type: ignore[arg-type]
"model_name": model_name,
"model_id": model_id,
"embedded": embedded,
"llm_kwargs": llm_kwargs,
}
)
self._client = None # type: ignore
self._runner = runner
@property
def runner(self) -> openllm.LLMRunner:
"""
Get the underlying openllm.LLMRunner instance for integration with BentoML.
Example:
.. code-block:: python
llm = OpenLLM(
model_name='flan-t5',
model_id='google/flan-t5-large',
embedded=False,
)
tools = load_tools(["serpapi", "llm-math"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION
)
svc = bentoml.Service("langchain-openllm", runners=[llm.runner])
@svc.api(input=Text(), output=Text())
def chat(input_text: str):
return agent.run(input_text)
"""
if self._runner is None:
raise ValueError("OpenLLM must be initialized locally with 'model_name'")
return self._runner
@property
def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters."""
if self._client is not None:
self.llm_kwargs.update(self._client._config)
model_name = self._client._metadata.model_dump()["model_name"]
model_id = self._client._metadata.model_dump()["model_id"]
else:
if self._runner is None:
raise ValueError("Runner must be initialized.")
model_name = self.model_name
model_id = self.model_id
try:
self.llm_kwargs.update(
json.loads(self._runner.identifying_params["configuration"])
)
except (TypeError, json.JSONDecodeError):
pass
return IdentifyingParams(
server_url=self.server_url,
server_type=self.server_type,
embedded=self.embedded,
llm_kwargs=self.llm_kwargs,
model_name=model_name,
model_id=model_id,
)
@property
def _llm_type(self) -> str:
return "openllm_client" if self._client else "openllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
import openllm
except ImportError as e:
raise ImportError(
"Could not import openllm. Make sure to install it with "
"'pip install openllm'."
) from e
copied = copy.deepcopy(self.llm_kwargs)
copied.update(kwargs)
config = openllm.AutoConfig.for_model(
self._identifying_params["model_name"], **copied
)
if self._client:
res = (
self._client.generate(prompt, **config.model_dump(flatten=True))
.outputs[0]
.text
)
else:
assert self._runner is not None
res = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
import openllm
except ImportError as e:
raise ImportError(
"Could not import openllm. Make sure to install it with "
"'pip install openllm'."
) from e
copied = copy.deepcopy(self.llm_kwargs)
copied.update(kwargs)
config = openllm.AutoConfig.for_model(
self._identifying_params["model_name"], **copied
)
if self._client:
async_client = openllm.client.AsyncHTTPClient(self.server_url, self.timeout)
res = (
(await async_client.generate(prompt, **config.model_dump(flatten=True)))
.outputs[0]
.text
)
else:
assert self._runner is not None
(
prompt,
generate_kwargs,
postprocess_kwargs,
) = self._runner.llm.sanitize_parameters(prompt, **kwargs)
generated_result = await self._runner.generate.async_run(
prompt, **generate_kwargs
)
res = self._runner.llm.postprocess_generate(
prompt, generated_result, **postprocess_kwargs
)
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
)