agent serialization (#4642)

textloader_autodetect_encodings
Harrison Chase 1 year ago committed by GitHub
parent ef49c659f6
commit fbfa49f2c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml
from pydantic import BaseModel, root_validator
from langchain.agents.agent_types import AgentType
from langchain.agents.tools import InvalidTool
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
@ -132,7 +133,11 @@ class BaseSingleActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
_dict["_type"] = str(self._agent_type)
_type = self._agent_type
if isinstance(_type, AgentType):
_dict["_type"] = str(_type.value)
else:
_dict["_type"] = _type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
@ -307,6 +312,12 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
def input_keys(self) -> List[str]:
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
return _dict
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
@ -376,6 +387,12 @@ class Agent(BaseSingleActionAgent):
output_parser: AgentOutputParser
allowed_tools: Optional[List[str]] = None
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
return _dict
def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools

@ -1,5 +1,6 @@
"""Functionality for loading agents."""
import json
import logging
from pathlib import Path
from typing import Any, List, Optional, Union
@ -12,6 +13,8 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.utilities.loading import try_load_from_hub
logger = logging.getLogger(__file__)
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
@ -61,6 +64,13 @@ def load_agent_from_config(
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
else:
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
if "output_parser" in config:
logger.warning(
"Currently loading output parsers on agent is not supported, "
"will just use the default one."
)
del config["output_parser"]
combined_config = {**config, **kwargs}
return agent_cls(**combined_config) # type: ignore

@ -10,6 +10,7 @@ from langchain.llms.base import BaseLLM
from langchain.llms.cerebriumai import CerebriumAI
from langchain.llms.cohere import Cohere
from langchain.llms.deepinfra import DeepInfra
from langchain.llms.fake import FakeListLLM
from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI
@ -71,6 +72,7 @@ __all__ = [
"PredictionGuard",
"HumanInputLLM",
"HuggingFaceTextGenInference",
"FakeListLLM",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@ -105,4 +107,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"writer": Writer,
"rwkv": RWKV,
"huggingface_textgen_inference": HuggingFaceTextGenInference,
"fake-list": FakeListLLM,
}

@ -29,4 +29,4 @@ class FakeListLLM(LLM):
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}
return {"responses": self.responses}

@ -0,0 +1,19 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from langchain.agents.agent_types import AgentType
from langchain.agents.initialize import initialize_agent, load_agent
from langchain.llms.fake import FakeListLLM
def test_mrkl_serialization() -> None:
agent = initialize_agent(
[],
FakeListLLM(responses=[]),
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
with TemporaryDirectory() as tempdir:
file = Path(tempdir) / "agent.json"
agent.save_agent(file)
load_agent(file)
Loading…
Cancel
Save