add `load()` deserializer function that bypasses need for json serialization (#7626)

There is already a `loads()` function which takes a JSON string and
loads it using the Reviver

But in the callbacks system, there is a `serialized` object that is
passed in and that object is already a deserialized JSON-compatible
object. This allows you to call `load(serialized)` and bypass
intermediate JSON encoding.

I found one other place in the code that benefited from this
short-circuiting (string_run_evaluator.py) so I fixed that too.

Tagging @baskaryan for general/utility stuff.

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/8747/head
Alec Flett 1 year ago committed by GitHub
parent 6aee589eec
commit f0b0c72d98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,7 +55,7 @@ class Reviver:
raise ValueError(f"Invalid namespace: {value}")
# The root namespace "langchain" is not a valid identifier.
if len(namespace) == 1:
if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
@ -79,7 +79,8 @@ def loads(
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Load a JSON object from a string.
"""Revive a LangChain class from a JSON string.
Equivalent to `load(json.loads(text))`.
Args:
text: The string to load.
@ -88,6 +89,38 @@ def loads(
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
def load(
obj: Any,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
"""Revive a LangChain class from a JSON object. Use this if you already
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
Args:
obj: The object to load.
secrets_map: A map of secrets to load.
valid_namespaces: A list of additional namespaces (modules)
to allow to be deserialized.
Returns:
Revived LangChain objects.
"""
reviver = Reviver(secrets_map, valid_namespaces)
def _load(obj: Any) -> Any:
if isinstance(obj, dict):
# Need to revive leaf nodes before reviving this node
loaded_obj = {k: _load(v) for k, v in obj.items()}
return reviver(loaded_obj)
if isinstance(obj, list):
return [_load(o) for o in obj]
return obj
return _load(obj)

@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.evaluation.schema import StringEvaluator
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.load.dump import dumpd
from langchain.load.load import load
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, messages_from_dict
from langchain.schema.messages import BaseMessage, get_buffer_string
@ -25,7 +25,7 @@ def _get_messages_from_run_dict(messages: List[dict]) -> List[BaseMessage]:
return []
first_message = messages[0]
if "lc" in first_message:
return [loads(dumps(message)) for message in messages]
return [load(dumpd(message)) for message in messages]
else:
return messages_from_dict(messages)

@ -4,8 +4,8 @@ import pytest
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.load.dump import dumpd, dumps
from langchain.load.load import load, loads
from langchain.prompts.prompt import PromptTemplate
@ -14,7 +14,7 @@ class NotSerializable:
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
def test_loads_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
@ -25,7 +25,7 @@ def test_load_openai_llm() -> None:
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
def test_loads_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
@ -40,7 +40,7 @@ def test_load_llmchain() -> None:
@pytest.mark.requires("openai")
def test_load_llmchain_env() -> None:
def test_loads_llmchain_env() -> None:
import os
has_env = "OPENAI_API_KEY" in os.environ
@ -64,7 +64,7 @@ def test_load_llmchain_env() -> None:
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
def test_loads_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
@ -76,3 +76,68 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
chain_string = dumps(chain, pretty=True)
with pytest.raises(NotImplementedError):
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_obj = dumpd(llm)
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm
assert dumpd(llm2) == llm_obj
assert isinstance(llm2, OpenAI)
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_obj = dumpd(chain)
chain2 = load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})
assert chain2 == chain
assert dumpd(chain2) == chain_obj
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_env() -> None:
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm = OpenAI(model="davinci", temperature=0.5)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_obj = dumpd(chain)
chain2 = load(chain_obj)
assert chain2 == chain
assert dumpd(chain2) == chain_obj
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_obj = dumpd(chain)
with pytest.raises(NotImplementedError):
load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})

Loading…
Cancel
Save