diff --git a/libs/langchain/langchain/load/load.py b/libs/langchain/langchain/load/load.py index 5cd8bb5e34..fe3653d550 100644 --- a/libs/langchain/langchain/load/load.py +++ b/libs/langchain/langchain/load/load.py @@ -1,7 +1,7 @@ import importlib import json import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from langchain.load.serializable import Serializable @@ -9,8 +9,16 @@ from langchain.load.serializable import Serializable class Reviver: """Reviver for JSON objects.""" - def __init__(self, secrets_map: Optional[Dict[str, str]] = None) -> None: + def __init__( + self, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, + ) -> None: self.secrets_map = secrets_map or dict() + # By default only support langchain, but user can pass in additional namespaces + self.valid_namespaces = ( + ["langchain", *valid_namespaces] if valid_namespaces else ["langchain"] + ) def __call__(self, value: Dict[str, Any]) -> Any: if ( @@ -43,8 +51,7 @@ class Reviver: ): [*namespace, name] = value["id"] - # Currently, we only support langchain imports. - if namespace[0] != "langchain": + if namespace[0] not in self.valid_namespaces: raise ValueError(f"Invalid namespace: {value}") # The root namespace "langchain" is not a valid identifier. @@ -66,14 +73,21 @@ class Reviver: return value -def loads(text: str, *, secrets_map: Optional[Dict[str, str]] = None) -> Any: +def loads( + text: str, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: """Load a JSON object from a string. Args: text: The string to load. secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. Returns: """ - return json.loads(text, object_hook=Reviver(secrets_map)) + return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))