Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos 2023-06-11 23:51:28 +01:00 committed by GitHub
parent 614cff89bc
commit 18af149e91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 810 additions and 71 deletions

View File

@ -4,9 +4,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Optional, Sequence, Set from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text) return tokenizer.encode(text)
class BaseLanguageModel(BaseModel, ABC): class BaseLanguageModel(Serializable, ABC):
@abstractmethod @abstractmethod
def generate_prompt( def generate_prompt(
self, self,

View File

@ -204,7 +204,7 @@ def _handle_event(
except Exception as e: except Exception as e:
if handler.raise_error: if handler.raise_error:
raise e raise e
logging.warning(f"Error in {event_name} callback: {e}") logger.warning(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler( async def _ahandle_event_for_handler(

View File

@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_) execution_order = self._get_execution_order(parent_run_id_)
llm_run = Run( llm_run = Run(
id=run_id, id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"prompts": prompts}, inputs={"prompts": prompts},
@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_) execution_order = self._get_execution_order(parent_run_id_)
chain_run = Run( chain_run = Run(
id=run_id, id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs=inputs, inputs=inputs,
@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_) execution_order = self._get_execution_order(parent_run_id_)
tool_run = Run( tool_run = Run(
id=run_id, id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id, parent_run_id=parent_run_id,
serialized=serialized, serialized=serialized,
inputs={"input": input_str}, inputs={"input": input_str},

View File

@ -124,7 +124,10 @@ class Run(RunBase):
def assign_name(cls, values: dict) -> dict: def assign_name(cls, values: dict) -> dict:
"""Assign name to the run.""" """Assign name to the run."""
if "name" not in values: if "name" not in values:
values["name"] = values["serialized"]["name"] if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
return values return values

View File

@ -7,7 +7,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import yaml import yaml
from pydantic import BaseModel, Field, root_validator, validator from pydantic import Field, root_validator, validator
import langchain import langchain
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -18,6 +18,8 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo from langchain.schema import RUN_KEY, BaseMemory, RunInfo
@ -25,7 +27,7 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
class Chain(BaseModel, ABC): class Chain(Serializable, ABC):
"""Base interface that all chains should implement.""" """Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None memory: Optional[BaseMemory] = None
@ -131,7 +133,7 @@ class Chain(BaseModel, ABC):
) )
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__}, dumpd(self),
inputs, inputs,
) )
try: try:
@ -179,7 +181,7 @@ class Chain(BaseModel, ABC):
) )
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__}, dumpd(self),
inputs, inputs,
) )
try: try:

View File

@ -15,6 +15,7 @@ from langchain.callbacks.manager import (
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_colored_text from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue from langchain.schema import LLMResult, PromptValue
@ -34,6 +35,10 @@ class LLMChain(Chain):
llm = LLMChain(llm=OpenAI(), prompt=prompt) llm = LLMChain(llm=OpenAI(), prompt=prompt)
""" """
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate prompt: BasePromptTemplate
"""Prompt object to use.""" """Prompt object to use."""
llm: BaseLanguageModel llm: BaseLanguageModel
@ -147,7 +152,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__}, dumpd(self),
{"input_list": input_list}, {"input_list": input_list},
) )
try: try:
@ -167,7 +172,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__}, dumpd(self),
{"input_list": input_list}, {"input_list": input_list},
) )
try: try:

View File

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop}
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = callback_manager.on_chat_model_start( run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params dumpd(self), messages, invocation_params=params, options=options
) )
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
"""Top Level call""" """Top Level call"""
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = await callback_manager.on_chat_model_start( run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params dumpd(self), messages, invocation_params=params, options=options
) )
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(

View File

@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel):
openai = ChatOpenAI(model_name="gpt-3.5-turbo") openai = ChatOpenAI(model_name="gpt-3.5-turbo")
""" """
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model") model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use.""" """Model name to use."""

View File

@ -19,6 +19,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
Callbacks, Callbacks,
) )
from langchain.load.dump import dumpd
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop}
( (
existing_prompts, existing_prompts,
llm_string, llm_string,
@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
run_manager = callback_manager.on_llm_start( run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params dumpd(self), prompts, invocation_params=params, options=options
) )
try: try:
output = ( output = (
@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start( run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, dumpd(self),
missing_prompts, missing_prompts,
invocation_params=params, invocation_params=params,
options=options,
) )
try: try:
new_results = ( new_results = (
@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
options = {"stop": stop}
( (
existing_prompts, existing_prompts,
llm_string, llm_string,
@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`." "Asked to cache, but no cache found at `langchain.cache`."
) )
run_manager = await callback_manager.on_llm_start( run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params dumpd(self), prompts, invocation_params=params, options=options
) )
try: try:
output = ( output = (
@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start( run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, dumpd(self),
missing_prompts, missing_prompts,
invocation_params=params, invocation_params=params,
options=options,
) )
try: try:
new_results = ( new_results = (

View File

@ -123,6 +123,14 @@ async def acompletion_with_retry(
class BaseOpenAI(BaseLLM): class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models.""" """Wrapper around OpenAI large language models."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = Field("text-davinci-003", alias="model") model_name: str = Field("text-davinci-003", alias="model")
"""Model name to use.""" """Model name to use."""

View File

22
langchain/load/dump.py Normal file
View File

@ -0,0 +1,22 @@
import json
from typing import Any, Dict
from langchain.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
def dumps(obj: Any, *, pretty: bool = False) -> str:
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
def dumpd(obj: Any) -> Dict[str, Any]:
return json.loads(dumps(obj))

65
langchain/load/load.py Normal file
View File

@ -0,0 +1,65 @@
import importlib
import json
from typing import Any, Dict, Optional
from langchain.load.serializable import Serializable
class Reviver:
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
self.secrets_map = secrets_map or dict()
def __call__(self, value: Dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
):
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
[*namespace, name] = value["id"]
# Currently, we only support langchain imports.
if namespace[0] != "langchain":
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1:
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
raise ValueError(f"Invalid namespace: {value}")
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
return cls(**kwargs)
return value
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
return json.loads(text, object_hook=Reviver(secrets_map))

View File

@ -0,0 +1,135 @@
from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, Field
class BaseSerialized(TypedDict):
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
type: Literal["not_implemented"]
class Serializable(BaseModel, ABC):
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return False
@property
def lc_namespace(self) -> List[str]:
"""
Return the namespace of the langchain object.
eg. ["langchain", "llms", "openai"]
"""
return self.__class__.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""
Return a map of constructor argument names to secret ids.
eg. {"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {}
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.lc_serializable:
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self.lc_kwargs.items()
if not self.__exclude_fields__.get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
return {
"lc": 1,
"type": "constructor",
"id": [*self.lc_namespace, self.__class__.__name__],
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
return {
"lc": 1,
"type": "not_implemented",
"id": _id,
}

View File

@ -7,9 +7,10 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml import yaml
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from langchain.formatting import formatter from langchain.formatting import formatter
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
@ -100,7 +101,7 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)] return [HumanMessage(content=self.text)]
class BasePromptTemplate(BaseModel, ABC): class BasePromptTemplate(Serializable, ABC):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
input_variables: List[str] input_variables: List[str]
@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC):
default_factory=dict default_factory=dict
) )
@property
def lc_serializable(self) -> bool:
return True
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import BaseModel, Field from pydantic import Field
from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string from langchain.memory.buffer import get_buffer_string
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -20,7 +21,11 @@ from langchain.schema import (
) )
class BaseMessagePromptTemplate(BaseModel, ABC): class BaseMessagePromptTemplate(Serializable, ABC):
@property
def lc_serializable(self) -> bool:
return True
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages.""" """To messages."""
@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property @property
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
raise NotImplementedError return "chat"
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate
class FewShotPromptTemplate(StringPromptTemplate): class FewShotPromptTemplate(StringPromptTemplate):
"""Prompt template that contains few shot examples.""" """Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
return False
examples: Optional[List[dict]] = None examples: Optional[List[dict]] = None
"""Examples to format into the prompt. """Examples to format into the prompt.
Either this or example_selector should be provided.""" Either this or example_selector should be provided."""

View File

@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate):
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
""" """
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"template_format": self.template_format,
}
input_variables: List[str] input_variables: List[str]
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""

View File

@ -17,6 +17,8 @@ from uuid import UUID
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from langchain.load.serializable import Serializable
RUN_KEY = "__run" RUN_KEY = "__run"
@ -55,7 +57,7 @@ class AgentFinish(NamedTuple):
log: str log: str
class Generation(BaseModel): class Generation(Serializable):
"""Output of a single generation.""" """Output of a single generation."""
text: str text: str
@ -67,7 +69,7 @@ class Generation(BaseModel):
# TODO: add log probs # TODO: add log probs
class BaseMessage(BaseModel): class BaseMessage(Serializable):
"""Message object.""" """Message object."""
content: str content: str
@ -194,7 +196,7 @@ class LLMResult(BaseModel):
) )
class PromptValue(BaseModel, ABC): class PromptValue(Serializable, ABC):
@abstractmethod @abstractmethod
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt as string.""" """Return prompt as string."""
@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages.""" """Return prompt as messages."""
class BaseMemory(BaseModel, ABC): class BaseMemory(Serializable, ABC):
"""Base interface for memory in chains.""" """Base interface for memory in chains."""
class Config: class Config:
@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC):
"""Remove all messages from the store""" """Remove all messages from the store"""
class Document(BaseModel): class Document(Serializable):
"""Interface for interacting with a document.""" """Interface for interacting with a document."""
page_content: str page_content: str
@ -321,7 +323,7 @@ Memory = BaseMemory
T = TypeVar("T") T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC, Generic[T]): class BaseOutputParser(Serializable, ABC, Generic[T]):
"""Class to parse the output of an LLM call. """Class to parse the output of an LLM call.
Output parsers help structure language model responses. Output parsers help structure language model responses.

31
poetry.lock generated
View File

@ -1417,6 +1417,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "colored"
version = "1.4.4"
description = "Simple library for color and formatting to terminal"
category = "dev"
optional = false
python-versions = "*"
files = [
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
]
[[package]] [[package]]
name = "coloredlogs" name = "coloredlogs"
version = "15.0.1" version = "15.0.1"
@ -9461,6 +9472,22 @@ files = [
[package.dependencies] [package.dependencies]
mpmath = ">=0.19" mpmath = ">=0.19"
[[package]]
name = "syrupy"
version = "4.0.2"
description = "Pytest Snapshot Test Utility"
category = "dev"
optional = false
python-versions = ">=3.8.1,<4"
files = [
{file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"},
{file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"},
]
[package.dependencies]
colored = ">=1.3.92,<2.0.0"
pytest = ">=7.0.0,<8.0.0"
[[package]] [[package]]
name = "tabulate" name = "tabulate"
version = "0.9.0" version = "0.9.0"
@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for
cohere = ["cohere"] cohere = ["cohere"]
docarray = ["docarray"] docarray = ["docarray"]
embeddings = ["sentence-transformers"] embeddings = ["sentence-transformers"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"] extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"]
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
openai = ["openai", "tiktoken"] openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"] qdrant = ["qdrant-client"]
@ -11437,4 +11464,4 @@ text-helpers = ["chardet"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5" content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f"

View File

@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3"
lark = "^1.1.5" lark = "^1.1.5"
pytest-mock = "^3.10.0" pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0" pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
[tool.poetry.group.test_integration] [tool.poetry.group.test_integration]
optional = true optional = true
@ -315,7 +316,8 @@ extended_testing = [
"html2text", "html2text",
"py-trello", "py-trello",
"scikit-learn", "scikit-learn",
"pyspark" "pyspark",
"openai"
] ]
[tool.ruff] [tool.ruff]
@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api"
# https://docs.pytest.org/en/7.1.x/reference/reference.html # https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest` # --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors. # section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5" #
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused"
# Registering custom markers. # Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [ markers = [

View File

@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult from langchain.schema import LLMResult
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
class FakeTracer(BaseTracer): class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution.""" """Fake tracer that records LangChain execution."""
@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
inputs={"prompts": []}, inputs={"prompts": []},
outputs=LLMResult(generations=[[]]), outputs=LLMResult(generations=[[]]),
error=None, error=None,
@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
) )
tracer = FakeTracer() tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "chat_model"}, serialized=SERIALIZED_CHAT,
inputs=dict(prompts=[""]), inputs=dict(prompts=[""]),
outputs=LLMResult(generations=[[]]), outputs=LLMResult(generations=[[]]),
error=None, error=None,
@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
tracer = FakeTracer() tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer]) manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start( run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
) )
run_manager.on_llm_end(response=LLMResult(generations=[[]])) run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), outputs=LLMResult(generations=[[]]),
error=None, error=None,
@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10 num_runs = 10
for _ in range(num_runs): for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs assert tracer.runs == [compare_run] * num_runs
@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid1, run_id=llm_uuid1,
parent_run_id=tool_uuid, parent_run_id=tool_uuid,
@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid) tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid2, run_id=llm_uuid2,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
extra={}, extra={},
execution_order=3, execution_order=3,
child_execution_order=3, child_execution_order=3,
serialized={"name": "llm"}, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), outputs=LLMResult(generations=[[]]),
run_type="llm", run_type="llm",
@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
extra={}, extra={},
execution_order=4, execution_order=4,
child_execution_order=4, child_execution_order=4,
serialized={"name": "llm"}, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]), outputs=LLMResult(generations=[[]]),
run_type="llm", run_type="llm",
@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=None, outputs=None,
error=repr(exception), error=repr(exception),
@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
) )
tracer = FakeTracer() tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid) tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid1, run_id=llm_uuid1,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid2, run_id=llm_uuid2,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid3, run_id=llm_uuid3,
parent_run_id=tool_uuid, parent_run_id=tool_uuid,
@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=2, execution_order=2,
child_execution_order=2, child_execution_order=2,
serialized={"name": "llm"}, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None), outputs=LLMResult(generations=[[]], llm_output=None),
@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=3, execution_order=3,
child_execution_order=3, child_execution_order=3,
serialized={"name": "llm"}, serialized=SERIALIZED,
error=None, error=None,
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None), outputs=LLMResult(generations=[[]], llm_output=None),
@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=5, execution_order=5,
child_execution_order=5, child_execution_order=5,
serialized={"name": "llm"}, serialized=SERIALIZED,
error=repr(exception), error=repr(exception),
inputs=dict(prompts=[]), inputs=dict(prompts=[]),
outputs=None, outputs=None,

View File

@ -22,6 +22,9 @@ from langchain.schema import LLMResult
TEST_SESSION_ID = 2023 TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1: def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session.""" """Load a tracing session."""
@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
response=LLMResult(generations=[[]]), response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "chat_model"}, serialized=SERIALIZED_CHAT,
prompts=[""], prompts=[""],
response=LLMResult(generations=[[]]), response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
tracer.new_session() tracer.new_session()
manager = CallbackManager(handlers=[tracer]) manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start( run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
) )
run_manager.on_llm_end(response=LLMResult(generations=[[]])) run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
response=LLMResult(generations=[[]]), response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session() tracer.new_session()
num_runs = 10 num_runs = 10
for _ in range(num_runs): for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs assert tracer.runs == [compare_run] * num_runs
@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid1, run_id=llm_uuid1,
parent_run_id=tool_uuid, parent_run_id=tool_uuid,
@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid) tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid2, run_id=llm_uuid2,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
extra={}, extra={},
execution_order=3, execution_order=3,
child_execution_order=3, child_execution_order=3,
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
response=LLMResult(generations=[[]]), response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
extra={}, extra={},
execution_order=4, execution_order=4,
child_execution_order=4, child_execution_order=4,
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
response=LLMResult(generations=[[]]), response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={}, extra={},
execution_order=1, execution_order=1,
child_execution_order=1, child_execution_order=1,
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
response=None, response=None,
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer() tracer = FakeTracer()
tracer.new_session() tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid) tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid1, run_id=llm_uuid1,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid2, run_id=llm_uuid2,
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid, parent_run_id=chain_uuid,
) )
tracer.on_llm_start( tracer.on_llm_start(
serialized={"name": "llm"}, serialized=SERIALIZED,
prompts=[], prompts=[],
run_id=llm_uuid3, run_id=llm_uuid3,
parent_run_id=tool_uuid, parent_run_id=tool_uuid,
@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=2, execution_order=2,
child_execution_order=2, child_execution_order=2,
serialized={"name": "llm"}, serialized=SERIALIZED,
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=None, error=None,
prompts=[], prompts=[],
@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=3, execution_order=3,
child_execution_order=3, child_execution_order=3,
serialized={"name": "llm"}, serialized=SERIALIZED,
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=None, error=None,
prompts=[], prompts=[],
@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={}, extra={},
execution_order=5, execution_order=5,
child_execution_order=5, child_execution_order=5,
serialized={"name": "llm"}, serialized=SERIALIZED,
session_id=TEST_SESSION_ID, session_id=TEST_SESSION_ID,
error=repr(exception), error=repr(exception),
prompts=[], prompts=[],

View File

@ -0,0 +1,273 @@
# serializer version: 1
# name: test_person
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"Person"
],
"kwargs": {
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_person.1
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"SpecialPerson"
],
"kwargs": {
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"another_visible": "bye",
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_serialize_llmchain
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_chat
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": "hello"
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
]
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_with_non_serializable_arg
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"client": {
"lc": 1,
"type": "not_implemented",
"id": [
"openai",
"api_resources",
"completion",
"Completion"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_openai_llm
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.7,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
}
'''
# ---

View File

@ -0,0 +1,103 @@
"""Test for Serializable base class"""
from typing import Any, Dict
import pytest
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.llm import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.serializable import Serializable
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
class Person(Serializable):
secret: str
you_can_see_me: str = "hello"
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"secret": "SECRET"}
@property
def lc_attributes(self) -> Dict[str, str]:
return {"you_can_see_me": self.you_can_see_me}
class SpecialPerson(Person):
another_secret: str
another_visible: str = "bye"
# Gets merged with parent class's secrets
@property
def lc_secrets(self) -> Dict[str, str]:
return {"another_secret": "ANOTHER_SECRET"}
# Gets merged with parent class's attributes
@property
def lc_attributes(self) -> Dict[str, str]:
return {"another_visible": self.another_visible}
class NotSerializable:
pass
def test_person(snapshot: Any) -> None:
p = Person(secret="hello")
assert dumps(p, pretty=True) == snapshot
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
assert dumps(sp, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_openai_llm(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
# This is excluded from serialization
callbacks=[LangChainTracer()],
)
llm.temperature = 0.7 # this is reflected in serialization
assert dumps(llm, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain(snapshot: Any) -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None:
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot

View File

@ -0,0 +1,54 @@
"""Test for Serializable base class"""
import pytest
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.prompts.prompt import PromptTemplate
class NotSerializable:
pass
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm
assert dumps(llm2) == llm_string
assert isinstance(llm2, OpenAI)
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert chain2 == chain
assert dumps(chain2) == chain_string
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain, pretty=True)
with pytest.raises(NotImplementedError):
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})

View File

@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"pytest-socket", "pytest-socket",
"pytest-watcher", "pytest-watcher",
"responses", "responses",
"syrupy",
] ]