agent serialization (#4642)

textloader_autodetect_encodings
Harrison Chase 1 year ago committed by GitHub
parent ef49c659f6
commit fbfa49f2c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import yaml import yaml
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
from langchain.agents.agent_types import AgentType
from langchain.agents.tools import InvalidTool from langchain.agents.tools import InvalidTool
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.base import BaseCallbackManager
@ -132,7 +133,11 @@ class BaseSingleActionAgent(BaseModel):
def dict(self, **kwargs: Any) -> Dict: def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
_dict["_type"] = str(self._agent_type) _type = self._agent_type
if isinstance(_type, AgentType):
_dict["_type"] = str(_type.value)
else:
_dict["_type"] = _type
return _dict return _dict
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
@ -307,6 +312,12 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
return _dict
def plan( def plan(
self, self,
intermediate_steps: List[Tuple[AgentAction, str]], intermediate_steps: List[Tuple[AgentAction, str]],
@ -376,6 +387,12 @@ class Agent(BaseSingleActionAgent):
output_parser: AgentOutputParser output_parser: AgentOutputParser
allowed_tools: Optional[List[str]] = None allowed_tools: Optional[List[str]] = None
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
del _dict["output_parser"]
return _dict
def get_allowed_tools(self) -> Optional[List[str]]: def get_allowed_tools(self) -> Optional[List[str]]:
return self.allowed_tools return self.allowed_tools

@ -1,5 +1,6 @@
"""Functionality for loading agents.""" """Functionality for loading agents."""
import json import json
import logging
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
@ -12,6 +13,8 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.loading import load_chain, load_chain_from_config from langchain.chains.loading import load_chain, load_chain_from_config
from langchain.utilities.loading import try_load_from_hub from langchain.utilities.loading import try_load_from_hub
logger = logging.getLogger(__file__)
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
@ -61,6 +64,13 @@ def load_agent_from_config(
config["llm_chain"] = load_chain(config.pop("llm_chain_path")) config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
else: else:
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.") raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
if "output_parser" in config:
logger.warning(
"Currently loading output parsers on agent is not supported, "
"will just use the default one."
)
del config["output_parser"]
combined_config = {**config, **kwargs} combined_config = {**config, **kwargs}
return agent_cls(**combined_config) # type: ignore return agent_cls(**combined_config) # type: ignore

@ -10,6 +10,7 @@ from langchain.llms.base import BaseLLM
from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cerebriumai import CerebriumAI
from langchain.llms.cohere import Cohere from langchain.llms.cohere import Cohere
from langchain.llms.deepinfra import DeepInfra from langchain.llms.deepinfra import DeepInfra
from langchain.llms.fake import FakeListLLM
from langchain.llms.forefrontai import ForefrontAI from langchain.llms.forefrontai import ForefrontAI
from langchain.llms.google_palm import GooglePalm from langchain.llms.google_palm import GooglePalm
from langchain.llms.gooseai import GooseAI from langchain.llms.gooseai import GooseAI
@ -71,6 +72,7 @@ __all__ = [
"PredictionGuard", "PredictionGuard",
"HumanInputLLM", "HumanInputLLM",
"HuggingFaceTextGenInference", "HuggingFaceTextGenInference",
"FakeListLLM",
] ]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = { type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
@ -105,4 +107,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"writer": Writer, "writer": Writer,
"rwkv": RWKV, "rwkv": RWKV,
"huggingface_textgen_inference": HuggingFaceTextGenInference, "huggingface_textgen_inference": HuggingFaceTextGenInference,
"fake-list": FakeListLLM,
} }

@ -29,4 +29,4 @@ class FakeListLLM(LLM):
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
return {} return {"responses": self.responses}

@ -0,0 +1,19 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from langchain.agents.agent_types import AgentType
from langchain.agents.initialize import initialize_agent, load_agent
from langchain.llms.fake import FakeListLLM
def test_mrkl_serialization() -> None:
agent = initialize_agent(
[],
FakeListLLM(responses=[]),
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
with TemporaryDirectory() as tempdir:
file = Path(tempdir) / "agent.json"
agent.save_agent(file)
load_agent(file)
Loading…
Cancel
Save