diff --git a/langchain/agents/agent.py b/langchain/agents/agent.py index 5aa7486a..f73b5f26 100644 --- a/langchain/agents/agent.py +++ b/langchain/agents/agent.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import yaml from pydantic import BaseModel, root_validator +from langchain.agents.agent_types import AgentType from langchain.agents.tools import InvalidTool from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -132,7 +133,11 @@ class BaseSingleActionAgent(BaseModel): def dict(self, **kwargs: Any) -> Dict: """Return dictionary representation of agent.""" _dict = super().dict() - _dict["_type"] = str(self._agent_type) + _type = self._agent_type + if isinstance(_type, AgentType): + _dict["_type"] = str(_type.value) + else: + _dict["_type"] = _type return _dict def save(self, file_path: Union[Path, str]) -> None: @@ -307,6 +312,12 @@ class LLMSingleActionAgent(BaseSingleActionAgent): def input_keys(self) -> List[str]: return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + del _dict["output_parser"] + return _dict + def plan( self, intermediate_steps: List[Tuple[AgentAction, str]], @@ -376,6 +387,12 @@ class Agent(BaseSingleActionAgent): output_parser: AgentOutputParser allowed_tools: Optional[List[str]] = None + def dict(self, **kwargs: Any) -> Dict: + """Return dictionary representation of agent.""" + _dict = super().dict() + del _dict["output_parser"] + return _dict + def get_allowed_tools(self) -> Optional[List[str]]: return self.allowed_tools diff --git a/langchain/agents/loading.py b/langchain/agents/loading.py index d7702fbc..359909cd 100644 --- a/langchain/agents/loading.py +++ b/langchain/agents/loading.py @@ -1,5 +1,6 @@ """Functionality for loading agents.""" import json +import logging from pathlib import Path from typing import Any, List, Optional, Union @@ -12,6 +13,8 @@ from langchain.base_language import BaseLanguageModel from langchain.chains.loading import load_chain, load_chain_from_config from langchain.utilities.loading import try_load_from_hub +logger = logging.getLogger(__file__) + URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/" @@ -61,6 +64,13 @@ def load_agent_from_config( config["llm_chain"] = load_chain(config.pop("llm_chain_path")) else: raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.") + if "output_parser" in config: + logger.warning( + "Currently loading output parsers on agent is not supported, " + "will just use the default one." + ) + del config["output_parser"] + combined_config = {**config, **kwargs} return agent_cls(**combined_config) # type: ignore diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 221a0d3e..c1c5d1e8 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -10,6 +10,7 @@ from langchain.llms.base import BaseLLM from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cohere import Cohere from langchain.llms.deepinfra import DeepInfra +from langchain.llms.fake import FakeListLLM from langchain.llms.forefrontai import ForefrontAI from langchain.llms.google_palm import GooglePalm from langchain.llms.gooseai import GooseAI @@ -71,6 +72,7 @@ __all__ = [ "PredictionGuard", "HumanInputLLM", "HuggingFaceTextGenInference", + "FakeListLLM", ] type_to_cls_dict: Dict[str, Type[BaseLLM]] = { @@ -105,4 +107,5 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "writer": Writer, "rwkv": RWKV, "huggingface_textgen_inference": HuggingFaceTextGenInference, + "fake-list": FakeListLLM, } diff --git a/langchain/llms/fake.py b/langchain/llms/fake.py index 3df15b9c..15fbab5e 100644 --- a/langchain/llms/fake.py +++ b/langchain/llms/fake.py @@ -29,4 +29,4 @@ class FakeListLLM(LLM): @property def _identifying_params(self) -> Mapping[str, Any]: - return {} + return {"responses": self.responses} diff --git a/tests/unit_tests/agents/test_serialization.py b/tests/unit_tests/agents/test_serialization.py new file mode 100644 index 00000000..db68fd2b --- /dev/null +++ b/tests/unit_tests/agents/test_serialization.py @@ -0,0 +1,19 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +from langchain.agents.agent_types import AgentType +from langchain.agents.initialize import initialize_agent, load_agent +from langchain.llms.fake import FakeListLLM + + +def test_mrkl_serialization() -> None: + agent = initialize_agent( + [], + FakeListLLM(responses=[]), + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + ) + with TemporaryDirectory() as tempdir: + file = Path(tempdir) / "agent.json" + agent.save_agent(file) + load_agent(file)