Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2023-06-11 23:51:28 +01:00 committed by GitHub
parent 614cff89bc
commit 18af149e91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 810 additions and 71 deletions

View File

@ -4,9 +4,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)
class BaseLanguageModel(BaseModel, ABC):
class BaseLanguageModel(Serializable, ABC):
@abstractmethod
def generate_prompt(
self,

View File

@ -204,7 +204,7 @@ def _handle_event(
except Exception as e:
if handler.raise_error:
raise e
logging.warning(f"Error in {event_name} callback: {e}")
logger.warning(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler(

View File

@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
llm_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"prompts": prompts},
@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
chain_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs=inputs,
@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
tool_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"input": input_str},

View File

@ -124,7 +124,10 @@ class Run(RunBase):
def assign_name(cls, values: dict) -> dict:
"""Assign name to the run."""
if "name" not in values:
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
return values

View File

@ -7,7 +7,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import Field, root_validator, validator
import langchain
from langchain.callbacks.base import BaseCallbackManager
@ -18,6 +18,8 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
@ -25,7 +27,7 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(BaseModel, ABC):
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
@ -131,7 +133,7 @@ class Chain(BaseModel, ABC):
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
inputs,
)
try:
@ -179,7 +181,7 @@ class Chain(BaseModel, ABC):
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
inputs,
)
try:

View File

@ -15,6 +15,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
@ -34,6 +35,10 @@ class LLMChain(Chain):
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
@ -147,7 +152,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
{"input_list": input_list},
)
try:
@ -167,7 +172,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
{"input_list": input_list},
)
try:

View File

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.schema import (
AIMessage,
BaseMessage,
@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
"""Top Level call"""
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(

View File

@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel):
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
"""
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""

View File

@ -19,6 +19,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.schema import (
AIMessage,
BaseMessage,
@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC):
)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__},
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
)
try:
new_results = (
@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"""Run the LLM on the given prompt and input."""
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__},
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
)
try:
new_results = (

View File

@ -123,6 +123,14 @@ async def acompletion_with_retry(
class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private:
model_name: str = Field("text-davinci-003", alias="model")
"""Model name to use."""

View File

22
langchain/load/dump.py Normal file
View File

@ -0,0 +1,22 @@
import json
from typing import Any, Dict
from langchain.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
def dumps(obj: Any, *, pretty: bool = False) -> str:
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
def dumpd(obj: Any) -> Dict[str, Any]:
return json.loads(dumps(obj))

65
langchain/load/load.py Normal file
View File

@ -0,0 +1,65 @@
import importlib
import json
from typing import Any, Dict, Optional
from langchain.load.serializable import Serializable
class Reviver:
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
self.secrets_map = secrets_map or dict()
def __call__(self, value: Dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
):
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
[*namespace, name] = value["id"]
# Currently, we only support langchain imports.
if namespace[0] != "langchain":
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1:
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
raise ValueError(f"Invalid namespace: {value}")
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
return cls(**kwargs)
return value
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
return json.loads(text, object_hook=Reviver(secrets_map))

View File

@ -0,0 +1,135 @@
from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, Field
class BaseSerialized(TypedDict):
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
type: Literal["not_implemented"]
class Serializable(BaseModel, ABC):
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return False
@property
def lc_namespace(self) -> List[str]:
"""
Return the namespace of the langchain object.
eg. ["langchain", "llms", "openai"]
"""
return self.__class__.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""
Return a map of constructor argument names to secret ids.
eg. {"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {}
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.lc_serializable:
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self.lc_kwargs.items()
if not self.__exclude_fields__.get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
return {
"lc": 1,
"type": "constructor",
"id": [*self.lc_namespace, self.__class__.__name__],
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
return {
"lc": 1,
"type": "not_implemented",
"id": _id,
}

View File

@ -7,9 +7,10 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import Extra, Field, root_validator
from langchain.formatting import formatter
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
@ -100,7 +101,7 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)]
class BasePromptTemplate(BaseModel, ABC):
class BasePromptTemplate(Serializable, ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC):
default_factory=dict
)
@property
def lc_serializable(self) -> bool:
return True
class Config:
"""Configuration for this pydantic object."""

View File

@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import BaseModel, Field
from pydantic import Field
from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -20,7 +21,11 @@ from langchain.schema import (
)
class BaseMessagePromptTemplate(BaseModel, ABC):
class BaseMessagePromptTemplate(Serializable, ABC):
@property
def lc_serializable(self) -> bool:
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages."""
@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property
def _prompt_type(self) -> str:
raise NotImplementedError
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError

View File

@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate
class FewShotPromptTemplate(StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
return False
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""

View File

@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate):
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"template_format": self.template_format,
}
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

View File

@ -17,6 +17,8 @@ from uuid import UUID
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.load.serializable import Serializable
RUN_KEY = "__run"
@ -55,7 +57,7 @@ class AgentFinish(NamedTuple):
log: str
class Generation(BaseModel):
class Generation(Serializable):
"""Output of a single generation."""
text: str
@ -67,7 +69,7 @@ class Generation(BaseModel):
# TODO: add log probs
class BaseMessage(BaseModel):
class BaseMessage(Serializable):
"""Message object."""
content: str
@ -194,7 +196,7 @@ class LLMResult(BaseModel):
)
class PromptValue(BaseModel, ABC):
class PromptValue(Serializable, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages."""
class BaseMemory(BaseModel, ABC):
class BaseMemory(Serializable, ABC):
"""Base interface for memory in chains."""
class Config:
@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC):
"""Remove all messages from the store"""
class Document(BaseModel):
class Document(Serializable):
"""Interface for interacting with a document."""
page_content: str
@ -321,7 +323,7 @@ Memory = BaseMemory
T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC, Generic[T]):
class BaseOutputParser(Serializable, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.

31
poetry.lock generated
View File

@ -1417,6 +1417,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "colored"
version = "1.4.4"
description = "Simple library for color and formatting to terminal"
category = "dev"
optional = false
python-versions = "*"
files = [
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
]
[[package]]
name = "coloredlogs"
version = "15.0.1"
@ -9461,6 +9472,22 @@ files = [
[package.dependencies]
mpmath = ">=0.19"
[[package]]
name = "syrupy"
version = "4.0.2"
description = "Pytest Snapshot Test Utility"
category = "dev"
optional = false
python-versions = ">=3.8.1,<4"
files = [
{file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"},
{file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"},
]
[package.dependencies]
colored = ">=1.3.92,<2.0.0"
pytest = ">=7.0.0,<8.0.0"
[[package]]
name = "tabulate"
version = "0.9.0"
@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"]
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
@ -11437,4 +11464,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5"
content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f"

View File

@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
[tool.poetry.group.test_integration]
optional = true
@ -315,7 +316,8 @@ extended_testing = [
"html2text",
"py-trello",
"scikit-learn",
"pyspark"
"pyspark",
"openai"
]
[tool.ruff]
@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api"
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [

View File

@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
error=None,
@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=[""]),
outputs=LLMResult(generations=[[]]),
error=None,
@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
error=None,
@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
outputs=None,

View File

@ -22,6 +22,9 @@ from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],

View File

@ -0,0 +1,273 @@
# serializer version: 1
# name: test_person
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"Person"
],
"kwargs": {
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_person.1
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"SpecialPerson"
],
"kwargs": {
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"another_visible": "bye",
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_serialize_llmchain
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_chat
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": "hello"
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
]
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_with_non_serializable_arg
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"client": {
"lc": 1,
"type": "not_implemented",
"id": [
"openai",
"api_resources",
"completion",
"Completion"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_openai_llm
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.7,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
}
'''
# ---

View File

@ -0,0 +1,103 @@
"""Test for Serializable base class"""
from typing import Any, Dict
import pytest
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.llm import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.serializable import Serializable
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
class Person(Serializable):
secret: str
you_can_see_me: str = "hello"
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"secret": "SECRET"}
@property
def lc_attributes(self) -> Dict[str, str]:
return {"you_can_see_me": self.you_can_see_me}
class SpecialPerson(Person):
another_secret: str
another_visible: str = "bye"
# Gets merged with parent class's secrets
@property
def lc_secrets(self) -> Dict[str, str]:
return {"another_secret": "ANOTHER_SECRET"}
# Gets merged with parent class's attributes
@property
def lc_attributes(self) -> Dict[str, str]:
return {"another_visible": self.another_visible}
class NotSerializable:
pass
def test_person(snapshot: Any) -> None:
p = Person(secret="hello")
assert dumps(p, pretty=True) == snapshot
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
assert dumps(sp, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_openai_llm(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
# This is excluded from serialization
callbacks=[LangChainTracer()],
)
llm.temperature = 0.7 # this is reflected in serialization
assert dumps(llm, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain(snapshot: Any) -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None:
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot

View File

@ -0,0 +1,54 @@
"""Test for Serializable base class"""
import pytest
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.prompts.prompt import PromptTemplate
class NotSerializable:
pass
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm
assert dumps(llm2) == llm_string
assert isinstance(llm2, OpenAI)
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert chain2 == chain
assert dumps(chain2) == chain_string
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain, pretty=True)
with pytest.raises(NotImplementedError):
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})

View File

@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"pytest-socket",
"pytest-watcher",
"responses",
"syrupy",
]