forked from Archives/langchain
parent
614cff89bc
commit
18af149e91
@ -4,9 +4,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Set
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||
|
||||
|
||||
@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
||||
return tokenizer.encode(text)
|
||||
|
||||
|
||||
class BaseLanguageModel(BaseModel, ABC):
|
||||
class BaseLanguageModel(Serializable, ABC):
|
||||
@abstractmethod
|
||||
def generate_prompt(
|
||||
self,
|
||||
|
@ -204,7 +204,7 @@ def _handle_event(
|
||||
except Exception as e:
|
||||
if handler.raise_error:
|
||||
raise e
|
||||
logging.warning(f"Error in {event_name} callback: {e}")
|
||||
logger.warning(f"Error in {event_name} callback: {e}")
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
|
@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
llm_run = Run(
|
||||
id=run_id,
|
||||
name=serialized.get("name"),
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"prompts": prompts},
|
||||
@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
chain_run = Run(
|
||||
id=run_id,
|
||||
name=serialized.get("name"),
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs=inputs,
|
||||
@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
tool_run = Run(
|
||||
id=run_id,
|
||||
name=serialized.get("name"),
|
||||
parent_run_id=parent_run_id,
|
||||
serialized=serialized,
|
||||
inputs={"input": input_str},
|
||||
|
@ -124,7 +124,10 @@ class Run(RunBase):
|
||||
def assign_name(cls, values: dict) -> dict:
|
||||
"""Assign name to the run."""
|
||||
if "name" not in values:
|
||||
if "name" in values["serialized"]:
|
||||
values["name"] = values["serialized"]["name"]
|
||||
elif "id" in values["serialized"]:
|
||||
values["name"] = values["serialized"]["id"][-1]
|
||||
return values
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
from pydantic import Field, root_validator, validator
|
||||
|
||||
import langchain
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@ -18,6 +18,8 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
|
||||
|
||||
@ -25,7 +27,7 @@ def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
|
||||
|
||||
class Chain(BaseModel, ABC):
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
@ -131,7 +133,7 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
@ -179,7 +181,7 @@ class Chain(BaseModel, ABC):
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
try:
|
||||
|
@ -15,6 +15,7 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
@ -34,6 +35,10 @@ class LLMChain(Chain):
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
@ -147,7 +152,7 @@ class LLMChain(Chain):
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
@ -167,7 +172,7 @@ class LLMChain(Chain):
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
|
@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
|
@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel):
|
||||
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||
"""Model name to use."""
|
||||
|
@ -19,6 +19,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
)
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
existing_prompts,
|
||||
llm_string,
|
||||
@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
(
|
||||
existing_prompts,
|
||||
llm_string,
|
||||
@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
|
@ -123,6 +123,14 @@ async def acompletion_with_retry(
|
||||
class BaseOpenAI(BaseLLM):
|
||||
"""Wrapper around OpenAI large language models."""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = Field("text-davinci-003", alias="model")
|
||||
"""Model name to use."""
|
||||
|
0
langchain/load/__init__.py
Normal file
0
langchain/load/__init__.py
Normal file
22
langchain/load/dump.py
Normal file
22
langchain/load/dump.py
Normal file
@ -0,0 +1,22 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
else:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
if pretty:
|
||||
return json.dumps(obj, default=default, indent=2)
|
||||
else:
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||
return json.loads(dumps(obj))
|
65
langchain/load/load.py
Normal file
65
langchain/load/load.py
Normal file
@ -0,0 +1,65 @@
|
||||
import importlib
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
class Reviver:
|
||||
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
|
||||
self.secrets_map = secrets_map or dict()
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
return self.secrets_map[key]
|
||||
else:
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
|
||||
# Currently, we only support langchain imports.
|
||||
if namespace[0] != "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", dict())
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
|
||||
return json.loads(text, object_hook=Reviver(secrets_map))
|
135
langchain/load/serializable.py
Normal file
135
langchain/load/serializable.py
Normal file
@ -0,0 +1,135 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
lc: int
|
||||
id: List[str]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
type: Literal["secret"]
|
||||
|
||||
|
||||
class SerializedNotImplemented(BaseSerialized):
|
||||
type: Literal["not_implemented"]
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""
|
||||
Return whether or not the class is serializable.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
"""
|
||||
Return the namespace of the langchain object.
|
||||
eg. ["langchain", "llms", "openai"]
|
||||
"""
|
||||
return self.__class__.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
"""
|
||||
Return a map of constructor argument names to secret ids.
|
||||
eg. {"openai_api_key": "OPENAI_API_KEY"}
|
||||
"""
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
constructor.
|
||||
"""
|
||||
return {}
|
||||
|
||||
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.lc_kwargs = kwargs
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
if not self.lc_serializable:
|
||||
return self.to_json_not_implemented()
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self.lc_kwargs.items()
|
||||
if not self.__exclude_fields__.get(k, False) # type: ignore
|
||||
}
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
# Once we get to Serializable, we're done
|
||||
if cls is Serializable:
|
||||
break
|
||||
|
||||
# Get a reference to self bound to each class in the MRO
|
||||
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||
|
||||
secrets.update(this.lc_secrets)
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [*self.lc_namespace, self.__class__.__name__],
|
||||
"kwargs": lc_kwargs
|
||||
if not secrets
|
||||
else _replace_secrets(lc_kwargs, secrets),
|
||||
}
|
||||
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
current = result
|
||||
for part in parts:
|
||||
if part not in current:
|
||||
break
|
||||
current[part] = current[part].copy()
|
||||
current = current[part]
|
||||
if last in current:
|
||||
current[last] = {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [secret_id],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
_id: List[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
elif hasattr(obj, "__class__"):
|
||||
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": _id,
|
||||
}
|
@ -7,9 +7,10 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.formatting import formatter
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
|
||||
|
||||
|
||||
@ -100,7 +101,7 @@ class StringPromptValue(PromptValue):
|
||||
return [HumanMessage(content=self.text)]
|
||||
|
||||
|
||||
class BasePromptTemplate(BaseModel, ABC):
|
||||
class BasePromptTemplate(Serializable, ABC):
|
||||
"""Base class for all prompt templates, returning a prompt."""
|
||||
|
||||
input_variables: List[str]
|
||||
@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC):
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
|
@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.memory.buffer import get_buffer_string
|
||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
@ -20,7 +21,11 @@ from langchain.schema import (
|
||||
)
|
||||
|
||||
|
||||
class BaseMessagePromptTemplate(BaseModel, ABC):
|
||||
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""To messages."""
|
||||
@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
return "chat"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate
|
||||
class FewShotPromptTemplate(StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return False
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate):
|
||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"template_format": self.template_format,
|
||||
}
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
|
@ -17,6 +17,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
|
||||
@ -55,7 +57,7 @@ class AgentFinish(NamedTuple):
|
||||
log: str
|
||||
|
||||
|
||||
class Generation(BaseModel):
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
@ -67,7 +69,7 @@ class Generation(BaseModel):
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
class BaseMessage(Serializable):
|
||||
"""Message object."""
|
||||
|
||||
content: str
|
||||
@ -194,7 +196,7 @@ class LLMResult(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
class PromptValue(Serializable, ABC):
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC):
|
||||
"""Return prompt as messages."""
|
||||
|
||||
|
||||
class BaseMemory(BaseModel, ABC):
|
||||
class BaseMemory(Serializable, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
|
||||
class Config:
|
||||
@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC):
|
||||
"""Remove all messages from the store"""
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
class Document(Serializable):
|
||||
"""Interface for interacting with a document."""
|
||||
|
||||
page_content: str
|
||||
@ -321,7 +323,7 @@ Memory = BaseMemory
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
||||
class BaseOutputParser(Serializable, ABC, Generic[T]):
|
||||
"""Class to parse the output of an LLM call.
|
||||
|
||||
Output parsers help structure language model responses.
|
||||
|
31
poetry.lock
generated
31
poetry.lock
generated
@ -1417,6 +1417,17 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "1.4.4"
|
||||
description = "Simple library for color and formatting to terminal"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coloredlogs"
|
||||
version = "15.0.1"
|
||||
@ -9461,6 +9472,22 @@ files = [
|
||||
[package.dependencies]
|
||||
mpmath = ">=0.19"
|
||||
|
||||
[[package]]
|
||||
name = "syrupy"
|
||||
version = "4.0.2"
|
||||
description = "Pytest Snapshot Test Utility"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4"
|
||||
files = [
|
||||
{file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"},
|
||||
{file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colored = ">=1.3.92,<2.0.0"
|
||||
pytest = ">=7.0.0,<8.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "tabulate"
|
||||
version = "0.9.0"
|
||||
@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for
|
||||
cohere = ["cohere"]
|
||||
docarray = ["docarray"]
|
||||
embeddings = ["sentence-transformers"]
|
||||
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"]
|
||||
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"]
|
||||
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
|
||||
openai = ["openai", "tiktoken"]
|
||||
qdrant = ["qdrant-client"]
|
||||
@ -11437,4 +11464,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5"
|
||||
content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f"
|
||||
|
@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3"
|
||||
lark = "^1.1.5"
|
||||
pytest-mock = "^3.10.0"
|
||||
pytest-socket = "^0.6.0"
|
||||
syrupy = "^4.0.2"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
@ -315,7 +316,8 @@ extended_testing = [
|
||||
"html2text",
|
||||
"py-trello",
|
||||
"scikit-learn",
|
||||
"pyspark"
|
||||
"pyspark",
|
||||
"openai"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api"
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
|
@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs=dict(prompts=[""]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=repr(exception),
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
|
@ -22,6 +22,9 @@ from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
prompts=[""],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
|
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
@ -0,0 +1,273 @@
|
||||
# serializer version: 1
|
||||
# name: test_person
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"Person"
|
||||
],
|
||||
"kwargs": {
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_person.1
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"SpecialPerson"
|
||||
],
|
||||
"kwargs": {
|
||||
"another_secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"ANOTHER_SECRET"
|
||||
]
|
||||
},
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"another_visible": "bye",
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_chat
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": "hello"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_with_non_serializable_arg
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"client": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"openai",
|
||||
"api_resources",
|
||||
"completion",
|
||||
"Completion"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_openai_llm
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.7,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
103
tests/unit_tests/load/test_dump.py
Normal file
103
tests/unit_tests/load/test_dump.py
Normal file
@ -0,0 +1,103 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class Person(Serializable):
|
||||
secret: str
|
||||
|
||||
you_can_see_me: str = "hello"
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"secret": "SECRET"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"you_can_see_me": self.you_can_see_me}
|
||||
|
||||
|
||||
class SpecialPerson(Person):
|
||||
another_secret: str
|
||||
|
||||
another_visible: str = "bye"
|
||||
|
||||
# Gets merged with parent class's secrets
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"another_secret": "ANOTHER_SECRET"}
|
||||
|
||||
# Gets merged with parent class's attributes
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"another_visible": self.another_visible}
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
def test_person(snapshot: Any) -> None:
|
||||
p = Person(secret="hello")
|
||||
assert dumps(p, pretty=True) == snapshot
|
||||
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
|
||||
assert dumps(sp, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_openai_llm(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
# This is excluded from serialization
|
||||
callbacks=[LangChainTracer()],
|
||||
)
|
||||
llm.temperature = 0.7 # this is reflected in serialization
|
||||
assert dumps(llm, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain(snapshot: Any) -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
||||
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
54
tests/unit_tests/load/test_load.py
Normal file
54
tests/unit_tests/load/test_load.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert llm2 == llm
|
||||
assert dumps(llm2) == llm_string
|
||||
assert isinstance(llm2, OpenAI)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain)
|
||||
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert chain2 == chain
|
||||
assert dumps(chain2) == chain_string
|
||||
assert isinstance(chain2, LLMChain)
|
||||
assert isinstance(chain2.llm, OpenAI)
|
||||
assert isinstance(chain2.prompt, PromptTemplate)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain, pretty=True)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"pytest-socket",
|
||||
"pytest-watcher",
|
||||
"responses",
|
||||
"syrupy",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user