From 18af149e91e62b3ac7728ddea420688d41043734 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sun, 11 Jun 2023 23:51:28 +0100 Subject: [PATCH] nc/load (#5733) Co-authored-by: Harrison Chase --- langchain/base_language.py | 5 +- langchain/callbacks/manager.py | 2 +- langchain/callbacks/tracers/base.py | 3 - langchain/callbacks/tracers/schemas.py | 5 +- langchain/chains/base.py | 10 +- langchain/chains/llm.py | 9 +- langchain/chat_models/base.py | 7 +- langchain/chat_models/openai.py | 4 + langchain/llms/base.py | 13 +- langchain/llms/openai.py | 8 + langchain/load/__init__.py | 0 langchain/load/dump.py | 22 ++ langchain/load/load.py | 65 +++++ langchain/load/serializable.py | 135 +++++++++ langchain/prompts/base.py | 9 +- langchain/prompts/chat.py | 11 +- langchain/prompts/few_shot.py | 4 + langchain/prompts/prompt.py | 6 + langchain/schema.py | 14 +- poetry.lock | 31 +- pyproject.toml | 9 +- .../callbacks/tracers/test_base_tracer.py | 39 +-- .../callbacks/tracers/test_langchain_v1.py | 39 +-- .../load/__snapshots__/test_dump.ambr | 273 ++++++++++++++++++ tests/unit_tests/load/test_dump.py | 103 +++++++ tests/unit_tests/load/test_load.py | 54 ++++ tests/unit_tests/test_dependencies.py | 1 + 27 files changed, 810 insertions(+), 71 deletions(-) create mode 100644 langchain/load/__init__.py create mode 100644 langchain/load/dump.py create mode 100644 langchain/load/load.py create mode 100644 langchain/load/serializable.py create mode 100644 tests/unit_tests/load/__snapshots__/test_dump.ambr create mode 100644 tests/unit_tests/load/test_dump.py create mode 100644 tests/unit_tests/load/test_load.py diff --git a/langchain/base_language.py b/langchain/base_language.py index 2587e8d2..f02e43d6 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -4,9 +4,8 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, List, Optional, Sequence, Set -from pydantic import BaseModel - from langchain.callbacks.manager import Callbacks +from langchain.load.serializable import Serializable from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string @@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]: return tokenizer.encode(text) -class BaseLanguageModel(BaseModel, ABC): +class BaseLanguageModel(Serializable, ABC): @abstractmethod def generate_prompt( self, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 2c935003..07600bf9 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -204,7 +204,7 @@ def _handle_event( except Exception as e: if handler.raise_error: raise e - logging.warning(f"Error in {event_name} callback: {e}") + logger.warning(f"Error in {event_name} callback: {e}") async def _ahandle_event_for_handler( diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 4c7ddbac..93df3513 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order = self._get_execution_order(parent_run_id_) llm_run = Run( id=run_id, - name=serialized.get("name"), parent_run_id=parent_run_id, serialized=serialized, inputs={"prompts": prompts}, @@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order = self._get_execution_order(parent_run_id_) chain_run = Run( id=run_id, - name=serialized.get("name"), parent_run_id=parent_run_id, serialized=serialized, inputs=inputs, @@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order = self._get_execution_order(parent_run_id_) tool_run = Run( id=run_id, - name=serialized.get("name"), parent_run_id=parent_run_id, serialized=serialized, inputs={"input": input_str}, diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index 4816b8b9..bc8abeae 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -124,7 +124,10 @@ class Run(RunBase): def assign_name(cls, values: dict) -> dict: """Assign name to the run.""" if "name" not in values: - values["name"] = values["serialized"]["name"] + if "name" in values["serialized"]: + values["name"] = values["serialized"]["name"] + elif "id" in values["serialized"]: + values["name"] = values["serialized"]["id"][-1] return values diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 2db63a8f..66354adc 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union import yaml -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import Field, root_validator, validator import langchain from langchain.callbacks.base import BaseCallbackManager @@ -18,6 +18,8 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) +from langchain.load.dump import dumpd +from langchain.load.serializable import Serializable from langchain.schema import RUN_KEY, BaseMemory, RunInfo @@ -25,7 +27,7 @@ def _get_verbosity() -> bool: return langchain.verbose -class Chain(BaseModel, ABC): +class Chain(Serializable, ABC): """Base interface that all chains should implement.""" memory: Optional[BaseMemory] = None @@ -131,7 +133,7 @@ class Chain(BaseModel, ABC): ) new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") run_manager = callback_manager.on_chain_start( - {"name": self.__class__.__name__}, + dumpd(self), inputs, ) try: @@ -179,7 +181,7 @@ class Chain(BaseModel, ABC): ) new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") run_manager = await callback_manager.on_chain_start( - {"name": self.__class__.__name__}, + dumpd(self), inputs, ) try: diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 18d8f539..4c743530 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -15,6 +15,7 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.input import get_colored_text +from langchain.load.dump import dumpd from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import LLMResult, PromptValue @@ -34,6 +35,10 @@ class LLMChain(Chain): llm = LLMChain(llm=OpenAI(), prompt=prompt) """ + @property + def lc_serializable(self) -> bool: + return True + prompt: BasePromptTemplate """Prompt object to use.""" llm: BaseLanguageModel @@ -147,7 +152,7 @@ class LLMChain(Chain): callbacks, self.callbacks, self.verbose ) run_manager = callback_manager.on_chain_start( - {"name": self.__class__.__name__}, + dumpd(self), {"input_list": input_list}, ) try: @@ -167,7 +172,7 @@ class LLMChain(Chain): callbacks, self.callbacks, self.verbose ) run_manager = await callback_manager.on_chain_start( - {"name": self.__class__.__name__}, + dumpd(self), {"input_list": input_list}, ) try: diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 05c1e8d5..f3521df6 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -17,6 +17,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, Callbacks, ) +from langchain.load.dump import dumpd from langchain.schema import ( AIMessage, BaseMessage, @@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC): params = self.dict() params["stop"] = stop + options = {"stop": stop} callback_manager = CallbackManager.configure( callbacks, self.callbacks, self.verbose ) run_manager = callback_manager.on_chat_model_start( - {"name": self.__class__.__name__}, messages, invocation_params=params + dumpd(self), messages, invocation_params=params, options=options ) new_arg_supported = inspect.signature(self._generate).parameters.get( @@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC): """Top Level call""" params = self.dict() params["stop"] = stop + options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, self.verbose ) run_manager = await callback_manager.on_chat_model_start( - {"name": self.__class__.__name__}, messages, invocation_params=params + dumpd(self), messages, invocation_params=params, options=options ) new_arg_supported = inspect.signature(self._agenerate).parameters.get( diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index d7c51832..b1dcb9de 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel): openai = ChatOpenAI(model_name="gpt-3.5-turbo") """ + @property + def lc_serializable(self) -> bool: + return True + client: Any #: :meta private: model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 866bdada..9b065142 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -19,6 +19,7 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, Callbacks, ) +from langchain.load.dump import dumpd from langchain.schema import ( AIMessage, BaseMessage, @@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC): ) params = self.dict() params["stop"] = stop + options = {"stop": stop} ( existing_prompts, llm_string, @@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC): "Asked to cache, but no cache found at `langchain.cache`." ) run_manager = callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, invocation_params=params + dumpd(self), prompts, invocation_params=params, options=options ) try: output = ( @@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC): return output if len(missing_prompts) > 0: run_manager = callback_manager.on_llm_start( - {"name": self.__class__.__name__}, + dumpd(self), missing_prompts, invocation_params=params, + options=options, ) try: new_results = ( @@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC): """Run the LLM on the given prompt and input.""" params = self.dict() params["stop"] = stop + options = {"stop": stop} ( existing_prompts, llm_string, @@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC): "Asked to cache, but no cache found at `langchain.cache`." ) run_manager = await callback_manager.on_llm_start( - {"name": self.__class__.__name__}, prompts, invocation_params=params + dumpd(self), prompts, invocation_params=params, options=options ) try: output = ( @@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC): return output if len(missing_prompts) > 0: run_manager = await callback_manager.on_llm_start( - {"name": self.__class__.__name__}, + dumpd(self), missing_prompts, invocation_params=params, + options=options, ) try: new_results = ( diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index bb1c0212..172697a3 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -123,6 +123,14 @@ async def acompletion_with_retry( class BaseOpenAI(BaseLLM): """Wrapper around OpenAI large language models.""" + @property + def lc_secrets(self) -> Dict[str, str]: + return {"openai_api_key": "OPENAI_API_KEY"} + + @property + def lc_serializable(self) -> bool: + return True + client: Any #: :meta private: model_name: str = Field("text-davinci-003", alias="model") """Model name to use.""" diff --git a/langchain/load/__init__.py b/langchain/load/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain/load/dump.py b/langchain/load/dump.py new file mode 100644 index 00000000..d59fb0b2 --- /dev/null +++ b/langchain/load/dump.py @@ -0,0 +1,22 @@ +import json +from typing import Any, Dict + +from langchain.load.serializable import Serializable, to_json_not_implemented + + +def default(obj: Any) -> Any: + if isinstance(obj, Serializable): + return obj.to_json() + else: + return to_json_not_implemented(obj) + + +def dumps(obj: Any, *, pretty: bool = False) -> str: + if pretty: + return json.dumps(obj, default=default, indent=2) + else: + return json.dumps(obj, default=default) + + +def dumpd(obj: Any) -> Dict[str, Any]: + return json.loads(dumps(obj)) diff --git a/langchain/load/load.py b/langchain/load/load.py new file mode 100644 index 00000000..3cac560a --- /dev/null +++ b/langchain/load/load.py @@ -0,0 +1,65 @@ +import importlib +import json +from typing import Any, Dict, Optional + +from langchain.load.serializable import Serializable + + +class Reviver: + def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None: + self.secrets_map = secrets_map or dict() + + def __call__(self, value: Dict[str, Any]) -> Any: + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "secret" + and value.get("id", None) is not None + ): + [key] = value["id"] + if key in self.secrets_map: + return self.secrets_map[key] + else: + raise KeyError(f'Missing key "{key}" in load(secrets_map)') + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "not_implemented" + and value.get("id", None) is not None + ): + raise NotImplementedError( + "Trying to load an object that doesn't implement " + f"serialization: {value}" + ) + + if ( + value.get("lc", None) == 1 + and value.get("type", None) == "constructor" + and value.get("id", None) is not None + ): + [*namespace, name] = value["id"] + + # Currently, we only support langchain imports. + if namespace[0] != "langchain": + raise ValueError(f"Invalid namespace: {value}") + + # The root namespace "langchain" is not a valid identifier. + if len(namespace) == 1: + raise ValueError(f"Invalid namespace: {value}") + + mod = importlib.import_module(".".join(namespace)) + cls = getattr(mod, name) + + # The class must be a subclass of Serializable. + if not issubclass(cls, Serializable): + raise ValueError(f"Invalid namespace: {value}") + + # We don't need to recurse on kwargs + # as json.loads will do that for us. + kwargs = value.get("kwargs", dict()) + return cls(**kwargs) + + return value + + +def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any: + return json.loads(text, object_hook=Reviver(secrets_map)) diff --git a/langchain/load/serializable.py b/langchain/load/serializable.py new file mode 100644 index 00000000..9c8c60bf --- /dev/null +++ b/langchain/load/serializable.py @@ -0,0 +1,135 @@ +from abc import ABC +from typing import Any, Dict, List, Literal, TypedDict, Union, cast + +from pydantic import BaseModel, Field + + +class BaseSerialized(TypedDict): + lc: int + id: List[str] + + +class SerializedConstructor(BaseSerialized): + type: Literal["constructor"] + kwargs: Dict[str, Any] + + +class SerializedSecret(BaseSerialized): + type: Literal["secret"] + + +class SerializedNotImplemented(BaseSerialized): + type: Literal["not_implemented"] + + +class Serializable(BaseModel, ABC): + @property + def lc_serializable(self) -> bool: + """ + Return whether or not the class is serializable. + """ + return False + + @property + def lc_namespace(self) -> List[str]: + """ + Return the namespace of the langchain object. + eg. ["langchain", "llms", "openai"] + """ + return self.__class__.__module__.split(".") + + @property + def lc_secrets(self) -> Dict[str, str]: + """ + Return a map of constructor argument names to secret ids. + eg. {"openai_api_key": "OPENAI_API_KEY"} + """ + return dict() + + @property + def lc_attributes(self) -> Dict: + """ + Return a list of attribute names that should be included in the + serialized kwargs. These attributes must be accepted by the + constructor. + """ + return {} + + lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.lc_kwargs = kwargs + + def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]: + if not self.lc_serializable: + return self.to_json_not_implemented() + + secrets = dict() + # Get latest values for kwargs if there is an attribute with same name + lc_kwargs = { + k: getattr(self, k, v) + for k, v in self.lc_kwargs.items() + if not self.__exclude_fields__.get(k, False) # type: ignore + } + + # Merge the lc_secrets and lc_attributes from every class in the MRO + for cls in [None, *self.__class__.mro()]: + # Once we get to Serializable, we're done + if cls is Serializable: + break + + # Get a reference to self bound to each class in the MRO + this = cast(Serializable, self if cls is None else super(cls, self)) + + secrets.update(this.lc_secrets) + lc_kwargs.update(this.lc_attributes) + + return { + "lc": 1, + "type": "constructor", + "id": [*self.lc_namespace, self.__class__.__name__], + "kwargs": lc_kwargs + if not secrets + else _replace_secrets(lc_kwargs, secrets), + } + + def to_json_not_implemented(self) -> SerializedNotImplemented: + return to_json_not_implemented(self) + + +def _replace_secrets( + root: Dict[Any, Any], secrets_map: Dict[str, str] +) -> Dict[Any, Any]: + result = root.copy() + for path, secret_id in secrets_map.items(): + [*parts, last] = path.split(".") + current = result + for part in parts: + if part not in current: + break + current[part] = current[part].copy() + current = current[part] + if last in current: + current[last] = { + "lc": 1, + "type": "secret", + "id": [secret_id], + } + return result + + +def to_json_not_implemented(obj: object) -> SerializedNotImplemented: + _id: List[str] = [] + try: + if hasattr(obj, "__name__"): + _id = [*obj.__module__.split("."), obj.__name__] + elif hasattr(obj, "__class__"): + _id = [*obj.__class__.__module__.split("."), obj.__class__.__name__] + except Exception: + pass + return { + "lc": 1, + "type": "not_implemented", + "id": _id, + } diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 8d31b10e..58e7339b 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -7,9 +7,10 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union import yaml -from pydantic import BaseModel, Extra, Field, root_validator +from pydantic import Extra, Field, root_validator from langchain.formatting import formatter +from langchain.load.serializable import Serializable from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue @@ -100,7 +101,7 @@ class StringPromptValue(PromptValue): return [HumanMessage(content=self.text)] -class BasePromptTemplate(BaseModel, ABC): +class BasePromptTemplate(Serializable, ABC): """Base class for all prompt templates, returning a prompt.""" input_variables: List[str] @@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC): default_factory=dict ) + @property + def lc_serializable(self) -> bool: + return True + class Config: """Configuration for this pydantic object.""" diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 89fb10b0..1edc8f6c 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -5,8 +5,9 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union -from pydantic import BaseModel, Field +from pydantic import Field +from langchain.load.serializable import Serializable from langchain.memory.buffer import get_buffer_string from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate from langchain.prompts.prompt import PromptTemplate @@ -20,7 +21,11 @@ from langchain.schema import ( ) -class BaseMessagePromptTemplate(BaseModel, ABC): +class BaseMessagePromptTemplate(Serializable, ABC): + @property + def lc_serializable(self) -> bool: + return True + @abstractmethod def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """To messages.""" @@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): @property def _prompt_type(self) -> str: - raise NotImplementedError + return "chat" def save(self, file_path: Union[Path, str]) -> None: raise NotImplementedError diff --git a/langchain/prompts/few_shot.py b/langchain/prompts/few_shot.py index e17c5354..90122953 100644 --- a/langchain/prompts/few_shot.py +++ b/langchain/prompts/few_shot.py @@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate class FewShotPromptTemplate(StringPromptTemplate): """Prompt template that contains few shot examples.""" + @property + def lc_serializable(self) -> bool: + return False + examples: Optional[List[dict]] = None """Examples to format into the prompt. Either this or example_selector should be provided.""" diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 31f87d43..c8ac2200 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate): prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") """ + @property + def lc_attributes(self) -> Dict[str, Any]: + return { + "template_format": self.template_format, + } + input_variables: List[str] """A list of the names of the variables the prompt template expects.""" diff --git a/langchain/schema.py b/langchain/schema.py index b74b40a7..b2f76e70 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -17,6 +17,8 @@ from uuid import UUID from pydantic import BaseModel, Extra, Field, root_validator +from langchain.load.serializable import Serializable + RUN_KEY = "__run" @@ -55,7 +57,7 @@ class AgentFinish(NamedTuple): log: str -class Generation(BaseModel): +class Generation(Serializable): """Output of a single generation.""" text: str @@ -67,7 +69,7 @@ class Generation(BaseModel): # TODO: add log probs -class BaseMessage(BaseModel): +class BaseMessage(Serializable): """Message object.""" content: str @@ -194,7 +196,7 @@ class LLMResult(BaseModel): ) -class PromptValue(BaseModel, ABC): +class PromptValue(Serializable, ABC): @abstractmethod def to_string(self) -> str: """Return prompt as string.""" @@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC): """Return prompt as messages.""" -class BaseMemory(BaseModel, ABC): +class BaseMemory(Serializable, ABC): """Base interface for memory in chains.""" class Config: @@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC): """Remove all messages from the store""" -class Document(BaseModel): +class Document(Serializable): """Interface for interacting with a document.""" page_content: str @@ -321,7 +323,7 @@ Memory = BaseMemory T = TypeVar("T") -class BaseOutputParser(BaseModel, ABC, Generic[T]): +class BaseOutputParser(Serializable, ABC, Generic[T]): """Class to parse the output of an LLM call. Output parsers help structure language model responses. diff --git a/poetry.lock b/poetry.lock index 03242e43..9ed0db75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1417,6 +1417,17 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "colored" +version = "1.4.4" +description = "Simple library for color and formatting to terminal" +category = "dev" +optional = false +python-versions = "*" +files = [ + {file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"}, +] + [[package]] name = "coloredlogs" version = "15.0.1" @@ -9461,6 +9472,22 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "syrupy" +version = "4.0.2" +description = "Pytest Snapshot Test Utility" +category = "dev" +optional = false +python-versions = ">=3.8.1,<4" +files = [ + {file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"}, + {file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"}, +] + +[package.dependencies] +colored = ">=1.3.92,<2.0.0" +pytest = ">=7.0.0,<8.0.0" + [[package]] name = "tabulate" version = "0.9.0" @@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"] +extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"] llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] @@ -11437,4 +11464,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5" +content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f" diff --git a/pyproject.toml b/pyproject.toml index d8fbca99..fc6edb5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3" lark = "^1.1.5" pytest-mock = "^3.10.0" pytest-socket = "^0.6.0" +syrupy = "^4.0.2" [tool.poetry.group.test_integration] optional = true @@ -315,7 +316,8 @@ extended_testing = [ "html2text", "py-trello", "scikit-learn", - "pyspark" + "pyspark", + "openai" ] [tool.ruff] @@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api" # https://docs.pytest.org/en/7.1.x/reference/reference.html # --strict-config any warnings encountered while parsing the `pytest` # section of the configuration file raise errors. -addopts = "--strict-markers --strict-config --durations=5" +# +# https://github.com/tophat/syrupy +# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/tests/unit_tests/callbacks/tracers/test_base_tracer.py index 4ff2e342..c0736c62 100644 --- a/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException from langchain.callbacks.tracers.schemas import Run from langchain.schema import LLMResult +SERIALIZED = {"id": ["llm"]} +SERIALIZED_CHAT = {"id": ["chat_model"]} + class FakeTracer(BaseTracer): """Fake tracer that records LangChain execution.""" @@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, inputs={"prompts": []}, outputs=LLMResult(generations=[[]]), error=None, @@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None: ) tracer = FakeTracer() - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] @@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "chat_model"}, + serialized=SERIALIZED_CHAT, inputs=dict(prompts=[""]), outputs=LLMResult(generations=[[]]), error=None, @@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None: tracer = FakeTracer() manager = CallbackManager(handlers=[tracer]) run_manager = manager.on_chat_model_start( - serialized={"name": "chat_model"}, messages=[[]], run_id=uuid + serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid ) run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run] @@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, inputs=dict(prompts=[]), outputs=LLMResult(generations=[[]]), error=None, @@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None: num_runs = 10 for _ in range(num_runs): - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] * num_runs @@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None: parent_run_id=chain_uuid, ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid, @@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None: tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_tool_end("test", run_id=tool_uuid) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid, @@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None: extra={}, execution_order=3, child_execution_order=3, - serialized={"name": "llm"}, + serialized=SERIALIZED, inputs=dict(prompts=[]), outputs=LLMResult(generations=[[]]), run_type="llm", @@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None: extra={}, execution_order=4, child_execution_order=4, - serialized={"name": "llm"}, + serialized=SERIALIZED, inputs=dict(prompts=[]), outputs=LLMResult(generations=[[]]), run_type="llm", @@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, inputs=dict(prompts=[]), outputs=None, error=repr(exception), @@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None: ) tracer = FakeTracer() - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None: serialized={"name": "chain"}, inputs={}, run_id=chain_uuid ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid, ) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid, @@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None: parent_run_id=chain_uuid, ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid, @@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=2, child_execution_order=2, - serialized={"name": "llm"}, + serialized=SERIALIZED, error=None, inputs=dict(prompts=[]), outputs=LLMResult(generations=[[]], llm_output=None), @@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=3, child_execution_order=3, - serialized={"name": "llm"}, + serialized=SERIALIZED, error=None, inputs=dict(prompts=[]), outputs=LLMResult(generations=[[]], llm_output=None), @@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=5, child_execution_order=5, - serialized={"name": "llm"}, + serialized=SERIALIZED, error=repr(exception), inputs=dict(prompts=[]), outputs=None, diff --git a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index ab655ac6..782f3fbc 100644 --- a/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -22,6 +22,9 @@ from langchain.schema import LLMResult TEST_SESSION_ID = 2023 +SERIALIZED = {"id": ["llm"]} +SERIALIZED_CHAT = {"id": ["chat_model"]} + def load_session(session_name: str) -> TracerSessionV1: """Load a tracing session.""" @@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, @@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] @@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "chat_model"}, + serialized=SERIALIZED_CHAT, prompts=[""], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, @@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None: tracer.new_session() manager = CallbackManager(handlers=[tracer]) run_manager = manager.on_chat_model_start( - serialized={"name": "chat_model"}, messages=[[]], run_id=uuid + serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid ) run_manager.on_llm_end(response=LLMResult(generations=[[]])) assert tracer.runs == [compare_run] @@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, @@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None: tracer.new_session() num_runs = 10 for _ in range(num_runs): - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid) assert tracer.runs == [compare_run] * num_runs @@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None: parent_run_id=chain_uuid, ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid1, parent_run_id=tool_uuid, @@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None: tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_tool_end("test", run_id=tool_uuid) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid, @@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None: extra={}, execution_order=3, child_execution_order=3, - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, @@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None: extra={}, execution_order=4, child_execution_order=4, - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], response=LLMResult(generations=[[]]), session_id=TEST_SESSION_ID, @@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None: extra={}, execution_order=1, child_execution_order=1, - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], response=None, session_id=TEST_SESSION_ID, @@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None: tracer = FakeTracer() tracer.new_session() - tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid) + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) tracer.on_llm_error(exception, run_id=uuid) assert tracer.runs == [compare_run] @@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None: serialized={"name": "chain"}, inputs={}, run_id=chain_uuid ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid1, parent_run_id=chain_uuid, ) tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid2, parent_run_id=chain_uuid, @@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None: parent_run_id=chain_uuid, ) tracer.on_llm_start( - serialized={"name": "llm"}, + serialized=SERIALIZED, prompts=[], run_id=llm_uuid3, parent_run_id=tool_uuid, @@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=2, child_execution_order=2, - serialized={"name": "llm"}, + serialized=SERIALIZED, session_id=TEST_SESSION_ID, error=None, prompts=[], @@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=3, child_execution_order=3, - serialized={"name": "llm"}, + serialized=SERIALIZED, session_id=TEST_SESSION_ID, error=None, prompts=[], @@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None: extra={}, execution_order=5, child_execution_order=5, - serialized={"name": "llm"}, + serialized=SERIALIZED, session_id=TEST_SESSION_ID, error=repr(exception), prompts=[], diff --git a/tests/unit_tests/load/__snapshots__/test_dump.ambr b/tests/unit_tests/load/__snapshots__/test_dump.ambr new file mode 100644 index 00000000..e9f75fba --- /dev/null +++ b/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -0,0 +1,273 @@ +# serializer version: 1 +# name: test_person + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "test_dump", + "Person" + ], + "kwargs": { + "secret": { + "lc": 1, + "type": "secret", + "id": [ + "SECRET" + ] + }, + "you_can_see_me": "hello" + } + } + ''' +# --- +# name: test_person.1 + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "test_dump", + "SpecialPerson" + ], + "kwargs": { + "another_secret": { + "lc": 1, + "type": "secret", + "id": [ + "ANOTHER_SECRET" + ] + }, + "secret": { + "lc": 1, + "type": "secret", + "id": [ + "SECRET" + ] + }, + "another_visible": "bye", + "you_can_see_me": "hello" + } + } + ''' +# --- +# name: test_serialize_llmchain + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "chains", + "llm", + "LLMChain" + ], + "kwargs": { + "llm": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "llms", + "openai", + "OpenAI" + ], + "kwargs": { + "model": "davinci", + "temperature": 0.5, + "openai_api_key": { + "lc": 1, + "type": "secret", + "id": [ + "OPENAI_API_KEY" + ] + } + } + }, + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "name" + ], + "template": "hello {name}!", + "template_format": "f-string" + } + } + } + } + ''' +# --- +# name: test_serialize_llmchain_chat + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "chains", + "llm", + "LLMChain" + ], + "kwargs": { + "llm": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "chat_models", + "openai", + "ChatOpenAI" + ], + "kwargs": { + "model": "davinci", + "temperature": 0.5, + "openai_api_key": "hello" + } + }, + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "ChatPromptTemplate" + ], + "kwargs": { + "input_variables": [ + "name" + ], + "messages": [ + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "chat", + "HumanMessagePromptTemplate" + ], + "kwargs": { + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "name" + ], + "template": "hello {name}!", + "template_format": "f-string" + } + } + } + } + ] + } + } + } + } + ''' +# --- +# name: test_serialize_llmchain_with_non_serializable_arg + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "chains", + "llm", + "LLMChain" + ], + "kwargs": { + "llm": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "llms", + "openai", + "OpenAI" + ], + "kwargs": { + "model": "davinci", + "temperature": 0.5, + "openai_api_key": { + "lc": 1, + "type": "secret", + "id": [ + "OPENAI_API_KEY" + ] + }, + "client": { + "lc": 1, + "type": "not_implemented", + "id": [ + "openai", + "api_resources", + "completion", + "Completion" + ] + } + } + }, + "prompt": { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "prompts", + "prompt", + "PromptTemplate" + ], + "kwargs": { + "input_variables": [ + "name" + ], + "template": "hello {name}!", + "template_format": "f-string" + } + } + } + } + ''' +# --- +# name: test_serialize_openai_llm + ''' + { + "lc": 1, + "type": "constructor", + "id": [ + "langchain", + "llms", + "openai", + "OpenAI" + ], + "kwargs": { + "model": "davinci", + "temperature": 0.7, + "openai_api_key": { + "lc": 1, + "type": "secret", + "id": [ + "OPENAI_API_KEY" + ] + } + } + } + ''' +# --- diff --git a/tests/unit_tests/load/test_dump.py b/tests/unit_tests/load/test_dump.py new file mode 100644 index 00000000..45eab8eb --- /dev/null +++ b/tests/unit_tests/load/test_dump.py @@ -0,0 +1,103 @@ +"""Test for Serializable base class""" + +from typing import Any, Dict + +import pytest + +from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.chains.llm import LLMChain +from langchain.chat_models.openai import ChatOpenAI +from langchain.llms.openai import OpenAI +from langchain.load.dump import dumps +from langchain.load.serializable import Serializable +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain.prompts.prompt import PromptTemplate + + +class Person(Serializable): + secret: str + + you_can_see_me: str = "hello" + + @property + def lc_serializable(self) -> bool: + return True + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"secret": "SECRET"} + + @property + def lc_attributes(self) -> Dict[str, str]: + return {"you_can_see_me": self.you_can_see_me} + + +class SpecialPerson(Person): + another_secret: str + + another_visible: str = "bye" + + # Gets merged with parent class's secrets + @property + def lc_secrets(self) -> Dict[str, str]: + return {"another_secret": "ANOTHER_SECRET"} + + # Gets merged with parent class's attributes + @property + def lc_attributes(self) -> Dict[str, str]: + return {"another_visible": self.another_visible} + + +class NotSerializable: + pass + + +def test_person(snapshot: Any) -> None: + p = Person(secret="hello") + assert dumps(p, pretty=True) == snapshot + sp = SpecialPerson(another_secret="Wooo", secret="Hmm") + assert dumps(sp, pretty=True) == snapshot + + +@pytest.mark.requires("openai") +def test_serialize_openai_llm(snapshot: Any) -> None: + llm = OpenAI( + model="davinci", + temperature=0.5, + openai_api_key="hello", + # This is excluded from serialization + callbacks=[LangChainTracer()], + ) + llm.temperature = 0.7 # this is reflected in serialization + assert dumps(llm, pretty=True) == snapshot + + +@pytest.mark.requires("openai") +def test_serialize_llmchain(snapshot: Any) -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + assert dumps(chain, pretty=True) == snapshot + + +@pytest.mark.requires("openai") +def test_serialize_llmchain_chat(snapshot: Any) -> None: + llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = ChatPromptTemplate.from_messages( + [HumanMessagePromptTemplate.from_template("hello {name}!")] + ) + chain = LLMChain(llm=llm, prompt=prompt) + assert dumps(chain, pretty=True) == snapshot + + +@pytest.mark.requires("openai") +def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: + llm = OpenAI( + model="davinci", + temperature=0.5, + openai_api_key="hello", + client=NotSerializable, + ) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + assert dumps(chain, pretty=True) == snapshot diff --git a/tests/unit_tests/load/test_load.py b/tests/unit_tests/load/test_load.py new file mode 100644 index 00000000..8062d499 --- /dev/null +++ b/tests/unit_tests/load/test_load.py @@ -0,0 +1,54 @@ +"""Test for Serializable base class""" + +import pytest + +from langchain.chains.llm import LLMChain +from langchain.llms.openai import OpenAI +from langchain.load.dump import dumps +from langchain.load.load import loads +from langchain.prompts.prompt import PromptTemplate + + +class NotSerializable: + pass + + +@pytest.mark.requires("openai") +def test_load_openai_llm() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + llm_string = dumps(llm) + llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) + + assert llm2 == llm + assert dumps(llm2) == llm_string + assert isinstance(llm2, OpenAI) + + +@pytest.mark.requires("openai") +def test_load_llmchain() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_string = dumps(chain) + chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"}) + + assert chain2 == chain + assert dumps(chain2) == chain_string + assert isinstance(chain2, LLMChain) + assert isinstance(chain2.llm, OpenAI) + assert isinstance(chain2.prompt, PromptTemplate) + + +@pytest.mark.requires("openai") +def test_load_llmchain_with_non_serializable_arg() -> None: + llm = OpenAI( + model="davinci", + temperature=0.5, + openai_api_key="hello", + client=NotSerializable, + ) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_string = dumps(chain, pretty=True) + with pytest.raises(NotImplementedError): + loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"}) diff --git a/tests/unit_tests/test_dependencies.py b/tests/unit_tests/test_dependencies.py index 63342a25..16e930dc 100644 --- a/tests/unit_tests/test_dependencies.py +++ b/tests/unit_tests/test_dependencies.py @@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None: "pytest-socket", "pytest-watcher", "responses", + "syrupy", ]