langchain/libs/community/langchain_community/callbacks/promptlayer_callback.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

164 lines
5.4 KiB
Python

"""Callback handler for promptlayer."""
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import (
ChatGeneration,
LLMResult,
)
if TYPE_CHECKING:
import promptlayer
def _lazy_import_promptlayer() -> promptlayer:
"""Lazy import promptlayer to avoid circular imports."""
try:
import promptlayer
except ImportError:
raise ImportError(
"The PromptLayerCallbackHandler requires the promptlayer package. "
" Please install it with `pip install promptlayer`."
)
return promptlayer
class PromptLayerCallbackHandler(BaseCallbackHandler):
"""Callback handler for promptlayer."""
def __init__(
self,
pl_id_callback: Optional[Callable[..., Any]] = None,
pl_tags: Optional[List[str]] = None,
) -> None:
"""Initialize the PromptLayerCallbackHandler."""
_lazy_import_promptlayer()
self.pl_id_callback = pl_id_callback
self.pl_tags = pl_tags or []
self.runs: Dict[UUID, Dict[str, Any]] = {}
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"messages": [self._create_message_dicts(m)[0] for m in messages],
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_start(
self,
serialized: Dict[str, Any],
prompts: List[str],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> Any:
self.runs[run_id] = {
"prompts": prompts,
"invocation_params": kwargs.get("invocation_params", {}),
"name": ".".join(serialized["id"]),
"request_start_time": datetime.datetime.now().timestamp(),
"tags": tags,
}
def on_llm_end(
self,
response: LLMResult,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
from promptlayer.utils import get_api_key, promptlayer_api_request
run_info = self.runs.get(run_id, {})
if not run_info:
return
run_info["request_end_time"] = datetime.datetime.now().timestamp()
for i in range(len(response.generations)):
generation = response.generations[i][0]
resp = {
"text": generation.text,
"llm_output": response.llm_output,
}
model_params = run_info.get("invocation_params", {})
is_chat_model = run_info.get("messages", None) is not None
model_input = (
run_info.get("messages", [])[i]
if is_chat_model
else [run_info.get("prompts", [])[i]]
)
model_response = (
[self._convert_message_to_dict(generation.message)]
if is_chat_model and isinstance(generation, ChatGeneration)
else resp
)
pl_request_id = promptlayer_api_request(
run_info.get("name"),
"langchain",
model_input,
model_params,
self.pl_tags,
model_response,
run_info.get("request_start_time"),
run_info.get("request_end_time"),
get_api_key(),
return_pl_id=bool(self.pl_id_callback is not None),
metadata={
"_langchain_run_id": str(run_id),
"_langchain_parent_run_id": str(parent_run_id),
"_langchain_tags": str(run_info.get("tags", [])),
},
)
if self.pl_id_callback:
self.pl_id_callback(pl_request_id)
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params: Dict[str, Any] = {}
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params