mirror of https://github.com/hwchase17/langchain
parent
614cff89bc
commit
18af149e91
@ -0,0 +1,22 @@
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain.load.serializable import Serializable, to_json_not_implemented
|
||||
|
||||
|
||||
def default(obj: Any) -> Any:
|
||||
if isinstance(obj, Serializable):
|
||||
return obj.to_json()
|
||||
else:
|
||||
return to_json_not_implemented(obj)
|
||||
|
||||
|
||||
def dumps(obj: Any, *, pretty: bool = False) -> str:
|
||||
if pretty:
|
||||
return json.dumps(obj, default=default, indent=2)
|
||||
else:
|
||||
return json.dumps(obj, default=default)
|
||||
|
||||
|
||||
def dumpd(obj: Any) -> Dict[str, Any]:
|
||||
return json.loads(dumps(obj))
|
@ -0,0 +1,65 @@
|
||||
import importlib
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
|
||||
|
||||
class Reviver:
|
||||
def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None:
|
||||
self.secrets_map = secrets_map or dict()
|
||||
|
||||
def __call__(self, value: Dict[str, Any]) -> Any:
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "secret"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[key] = value["id"]
|
||||
if key in self.secrets_map:
|
||||
return self.secrets_map[key]
|
||||
else:
|
||||
raise KeyError(f'Missing key "{key}" in load(secrets_map)')
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "not_implemented"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Trying to load an object that doesn't implement "
|
||||
f"serialization: {value}"
|
||||
)
|
||||
|
||||
if (
|
||||
value.get("lc", None) == 1
|
||||
and value.get("type", None) == "constructor"
|
||||
and value.get("id", None) is not None
|
||||
):
|
||||
[*namespace, name] = value["id"]
|
||||
|
||||
# Currently, we only support langchain imports.
|
||||
if namespace[0] != "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
cls = getattr(mod, name)
|
||||
|
||||
# The class must be a subclass of Serializable.
|
||||
if not issubclass(cls, Serializable):
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# We don't need to recurse on kwargs
|
||||
# as json.loads will do that for us.
|
||||
kwargs = value.get("kwargs", dict())
|
||||
return cls(**kwargs)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any:
|
||||
return json.loads(text, object_hook=Reviver(secrets_map))
|
@ -0,0 +1,135 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Literal, TypedDict, Union, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BaseSerialized(TypedDict):
|
||||
lc: int
|
||||
id: List[str]
|
||||
|
||||
|
||||
class SerializedConstructor(BaseSerialized):
|
||||
type: Literal["constructor"]
|
||||
kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
class SerializedSecret(BaseSerialized):
|
||||
type: Literal["secret"]
|
||||
|
||||
|
||||
class SerializedNotImplemented(BaseSerialized):
|
||||
type: Literal["not_implemented"]
|
||||
|
||||
|
||||
class Serializable(BaseModel, ABC):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""
|
||||
Return whether or not the class is serializable.
|
||||
"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
"""
|
||||
Return the namespace of the langchain object.
|
||||
eg. ["langchain", "llms", "openai"]
|
||||
"""
|
||||
return self.__class__.__module__.split(".")
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
"""
|
||||
Return a map of constructor argument names to secret ids.
|
||||
eg. {"openai_api_key": "OPENAI_API_KEY"}
|
||||
"""
|
||||
return dict()
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""
|
||||
Return a list of attribute names that should be included in the
|
||||
serialized kwargs. These attributes must be accepted by the
|
||||
constructor.
|
||||
"""
|
||||
return {}
|
||||
|
||||
lc_kwargs: Dict[str, Any] = Field(default_factory=dict, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.lc_kwargs = kwargs
|
||||
|
||||
def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
|
||||
if not self.lc_serializable:
|
||||
return self.to_json_not_implemented()
|
||||
|
||||
secrets = dict()
|
||||
# Get latest values for kwargs if there is an attribute with same name
|
||||
lc_kwargs = {
|
||||
k: getattr(self, k, v)
|
||||
for k, v in self.lc_kwargs.items()
|
||||
if not self.__exclude_fields__.get(k, False) # type: ignore
|
||||
}
|
||||
|
||||
# Merge the lc_secrets and lc_attributes from every class in the MRO
|
||||
for cls in [None, *self.__class__.mro()]:
|
||||
# Once we get to Serializable, we're done
|
||||
if cls is Serializable:
|
||||
break
|
||||
|
||||
# Get a reference to self bound to each class in the MRO
|
||||
this = cast(Serializable, self if cls is None else super(cls, self))
|
||||
|
||||
secrets.update(this.lc_secrets)
|
||||
lc_kwargs.update(this.lc_attributes)
|
||||
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [*self.lc_namespace, self.__class__.__name__],
|
||||
"kwargs": lc_kwargs
|
||||
if not secrets
|
||||
else _replace_secrets(lc_kwargs, secrets),
|
||||
}
|
||||
|
||||
def to_json_not_implemented(self) -> SerializedNotImplemented:
|
||||
return to_json_not_implemented(self)
|
||||
|
||||
|
||||
def _replace_secrets(
|
||||
root: Dict[Any, Any], secrets_map: Dict[str, str]
|
||||
) -> Dict[Any, Any]:
|
||||
result = root.copy()
|
||||
for path, secret_id in secrets_map.items():
|
||||
[*parts, last] = path.split(".")
|
||||
current = result
|
||||
for part in parts:
|
||||
if part not in current:
|
||||
break
|
||||
current[part] = current[part].copy()
|
||||
current = current[part]
|
||||
if last in current:
|
||||
current[last] = {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [secret_id],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def to_json_not_implemented(obj: object) -> SerializedNotImplemented:
|
||||
_id: List[str] = []
|
||||
try:
|
||||
if hasattr(obj, "__name__"):
|
||||
_id = [*obj.__module__.split("."), obj.__name__]
|
||||
elif hasattr(obj, "__class__"):
|
||||
_id = [*obj.__class__.__module__.split("."), obj.__class__.__name__]
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": _id,
|
||||
}
|
@ -0,0 +1,273 @@
|
||||
# serializer version: 1
|
||||
# name: test_person
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"Person"
|
||||
],
|
||||
"kwargs": {
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_person.1
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"SpecialPerson"
|
||||
],
|
||||
"kwargs": {
|
||||
"another_secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"ANOTHER_SECRET"
|
||||
]
|
||||
},
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"another_visible": "bye",
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_chat
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": "hello"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_with_non_serializable_arg
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"client": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"openai",
|
||||
"api_resources",
|
||||
"completion",
|
||||
"Completion"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_openai_llm
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.7,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
@ -0,0 +1,103 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class Person(Serializable):
|
||||
secret: str
|
||||
|
||||
you_can_see_me: str = "hello"
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"secret": "SECRET"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"you_can_see_me": self.you_can_see_me}
|
||||
|
||||
|
||||
class SpecialPerson(Person):
|
||||
another_secret: str
|
||||
|
||||
another_visible: str = "bye"
|
||||
|
||||
# Gets merged with parent class's secrets
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"another_secret": "ANOTHER_SECRET"}
|
||||
|
||||
# Gets merged with parent class's attributes
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"another_visible": self.another_visible}
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
def test_person(snapshot: Any) -> None:
|
||||
p = Person(secret="hello")
|
||||
assert dumps(p, pretty=True) == snapshot
|
||||
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
|
||||
assert dumps(sp, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_openai_llm(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
# This is excluded from serialization
|
||||
callbacks=[LangChainTracer()],
|
||||
)
|
||||
llm.temperature = 0.7 # this is reflected in serialization
|
||||
assert dumps(llm, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain(snapshot: Any) -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
||||
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
@ -0,0 +1,54 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert llm2 == llm
|
||||
assert dumps(llm2) == llm_string
|
||||
assert isinstance(llm2, OpenAI)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain)
|
||||
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert chain2 == chain
|
||||
assert dumps(chain2) == chain_string
|
||||
assert isinstance(chain2, LLMChain)
|
||||
assert isinstance(chain2.llm, OpenAI)
|
||||
assert isinstance(chain2.prompt, PromptTemplate)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain, pretty=True)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
Loading…
Reference in New Issue