mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
parent
614cff89bc
commit
18af149e91
@ -4,9 +4,8 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Optional, Sequence, Set
|
from typing import Any, List, Optional, Sequence, Set
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
|||||||
return tokenizer.encode(text)
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
class BaseLanguageModel(BaseModel, ABC):
|
class BaseLanguageModel(Serializable, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -204,7 +204,7 @@ def _handle_event(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if handler.raise_error:
|
if handler.raise_error:
|
||||||
raise e
|
raise e
|
||||||
logging.warning(f"Error in {event_name} callback: {e}")
|
logger.warning(f"Error in {event_name} callback: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _ahandle_event_for_handler(
|
async def _ahandle_event_for_handler(
|
||||||
|
@ -93,7 +93,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
llm_run = Run(
|
llm_run = Run(
|
||||||
id=run_id,
|
id=run_id,
|
||||||
name=serialized.get("name"),
|
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs={"prompts": prompts},
|
inputs={"prompts": prompts},
|
||||||
@ -154,7 +153,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
chain_run = Run(
|
chain_run = Run(
|
||||||
id=run_id,
|
id=run_id,
|
||||||
name=serialized.get("name"),
|
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -216,7 +214,6 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order = self._get_execution_order(parent_run_id_)
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
tool_run = Run(
|
tool_run = Run(
|
||||||
id=run_id,
|
id=run_id,
|
||||||
name=serialized.get("name"),
|
|
||||||
parent_run_id=parent_run_id,
|
parent_run_id=parent_run_id,
|
||||||
serialized=serialized,
|
serialized=serialized,
|
||||||
inputs={"input": input_str},
|
inputs={"input": input_str},
|
||||||
|
@ -124,7 +124,10 @@ class Run(RunBase):
|
|||||||
def assign_name(cls, values: dict) -> dict:
|
def assign_name(cls, values: dict) -> dict:
|
||||||
"""Assign name to the run."""
|
"""Assign name to the run."""
|
||||||
if "name" not in values:
|
if "name" not in values:
|
||||||
|
if "name" in values["serialized"]:
|
||||||
values["name"] = values["serialized"]["name"]
|
values["name"] = values["serialized"]["name"]
|
||||||
|
elif "id" in values["serialized"]:
|
||||||
|
values["name"] = values["serialized"]["id"][-1]
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
from pydantic import Field, root_validator, validator
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@ -18,6 +18,8 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
|
|
||||||
|
|
||||||
@ -25,7 +27,7 @@ def _get_verbosity() -> bool:
|
|||||||
return langchain.verbose
|
return langchain.verbose
|
||||||
|
|
||||||
|
|
||||||
class Chain(BaseModel, ABC):
|
class Chain(Serializable, ABC):
|
||||||
"""Base interface that all chains should implement."""
|
"""Base interface that all chains should implement."""
|
||||||
|
|
||||||
memory: Optional[BaseMemory] = None
|
memory: Optional[BaseMemory] = None
|
||||||
@ -131,7 +133,7 @@ class Chain(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -179,7 +181,7 @@ class Chain(BaseModel, ABC):
|
|||||||
)
|
)
|
||||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
inputs,
|
inputs,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -15,6 +15,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_colored_text
|
from langchain.input import get_colored_text
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import LLMResult, PromptValue
|
from langchain.schema import LLMResult, PromptValue
|
||||||
@ -34,6 +35,10 @@ class LLMChain(Chain):
|
|||||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
prompt: BasePromptTemplate
|
prompt: BasePromptTemplate
|
||||||
"""Prompt object to use."""
|
"""Prompt object to use."""
|
||||||
llm: BaseLanguageModel
|
llm: BaseLanguageModel
|
||||||
@ -147,7 +152,7 @@ class LLMChain(Chain):
|
|||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
{"input_list": input_list},
|
{"input_list": input_list},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -167,7 +172,7 @@ class LLMChain(Chain):
|
|||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
{"input_list": input_list},
|
{"input_list": input_list},
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
@ -17,6 +17,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -70,12 +71,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
|
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_chat_model_start(
|
run_manager = callback_manager.on_chat_model_start(
|
||||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
dumpd(self), messages, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
@ -109,12 +111,13 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_chat_model_start(
|
run_manager = await callback_manager.on_chat_model_start(
|
||||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
dumpd(self), messages, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
|
@ -136,6 +136,10 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
|
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
@ -19,6 +19,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.load.dump import dumpd
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -166,6 +167,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
)
|
)
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
(
|
(
|
||||||
existing_prompts,
|
existing_prompts,
|
||||||
llm_string,
|
llm_string,
|
||||||
@ -186,7 +188,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_llm_start(
|
run_manager = callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
dumpd(self), prompts, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = (
|
output = (
|
||||||
@ -205,9 +207,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_manager = callback_manager.on_llm_start(
|
run_manager = callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
missing_prompts,
|
missing_prompts,
|
||||||
invocation_params=params,
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
new_results = (
|
new_results = (
|
||||||
@ -243,6 +246,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
params = self.dict()
|
params = self.dict()
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
|
options = {"stop": stop}
|
||||||
(
|
(
|
||||||
existing_prompts,
|
existing_prompts,
|
||||||
llm_string,
|
llm_string,
|
||||||
@ -263,7 +267,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
"Asked to cache, but no cache found at `langchain.cache`."
|
"Asked to cache, but no cache found at `langchain.cache`."
|
||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_llm_start(
|
run_manager = await callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
dumpd(self), prompts, invocation_params=params, options=options
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = (
|
output = (
|
||||||
@ -282,9 +286,10 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_manager = await callback_manager.on_llm_start(
|
run_manager = await callback_manager.on_llm_start(
|
||||||
{"name": self.__class__.__name__},
|
dumpd(self),
|
||||||
missing_prompts,
|
missing_prompts,
|
||||||
invocation_params=params,
|
invocation_params=params,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
new_results = (
|
new_results = (
|
||||||
|
@ -123,6 +123,14 @@ async def acompletion_with_retry(
|
|||||||
class BaseOpenAI(BaseLLM):
|
class BaseOpenAI(BaseLLM):
|
||||||
"""Wrapper around OpenAI large language models."""
|
"""Wrapper around OpenAI large language models."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = Field("text-davinci-003", alias="model")
|
model_name: str = Field("text-davinci-003", alias="model")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
0
langchain/load/__init__.py
Normal file
0
langchain/load/__init__.py
Normal file
22
langchain/load/dump.py
Normal file
22
langchain/load/dump.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable, to_json_not_implemented
|
||||||
|
|
||||||
|
|
||||||
|
def default(obj: Any) -> Any:
|
||||||
|
if isinstance(obj, Serializable):
|
||||||
|
return obj.to_json()
|
||||||
|
else:
|
||||||
|
return to_json_not_implemented(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||||
|
if pretty:
|
||||||
|
return json.dumps(obj, default=default, indent=2)
|
||||||
|
else:
|
||||||
|
return json.dumps(obj, default=default)
|
||||||
|
|
||||||
|
|
||||||
|
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||||
|
return json.loads(dumps(obj))
|
65
langchain/load/load.py
Normal file
65
langchain/load/load.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
|
||||||
|
|
||||||
|
class Reviver:
|
||||||
|
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
|
||||||
|
self.secrets_map = secrets_map or dict()
|
||||||
|
|
||||||
|
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||||
|
if (
|
||||||
|
value.get("lc", None) == 1
|
||||||
|
and value.get("type", None) == "secret"
|
||||||
|
and value.get("id", None) is not None
|
||||||
|
):
|
||||||
|
[key] = value["id"]
|
||||||
|
if key in self.secrets_map:
|
||||||
|
return self.secrets_map[key]
|
||||||
|
else:
|
||||||
|
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||||
|
|
||||||
|
if (
|
||||||
|
value.get("lc", None) == 1
|
||||||
|
and value.get("type", None) == "not_implemented"
|
||||||
|
and value.get("id", None) is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Trying to load an object that doesn't implement "
|
||||||
|
f"serialization: {value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
value.get("lc", None) == 1
|
||||||
|
and value.get("type", None) == "constructor"
|
||||||
|
and value.get("id", None) is not None
|
||||||
|
):
|
||||||
|
[*namespace, name] = value["id"]
|
||||||
|
|
||||||
|
# Currently, we only support langchain imports.
|
||||||
|
if namespace[0] != "langchain":
|
||||||
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
|
# The root namespace "langchain" is not a valid identifier.
|
||||||
|
if len(namespace) == 1:
|
||||||
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
|
mod = importlib.import_module(".".join(namespace))
|
||||||
|
cls = getattr(mod, name)
|
||||||
|
|
||||||
|
# The class must be a subclass of Serializable.
|
||||||
|
if not issubclass(cls, Serializable):
|
||||||
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
|
# We don't need to recurse on kwargs
|
||||||
|
# as json.loads will do that for us.
|
||||||
|
kwargs = value.get("kwargs", dict())
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
|
||||||
|
return json.loads(text, object_hook=Reviver(secrets_map))
|
135
langchain/load/serializable.py
Normal file
135
langchain/load/serializable.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSerialized(TypedDict):
|
||||||
|
lc: int
|
||||||
|
id: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedConstructor(BaseSerialized):
|
||||||
|
type: Literal["constructor"]
|
||||||
|
kwargs: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedSecret(BaseSerialized):
|
||||||
|
type: Literal["secret"]
|
||||||
|
|
||||||
|
|
||||||
|
class SerializedNotImplemented(BaseSerialized):
|
||||||
|
type: Literal["not_implemented"]
|
||||||
|
|
||||||
|
|
||||||
|
class Serializable(BaseModel, ABC):
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""
|
||||||
|
Return whether or not the class is serializable.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Return the namespace of the langchain object.
|
||||||
|
eg. ["langchain", "llms", "openai"]
|
||||||
|
"""
|
||||||
|
return self.__class__.__module__.split(".")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Return a map of constructor argument names to secret ids.
|
||||||
|
eg. {"openai_api_key": "OPENAI_API_KEY"}
|
||||||
|
"""
|
||||||
|
return dict()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Return a list of attribute names that should be included in the
|
||||||
|
serialized kwargs. These attributes must be accepted by the
|
||||||
|
constructor.
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.lc_kwargs = kwargs
|
||||||
|
|
||||||
|
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||||
|
if not self.lc_serializable:
|
||||||
|
return self.to_json_not_implemented()
|
||||||
|
|
||||||
|
secrets = dict()
|
||||||
|
# Get latest values for kwargs if there is an attribute with same name
|
||||||
|
lc_kwargs = {
|
||||||
|
k: getattr(self, k, v)
|
||||||
|
for k, v in self.lc_kwargs.items()
|
||||||
|
if not self.__exclude_fields__.get(k, False) # type: ignore
|
||||||
|
}
|
||||||
|
|
||||||
|
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||||
|
for cls in [None, *self.__class__.mro()]:
|
||||||
|
# Once we get to Serializable, we're done
|
||||||
|
if cls is Serializable:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get a reference to self bound to each class in the MRO
|
||||||
|
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||||
|
|
||||||
|
secrets.update(this.lc_secrets)
|
||||||
|
lc_kwargs.update(this.lc_attributes)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [*self.lc_namespace, self.__class__.__name__],
|
||||||
|
"kwargs": lc_kwargs
|
||||||
|
if not secrets
|
||||||
|
else _replace_secrets(lc_kwargs, secrets),
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||||
|
return to_json_not_implemented(self)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_secrets(
|
||||||
|
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||||
|
) -> Dict[Any, Any]:
|
||||||
|
result = root.copy()
|
||||||
|
for path, secret_id in secrets_map.items():
|
||||||
|
[*parts, last] = path.split(".")
|
||||||
|
current = result
|
||||||
|
for part in parts:
|
||||||
|
if part not in current:
|
||||||
|
break
|
||||||
|
current[part] = current[part].copy()
|
||||||
|
current = current[part]
|
||||||
|
if last in current:
|
||||||
|
current[last] = {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [secret_id],
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||||
|
_id: List[str] = []
|
||||||
|
try:
|
||||||
|
if hasattr(obj, "__name__"):
|
||||||
|
_id = [*obj.__module__.split("."), obj.__name__]
|
||||||
|
elif hasattr(obj, "__class__"):
|
||||||
|
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": _id,
|
||||||
|
}
|
@ -7,9 +7,10 @@ from pathlib import Path
|
|||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
from langchain.formatting import formatter
|
from langchain.formatting import formatter
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
|
from langchain.schema import BaseMessage, BaseOutputParser, HumanMessage, PromptValue
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +101,7 @@ class StringPromptValue(PromptValue):
|
|||||||
return [HumanMessage(content=self.text)]
|
return [HumanMessage(content=self.text)]
|
||||||
|
|
||||||
|
|
||||||
class BasePromptTemplate(BaseModel, ABC):
|
class BasePromptTemplate(Serializable, ABC):
|
||||||
"""Base class for all prompt templates, returning a prompt."""
|
"""Base class for all prompt templates, returning a prompt."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
@ -111,6 +112,10 @@ class BasePromptTemplate(BaseModel, ABC):
|
|||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.memory.buffer import get_buffer_string
|
from langchain.memory.buffer import get_buffer_string
|
||||||
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
@ -20,7 +21,11 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseMessagePromptTemplate(BaseModel, ABC):
|
class BaseMessagePromptTemplate(Serializable, ABC):
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||||
"""To messages."""
|
"""To messages."""
|
||||||
@ -220,7 +225,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
raise NotImplementedError
|
return "chat"
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -15,6 +15,10 @@ from langchain.prompts.prompt import PromptTemplate
|
|||||||
class FewShotPromptTemplate(StringPromptTemplate):
|
class FewShotPromptTemplate(StringPromptTemplate):
|
||||||
"""Prompt template that contains few shot examples."""
|
"""Prompt template that contains few shot examples."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
examples: Optional[List[dict]] = None
|
examples: Optional[List[dict]] = None
|
||||||
"""Examples to format into the prompt.
|
"""Examples to format into the prompt.
|
||||||
Either this or example_selector should be provided."""
|
Either this or example_selector should be provided."""
|
||||||
|
@ -25,6 +25,12 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"template_format": self.template_format,
|
||||||
|
}
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
"""A list of the names of the variables the prompt template expects."""
|
"""A list of the names of the variables the prompt template expects."""
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ from uuid import UUID
|
|||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
|
||||||
RUN_KEY = "__run"
|
RUN_KEY = "__run"
|
||||||
|
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class AgentFinish(NamedTuple):
|
|||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
|
||||||
class Generation(BaseModel):
|
class Generation(Serializable):
|
||||||
"""Output of a single generation."""
|
"""Output of a single generation."""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
@ -67,7 +69,7 @@ class Generation(BaseModel):
|
|||||||
# TODO: add log probs
|
# TODO: add log probs
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(BaseModel):
|
class BaseMessage(Serializable):
|
||||||
"""Message object."""
|
"""Message object."""
|
||||||
|
|
||||||
content: str
|
content: str
|
||||||
@ -194,7 +196,7 @@ class LLMResult(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(BaseModel, ABC):
|
class PromptValue(Serializable, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
"""Return prompt as string."""
|
"""Return prompt as string."""
|
||||||
@ -204,7 +206,7 @@ class PromptValue(BaseModel, ABC):
|
|||||||
"""Return prompt as messages."""
|
"""Return prompt as messages."""
|
||||||
|
|
||||||
|
|
||||||
class BaseMemory(BaseModel, ABC):
|
class BaseMemory(Serializable, ABC):
|
||||||
"""Base interface for memory in chains."""
|
"""Base interface for memory in chains."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -282,7 +284,7 @@ class BaseChatMessageHistory(ABC):
|
|||||||
"""Remove all messages from the store"""
|
"""Remove all messages from the store"""
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
class Document(Serializable):
|
||||||
"""Interface for interacting with a document."""
|
"""Interface for interacting with a document."""
|
||||||
|
|
||||||
page_content: str
|
page_content: str
|
||||||
@ -321,7 +323,7 @@ Memory = BaseMemory
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class BaseOutputParser(BaseModel, ABC, Generic[T]):
|
class BaseOutputParser(Serializable, ABC, Generic[T]):
|
||||||
"""Class to parse the output of an LLM call.
|
"""Class to parse the output of an LLM call.
|
||||||
|
|
||||||
Output parsers help structure language model responses.
|
Output parsers help structure language model responses.
|
||||||
|
31
poetry.lock
generated
31
poetry.lock
generated
@ -1417,6 +1417,17 @@ files = [
|
|||||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "colored"
|
||||||
|
version = "1.4.4"
|
||||||
|
description = "Simple library for color and formatting to terminal"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "colored-1.4.4.tar.gz", hash = "sha256:04ff4d4dd514274fe3b99a21bb52fb96f2688c01e93fba7bef37221e7cb56ce0"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coloredlogs"
|
name = "coloredlogs"
|
||||||
version = "15.0.1"
|
version = "15.0.1"
|
||||||
@ -9461,6 +9472,22 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=0.19"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syrupy"
|
||||||
|
version = "4.0.2"
|
||||||
|
description = "Pytest Snapshot Test Utility"
|
||||||
|
category = "dev"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8.1,<4"
|
||||||
|
files = [
|
||||||
|
{file = "syrupy-4.0.2-py3-none-any.whl", hash = "sha256:dfd1f0fad298eee753de4f2471d4346412c4435885c4b7beea648d4934c6620a"},
|
||||||
|
{file = "syrupy-4.0.2.tar.gz", hash = "sha256:3c75ab6866580679b2cb9abe78e74c3e2011fffc6333651c6beb2a78a716ab80"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
colored = ">=1.3.92,<2.0.0"
|
||||||
|
pytest = ">=7.0.0,<8.0.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tabulate"
|
name = "tabulate"
|
||||||
version = "0.9.0"
|
version = "0.9.0"
|
||||||
@ -11428,7 +11455,7 @@ azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-for
|
|||||||
cohere = ["cohere"]
|
cohere = ["cohere"]
|
||||||
docarray = ["docarray"]
|
docarray = ["docarray"]
|
||||||
embeddings = ["sentence-transformers"]
|
embeddings = ["sentence-transformers"]
|
||||||
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark"]
|
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "pyspark", "openai"]
|
||||||
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
|
llms = ["anthropic", "cohere", "openai", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
|
||||||
openai = ["openai", "tiktoken"]
|
openai = ["openai", "tiktoken"]
|
||||||
qdrant = ["qdrant-client"]
|
qdrant = ["qdrant-client"]
|
||||||
@ -11437,4 +11464,4 @@ text-helpers = ["chardet"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.8.1,<4.0"
|
python-versions = ">=3.8.1,<4.0"
|
||||||
content-hash = "ecf7086e83cc0ff19e6851c0b63170b082b267c1c1c00f47700fd3a8c8bb46c5"
|
content-hash = "7a39130af070d4a4fe6b0af5d6b70615c868ab0b1867e404060ff00eacd10f5f"
|
||||||
|
@ -139,6 +139,7 @@ pytest-asyncio = "^0.20.3"
|
|||||||
lark = "^1.1.5"
|
lark = "^1.1.5"
|
||||||
pytest-mock = "^3.10.0"
|
pytest-mock = "^3.10.0"
|
||||||
pytest-socket = "^0.6.0"
|
pytest-socket = "^0.6.0"
|
||||||
|
syrupy = "^4.0.2"
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
[tool.poetry.group.test_integration]
|
||||||
optional = true
|
optional = true
|
||||||
@ -315,7 +316,8 @@ extended_testing = [
|
|||||||
"html2text",
|
"html2text",
|
||||||
"py-trello",
|
"py-trello",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"pyspark"
|
"pyspark",
|
||||||
|
"openai"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
@ -349,7 +351,10 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||||
# --strict-config any warnings encountered while parsing the `pytest`
|
# --strict-config any warnings encountered while parsing the `pytest`
|
||||||
# section of the configuration file raise errors.
|
# section of the configuration file raise errors.
|
||||||
addopts = "--strict-markers --strict-config --durations=5"
|
#
|
||||||
|
# https://github.com/tophat/syrupy
|
||||||
|
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||||
|
addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused"
|
||||||
# Registering custom markers.
|
# Registering custom markers.
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||||
markers = [
|
markers = [
|
||||||
|
@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
|||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
SERIALIZED = {"id": ["llm"]}
|
||||||
|
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||||
|
|
||||||
|
|
||||||
class FakeTracer(BaseTracer):
|
class FakeTracer(BaseTracer):
|
||||||
"""Fake tracer that records LangChain execution."""
|
"""Fake tracer that records LangChain execution."""
|
||||||
@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
inputs={"prompts": []},
|
inputs={"prompts": []},
|
||||||
outputs=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]),
|
||||||
error=None,
|
error=None,
|
||||||
@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
|
|||||||
)
|
)
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "chat_model"},
|
serialized=SERIALIZED_CHAT,
|
||||||
inputs=dict(prompts=[""]),
|
inputs=dict(prompts=[""]),
|
||||||
outputs=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]),
|
||||||
error=None,
|
error=None,
|
||||||
@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
manager = CallbackManager(handlers=[tracer])
|
manager = CallbackManager(handlers=[tracer])
|
||||||
run_manager = manager.on_chat_model_start(
|
run_manager = manager.on_chat_model_start(
|
||||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||||
)
|
)
|
||||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]),
|
||||||
error=None,
|
error=None,
|
||||||
@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
|||||||
|
|
||||||
num_runs = 10
|
num_runs = 10
|
||||||
for _ in range(num_runs):
|
for _ in range(num_runs):
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
|
||||||
assert tracer.runs == [compare_run] * num_runs
|
assert tracer.runs == [compare_run] * num_runs
|
||||||
@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid1,
|
run_id=llm_uuid1,
|
||||||
parent_run_id=tool_uuid,
|
parent_run_id=tool_uuid,
|
||||||
@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid2,
|
run_id=llm_uuid2,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=3,
|
execution_order=3,
|
||||||
child_execution_order=3,
|
child_execution_order=3,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]),
|
||||||
run_type="llm",
|
run_type="llm",
|
||||||
@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=4,
|
execution_order=4,
|
||||||
child_execution_order=4,
|
child_execution_order=4,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=LLMResult(generations=[[]]),
|
outputs=LLMResult(generations=[[]]),
|
||||||
run_type="llm",
|
run_type="llm",
|
||||||
@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=None,
|
outputs=None,
|
||||||
error=repr(exception),
|
error=repr(exception),
|
||||||
@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
|
|||||||
)
|
)
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_error(exception, run_id=uuid)
|
tracer.on_llm_error(exception, run_id=uuid)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid1,
|
run_id=llm_uuid1,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid2,
|
run_id=llm_uuid2,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid3,
|
run_id=llm_uuid3,
|
||||||
parent_run_id=tool_uuid,
|
parent_run_id=tool_uuid,
|
||||||
@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=2,
|
execution_order=2,
|
||||||
child_execution_order=2,
|
child_execution_order=2,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
error=None,
|
error=None,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||||
@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=3,
|
execution_order=3,
|
||||||
child_execution_order=3,
|
child_execution_order=3,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
error=None,
|
error=None,
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||||
@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=5,
|
execution_order=5,
|
||||||
child_execution_order=5,
|
child_execution_order=5,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
error=repr(exception),
|
error=repr(exception),
|
||||||
inputs=dict(prompts=[]),
|
inputs=dict(prompts=[]),
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
@ -22,6 +22,9 @@ from langchain.schema import LLMResult
|
|||||||
|
|
||||||
TEST_SESSION_ID = 2023
|
TEST_SESSION_ID = 2023
|
||||||
|
|
||||||
|
SERIALIZED = {"id": ["llm"]}
|
||||||
|
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||||
|
|
||||||
|
|
||||||
def load_session(session_name: str) -> TracerSessionV1:
|
def load_session(session_name: str) -> TracerSessionV1:
|
||||||
"""Load a tracing session."""
|
"""Load a tracing session."""
|
||||||
@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
response=LLMResult(generations=[[]]),
|
response=LLMResult(generations=[[]]),
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "chat_model"},
|
serialized=SERIALIZED_CHAT,
|
||||||
prompts=[""],
|
prompts=[""],
|
||||||
response=LLMResult(generations=[[]]),
|
response=LLMResult(generations=[[]]),
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
|
|||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
manager = CallbackManager(handlers=[tracer])
|
manager = CallbackManager(handlers=[tracer])
|
||||||
run_manager = manager.on_chat_model_start(
|
run_manager = manager.on_chat_model_start(
|
||||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||||
)
|
)
|
||||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
response=LLMResult(generations=[[]]),
|
response=LLMResult(generations=[[]]),
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
|||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
num_runs = 10
|
num_runs = 10
|
||||||
for _ in range(num_runs):
|
for _ in range(num_runs):
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||||
|
|
||||||
assert tracer.runs == [compare_run] * num_runs
|
assert tracer.runs == [compare_run] * num_runs
|
||||||
@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid1,
|
run_id=llm_uuid1,
|
||||||
parent_run_id=tool_uuid,
|
parent_run_id=tool_uuid,
|
||||||
@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid2,
|
run_id=llm_uuid2,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=3,
|
execution_order=3,
|
||||||
child_execution_order=3,
|
child_execution_order=3,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
response=LLMResult(generations=[[]]),
|
response=LLMResult(generations=[[]]),
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=4,
|
execution_order=4,
|
||||||
child_execution_order=4,
|
child_execution_order=4,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
response=LLMResult(generations=[[]]),
|
response=LLMResult(generations=[[]]),
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=1,
|
execution_order=1,
|
||||||
child_execution_order=1,
|
child_execution_order=1,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
response=None,
|
response=None,
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
|
|||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
|
|
||||||
tracer.new_session()
|
tracer.new_session()
|
||||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||||
tracer.on_llm_error(exception, run_id=uuid)
|
tracer.on_llm_error(exception, run_id=uuid)
|
||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid1,
|
run_id=llm_uuid1,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid2,
|
run_id=llm_uuid2,
|
||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
parent_run_id=chain_uuid,
|
parent_run_id=chain_uuid,
|
||||||
)
|
)
|
||||||
tracer.on_llm_start(
|
tracer.on_llm_start(
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
run_id=llm_uuid3,
|
run_id=llm_uuid3,
|
||||||
parent_run_id=tool_uuid,
|
parent_run_id=tool_uuid,
|
||||||
@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=2,
|
execution_order=2,
|
||||||
child_execution_order=2,
|
child_execution_order=2,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=None,
|
error=None,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=3,
|
execution_order=3,
|
||||||
child_execution_order=3,
|
child_execution_order=3,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=None,
|
error=None,
|
||||||
prompts=[],
|
prompts=[],
|
||||||
@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
|||||||
extra={},
|
extra={},
|
||||||
execution_order=5,
|
execution_order=5,
|
||||||
child_execution_order=5,
|
child_execution_order=5,
|
||||||
serialized={"name": "llm"},
|
serialized=SERIALIZED,
|
||||||
session_id=TEST_SESSION_ID,
|
session_id=TEST_SESSION_ID,
|
||||||
error=repr(exception),
|
error=repr(exception),
|
||||||
prompts=[],
|
prompts=[],
|
||||||
|
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_person
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"test_dump",
|
||||||
|
"Person"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"secret": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"SECRET"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"you_can_see_me": "hello"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_person.1
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"test_dump",
|
||||||
|
"SpecialPerson"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"another_secret": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"ANOTHER_SECRET"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"secret": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"SECRET"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"another_visible": "bye",
|
||||||
|
"you_can_see_me": "hello"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_serialize_llmchain
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"chains",
|
||||||
|
"llm",
|
||||||
|
"LLMChain"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"llm": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"llms",
|
||||||
|
"openai",
|
||||||
|
"OpenAI"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"model": "davinci",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"openai_api_key": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"OPENAI_API_KEY"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"prompt",
|
||||||
|
"PromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"template": "hello {name}!",
|
||||||
|
"template_format": "f-string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_serialize_llmchain_chat
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"chains",
|
||||||
|
"llm",
|
||||||
|
"LLMChain"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"llm": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"chat_models",
|
||||||
|
"openai",
|
||||||
|
"ChatOpenAI"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"model": "davinci",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"openai_api_key": "hello"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"chat",
|
||||||
|
"ChatPromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"chat",
|
||||||
|
"HumanMessagePromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"prompt": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"prompt",
|
||||||
|
"PromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"template": "hello {name}!",
|
||||||
|
"template_format": "f-string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_serialize_llmchain_with_non_serializable_arg
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"chains",
|
||||||
|
"llm",
|
||||||
|
"LLMChain"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"llm": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"llms",
|
||||||
|
"openai",
|
||||||
|
"OpenAI"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"model": "davinci",
|
||||||
|
"temperature": 0.5,
|
||||||
|
"openai_api_key": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"OPENAI_API_KEY"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"client": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "not_implemented",
|
||||||
|
"id": [
|
||||||
|
"openai",
|
||||||
|
"api_resources",
|
||||||
|
"completion",
|
||||||
|
"Completion"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"prompt": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"prompts",
|
||||||
|
"prompt",
|
||||||
|
"PromptTemplate"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"input_variables": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
"template": "hello {name}!",
|
||||||
|
"template_format": "f-string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
||||||
|
# name: test_serialize_openai_llm
|
||||||
|
'''
|
||||||
|
{
|
||||||
|
"lc": 1,
|
||||||
|
"type": "constructor",
|
||||||
|
"id": [
|
||||||
|
"langchain",
|
||||||
|
"llms",
|
||||||
|
"openai",
|
||||||
|
"OpenAI"
|
||||||
|
],
|
||||||
|
"kwargs": {
|
||||||
|
"model": "davinci",
|
||||||
|
"temperature": 0.7,
|
||||||
|
"openai_api_key": {
|
||||||
|
"lc": 1,
|
||||||
|
"type": "secret",
|
||||||
|
"id": [
|
||||||
|
"OPENAI_API_KEY"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
# ---
|
103
tests/unit_tests/load/test_dump.py
Normal file
103
tests/unit_tests/load/test_dump.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""Test for Serializable base class"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.load.dump import dumps
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class Person(Serializable):
|
||||||
|
secret: str
|
||||||
|
|
||||||
|
you_can_see_me: str = "hello"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"secret": "SECRET"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, str]:
|
||||||
|
return {"you_can_see_me": self.you_can_see_me}
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialPerson(Person):
|
||||||
|
another_secret: str
|
||||||
|
|
||||||
|
another_visible: str = "bye"
|
||||||
|
|
||||||
|
# Gets merged with parent class's secrets
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"another_secret": "ANOTHER_SECRET"}
|
||||||
|
|
||||||
|
# Gets merged with parent class's attributes
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, str]:
|
||||||
|
return {"another_visible": self.another_visible}
|
||||||
|
|
||||||
|
|
||||||
|
class NotSerializable:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_person(snapshot: Any) -> None:
|
||||||
|
p = Person(secret="hello")
|
||||||
|
assert dumps(p, pretty=True) == snapshot
|
||||||
|
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
|
||||||
|
assert dumps(sp, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_serialize_openai_llm(snapshot: Any) -> None:
|
||||||
|
llm = OpenAI(
|
||||||
|
model="davinci",
|
||||||
|
temperature=0.5,
|
||||||
|
openai_api_key="hello",
|
||||||
|
# This is excluded from serialization
|
||||||
|
callbacks=[LangChainTracer()],
|
||||||
|
)
|
||||||
|
llm.temperature = 0.7 # this is reflected in serialization
|
||||||
|
assert dumps(llm, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_serialize_llmchain(snapshot: Any) -> None:
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
||||||
|
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||||
|
)
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||||
|
llm = OpenAI(
|
||||||
|
model="davinci",
|
||||||
|
temperature=0.5,
|
||||||
|
openai_api_key="hello",
|
||||||
|
client=NotSerializable,
|
||||||
|
)
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
54
tests/unit_tests/load/test_load.py
Normal file
54
tests/unit_tests/load/test_load.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""Test for Serializable base class"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
|
from langchain.load.dump import dumps
|
||||||
|
from langchain.load.load import loads
|
||||||
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
class NotSerializable:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_openai_llm() -> None:
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
llm_string = dumps(llm)
|
||||||
|
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
|
||||||
|
assert llm2 == llm
|
||||||
|
assert dumps(llm2) == llm_string
|
||||||
|
assert isinstance(llm2, OpenAI)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_llmchain() -> None:
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain_string = dumps(chain)
|
||||||
|
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
|
||||||
|
assert chain2 == chain
|
||||||
|
assert dumps(chain2) == chain_string
|
||||||
|
assert isinstance(chain2, LLMChain)
|
||||||
|
assert isinstance(chain2.llm, OpenAI)
|
||||||
|
assert isinstance(chain2.prompt, PromptTemplate)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||||
|
llm = OpenAI(
|
||||||
|
model="davinci",
|
||||||
|
temperature=0.5,
|
||||||
|
openai_api_key="hello",
|
||||||
|
client=NotSerializable,
|
||||||
|
)
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain_string = dumps(chain, pretty=True)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
|||||||
"pytest-socket",
|
"pytest-socket",
|
||||||
"pytest-watcher",
|
"pytest-watcher",
|
||||||
"responses",
|
"responses",
|
||||||
|
"syrupy",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user