Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
searx_updates
Nuno Campos 11 months ago committed by GitHub
parent 614cff89bc
commit 18af149e91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,9 +4,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)
class BaseLanguageModel(BaseModel, ABC):
class BaseLanguageModel(Serializable, ABC):
@abstractmethod
def generate_prompt(
self,

@ -204,7 +204,7 @@ def _handle_event(
except Exception as e:
if handler.raise_error:
raise e
logging.warning(f"Error in {event_name} callback: {e}")
logger.warning(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler(

@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
llm_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"prompts": prompts},
@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
chain_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs=inputs,
@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order = self._get_execution_order(parent_run_id_)
tool_run = Run(
id=run_id,
name=serialized.get("name"),
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"input": input_str},

@ -124,7 +124,10 @@ class Run(RunBase):
def assign_name(cls, values: dict) -> dict:
"""Assign name to the run."""
if "name" not in values:
values["name"] = values["serialized"]["name"]
if "name" in values["serialized"]:
values["name"] = values["serialized"]["name"]
elif "id" in values["serialized"]:
values["name"] = values["serialized"]["id"][-1]
return values

@ -7,7 +7,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import yaml
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import Field, root_validator, validator
import langchain
from langchain.callbacks.base import BaseCallbackManager
@ -18,6 +18,8 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
@ -25,7 +27,7 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(BaseModel, ABC):
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
@ -131,7 +133,7 @@ class Chain(BaseModel, ABC):
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
inputs,
)
try:
@ -179,7 +181,7 @@ class Chain(BaseModel, ABC):
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
inputs,
)
try:

@ -15,6 +15,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
@ -34,6 +35,10 @@ class LLMChain(Chain):
llm = LLMChain(llm=OpenAI(), prompt=prompt)
"""
@property
def lc_serializable(self) -> bool:
return True
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
@ -147,7 +152,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
{"input_list": input_list},
)
try:
@ -167,7 +172,7 @@ class LLMChain(Chain):
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chain_start(
{"name": self.__class__.__name__},
dumpd(self),
{"input_list": input_list},
)
try:

@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.schema import (
AIMessage,
BaseMessage,
@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
"""Top Level call"""
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
dumpd(self), messages, invocation_params=params, options=options
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(

@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel):
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
"""
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""

@ -19,6 +19,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.schema import (
AIMessage,
BaseMessage,
@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC):
)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__},
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
)
try:
new_results = (
@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"""Run the LLM on the given prompt and input."""
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
(
existing_prompts,
llm_string,
@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC):
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC):
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__},
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
)
try:
new_results = (

@ -123,6 +123,14 @@ async def acompletion_with_retry(
class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models."""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private:
model_name: str = Field("text-davinci-003", alias="model")
"""Model name to use."""

@ -0,0 +1,22 @@
import json
from typing import Any, Dict
from langchain.load.serializable import Serializable, to_json_not_implemented
def default(obj: Any) -> Any:
if isinstance(obj, Serializable):
return obj.to_json()
else:
return to_json_not_implemented(obj)
def dumps(obj: Any, *, pretty: bool = False) -> str:
if pretty:
return json.dumps(obj, default=default, indent=2)
else:
return json.dumps(obj, default=default)
def dumpd(obj: Any) -> Dict[str, Any]:
return json.loads(dumps(obj))

@ -0,0 +1,65 @@
import importlib
import json
from typing import Any, Dict, Optional
from langchain.load.serializable import Serializable
class Reviver:
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
self.secrets_map = secrets_map or dict()
def __call__(self, value: Dict[str, Any]) -> Any:
if (
value.get("lc", None) == 1
and value.get("type", None) == "secret"
and value.get("id", None) is not None
):
[key] = value["id"]
if key in self.secrets_map:
return self.secrets_map[key]
else:
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
if (
value.get("lc", None) == 1
and value.get("type", None) == "not_implemented"
and value.get("id", None) is not None
):
raise NotImplementedError(
"Trying to load an object that doesn't implement "
f"serialization: {value}"
)
if (
value.get("lc", None) == 1
and value.get("type", None) == "constructor"
and value.get("id", None) is not None
):
[*namespace, name] = value["id"]
# Currently, we only support langchain imports.
if namespace[0] != "langchain":
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1:
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):
raise ValueError(f"Invalid namespace: {value}")
# We don't need to recurse on kwargs
# as json.loads will do that for us.
kwargs = value.get("kwargs", dict())
return cls(**kwargs)
return value
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
return json.loads(text, object_hook=Reviver(secrets_map))

@ -0,0 +1,135 @@
from abc import ABC
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
from pydantic import BaseModel, Field
class BaseSerialized(TypedDict):
lc: int
id: List[str]
class SerializedConstructor(BaseSerialized):
type: Literal["constructor"]
kwargs: Dict[str, Any]
class SerializedSecret(BaseSerialized):
type: Literal["secret"]
class SerializedNotImplemented(BaseSerialized):
type: Literal["not_implemented"]
class Serializable(BaseModel, ABC):
@property
def lc_serializable(self) -> bool:
"""
Return whether or not the class is serializable.
"""
return False
@property
def lc_namespace(self) -> List[str]:
"""
Return the namespace of the langchain object.
eg. ["langchain", "llms", "openai"]
"""
return self.__class__.__module__.split(".")
@property
def lc_secrets(self) -> Dict[str, str]:
"""
Return a map of constructor argument names to secret ids.
eg. {"openai_api_key": "OPENAI_API_KEY"}
"""
return dict()
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {}
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.lc_kwargs = kwargs
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
if not self.lc_serializable:
return self.to_json_not_implemented()
secrets = dict()
# Get latest values for kwargs if there is an attribute with same name
lc_kwargs = {
k: getattr(self, k, v)
for k, v in self.lc_kwargs.items()
if not self.__exclude_fields__.get(k, False) # type: ignore
}
# Merge the lc_secrets and lc_attributes from every class in the MRO
for cls in [None, *self.__class__.mro()]:
# Once we get to Serializable, we're done
if cls is Serializable:
break
# Get a reference to self bound to each class in the MRO
this = cast(Serializable, self if cls is None else super(cls, self))
secrets.update(this.lc_secrets)
lc_kwargs.update(this.lc_attributes)
return {
"lc": 1,
"type": "constructor",
"id": [*self.lc_namespace, self.__class__.__name__],
"kwargs": lc_kwargs
if not secrets
else _replace_secrets(lc_kwargs, secrets),
}
def to_json_not_implemented(self) -> SerializedNotImplemented:
return to_json_not_implemented(self)
def _replace_secrets(
root: Dict[Any, Any], secrets_map: Dict[str, str]
) -> Dict[Any, Any]:
result = root.copy()
for path, secret_id in secrets_map.items():
[*parts, last] = path.split(".")
current = result
for part in parts:
if part not in current:
break
current[part] = current[part].copy()
current = current[part]
if last in current:
current[last] = {
"lc": 1,
"type": "secret",
"id": [secret_id],
}
return result
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
_id: List[str] = []
try:
if hasattr(obj, "__name__"):
_id = [*obj.__module__.split("."), obj.__name__]
elif hasattr(obj, "__class__"):
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
except Exception:
pass
return {
"lc": 1,
"type": "not_implemented",
"id": _id,
}

@ -7,9 +7,10 @@ from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pydantic import Extra, Field, root_validator
from langchain.formatting import formatter
from langchain.load.serializable import Serializable
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
@ -100,7 +101,7 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)]
class BasePromptTemplate(BaseModel, ABC):
class BasePromptTemplate(Serializable, ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC):
default_factory=dict
)
@property
def lc_serializable(self) -> bool:
return True
class Config:
"""Configuration for this pydantic object."""

@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import BaseModel, Field
from pydantic import Field
from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate
@ -20,7 +21,11 @@ from langchain.schema import (
)
class BaseMessagePromptTemplate(BaseModel, ABC):
class BaseMessagePromptTemplate(Serializable, ABC):
@property
def lc_serializable(self) -> bool:
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages."""
@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property
def _prompt_type(self) -> str:
raise NotImplementedError
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError

@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate
class FewShotPromptTemplate(StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
return False
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""

@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate):
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
@property
def lc_attributes(self) -> Dict[str, Any]:
return {
"template_format": self.template_format,
}
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

@ -17,6 +17,8 @@ from uuid import UUID
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.load.serializable import Serializable
RUN_KEY = "__run"
@ -55,7 +57,7 @@ class AgentFinish(NamedTuple):
log: str
class Generation(BaseModel):
class Generation(Serializable):
"""Output of a single generation."""
text: str
@ -67,7 +69,7 @@ class Generation(BaseModel):
# TODO: add log probs
class BaseMessage(BaseModel):
class BaseMessage(Serializable):
"""Message object."""
content: str
@ -194,7 +196,7 @@ class LLMResult(BaseModel):
)
class PromptValue(BaseModel, ABC):
class PromptValue(Serializable, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC):
"""Return prompt as messages."""
class BaseMemory(BaseModel, ABC):
class BaseMemory(Serializable, ABC):
"""Base interface for memory in chains."""
class Config:
@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC):
"""Remove all messages from the store"""
class Document(BaseModel):
class Document(Serializable):
"""Interface for interacting with a document."""
page_content: str
@ -321,7 +323,7 @@ Memory = BaseMemory
T = TypeVar("T")
class BaseOutputParser(BaseModel, ABC, Generic[T]):
class BaseOutputParser(Serializable, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.

31
poetry.lock generated

@ -1417,6 +1417,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "colored"
version = "1.4.4"
description = "Simple library for color and formatting to terminal"
category = "dev"
optional = false
python-versions = "*"
files = [
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
]
[[package]]
name = "coloredlogs"
version = "15.0.1"
@ -9461,6 +9472,22 @@ files = [
[package.dependencies]
mpmath = ">=0.19"
[[package]]
name = "syrupy"
version = "4.0.2"
description = "Pytest Snapshot Test Utility"
category = "dev"
optional = false
python-versions = ">=3.8.1,<4"
files = [
{file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"},
{file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"},
]
[package.dependencies]
colored = ">=1.3.92,<2.0.0"
pytest = ">=7.0.0,<8.0.0"
[[package]]
name = "tabulate"
version = "0.9.0"
@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"]
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
@ -11437,4 +11464,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5"
content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f"

@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3"
lark = "^1.1.5"
pytest-mock = "^3.10.0"
pytest-socket = "^0.6.0"
syrupy = "^4.0.2"
[tool.poetry.group.test_integration]
optional = true
@ -315,7 +316,8 @@ extended_testing = [
"html2text",
"py-trello",
"scikit-learn",
"pyspark"
"pyspark",
"openai"
]
[tool.ruff]
@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api"
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5"
#
# https://github.com/tophat/syrupy
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [

@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
error=None,
@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=[""]),
outputs=LLMResult(generations=[[]]),
error=None,
@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
error=None,
@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
outputs=None,

@ -22,6 +22,9 @@ from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],

@ -0,0 +1,273 @@
# serializer version: 1
# name: test_person
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"Person"
],
"kwargs": {
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_person.1
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"SpecialPerson"
],
"kwargs": {
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"another_visible": "bye",
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_serialize_llmchain
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_chat
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": "hello"
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
]
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_with_non_serializable_arg
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"client": {
"lc": 1,
"type": "not_implemented",
"id": [
"openai",
"api_resources",
"completion",
"Completion"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_openai_llm
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.7,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
}
'''
# ---

@ -0,0 +1,103 @@
"""Test for Serializable base class"""
from typing import Any, Dict
import pytest
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.llm import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.serializable import Serializable
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
class Person(Serializable):
secret: str
you_can_see_me: str = "hello"
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"secret": "SECRET"}
@property
def lc_attributes(self) -> Dict[str, str]:
return {"you_can_see_me": self.you_can_see_me}
class SpecialPerson(Person):
another_secret: str
another_visible: str = "bye"
# Gets merged with parent class's secrets
@property
def lc_secrets(self) -> Dict[str, str]:
return {"another_secret": "ANOTHER_SECRET"}
# Gets merged with parent class's attributes
@property
def lc_attributes(self) -> Dict[str, str]:
return {"another_visible": self.another_visible}
class NotSerializable:
pass
def test_person(snapshot: Any) -> None:
p = Person(secret="hello")
assert dumps(p, pretty=True) == snapshot
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
assert dumps(sp, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_openai_llm(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
# This is excluded from serialization
callbacks=[LangChainTracer()],
)
llm.temperature = 0.7 # this is reflected in serialization
assert dumps(llm, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain(snapshot: Any) -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None:
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot

@ -0,0 +1,54 @@
"""Test for Serializable base class"""
import pytest
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.prompts.prompt import PromptTemplate
class NotSerializable:
pass
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm
assert dumps(llm2) == llm_string
assert isinstance(llm2, OpenAI)
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert chain2 == chain
assert dumps(chain2) == chain_string
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain, pretty=True)
with pytest.raises(NotImplementedError):
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})

@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"pytest-socket",
"pytest-watcher",
"responses",
"syrupy",
]

Loading…
Cancel
Save