From 4da43f77e5bf3d25f5b7ece8bcba1ab7c6a9abb2 Mon Sep 17 00:00:00 2001 From: Alec Flett Date: Wed, 26 Jul 2023 08:59:28 -0700 Subject: [PATCH] Add ability to load (deserialize) objects from other namespaces (#7726) I have some Prompt subclasses in my project that I'd like to be able to deserialize in callbacks. Right now `loads()`/`load()` will bail when it encounters my object, but I know I can trust the objects because they're in my own projects. --- libs/langchain/langchain/load/load.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/load/load.py b/libs/langchain/langchain/load/load.py index 5cd8bb5e34..fe3653d550 100644 --- a/libs/langchain/langchain/load/load.py +++ b/libs/langchain/langchain/load/load.py @@ -1,7 +1,7 @@ import importlib import json import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from langchain.load.serializable import Serializable @@ -9,8 +9,16 @@ from langchain.load.serializable import Serializable class Reviver: """Reviver for JSON objects.""" - def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None: + def __init__( + self, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, + ) -> None: self.secrets_map = secrets_map or dict() + # By default only support langchain, but user can pass in additional namespaces + self.valid_namespaces = ( + ["langchain", *valid_namespaces] if valid_namespaces else ["langchain"] + ) def __call__(self, value: Dict[str, Any]) -> Any: if ( @@ -43,8 +51,7 @@ class Reviver: ): [*namespace, name] = value["id"] - # Currently, we only support langchain imports. - if namespace[0] != "langchain": + if namespace[0] not in self.valid_namespaces: raise ValueError(f"Invalid namespace: {value}") # The root namespace "langchain" is not a valid identifier. @@ -66,14 +73,21 @@ class Reviver: return value -def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any: +def loads( + text: str, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: """Load a JSON object from a string. Args: text: The string to load. secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. Returns: """ - return json.loads(text, object_hook=Reviver(secrets_map)) + return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))