load agents from hub (#759)

ankush/async-llmchain
Harrison Chase 1 year ago committed by GitHub
parent 7129f23511
commit 12dc7f26cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,8 +1,11 @@
"""Functionality for loading agents."""
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Union
import requests
import yaml
from langchain.agents.agent import Agent
@ -19,6 +22,8 @@ AGENT_TO_CLASS = {
"conversational-react-description": ConversationalAgent,
}
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/agents/"
def load_agent_from_config(config: dict, **kwargs: Any) -> Agent:
"""Load agent from Config Dict."""
@ -40,7 +45,32 @@ def load_agent_from_config(config: dict, **kwargs: Any) -> Agent:
return agent_cls(**combined_config) # type: ignore
def load_agent(file: Union[str, Path], **kwargs: Any) -> Agent:
def load_agent(path: Union[str, Path], **kwargs: Any) -> Agent:
"""Unified method for loading a agent from LangChainHub or local fs."""
if isinstance(path, str) and path.startswith("lc://agents"):
path = os.path.relpath(path, "lc://agents/")
return _load_from_hub(path, **kwargs)
else:
return _load_agent_from_file(path, **kwargs)
def _load_from_hub(path: str, **kwargs: Any) -> Agent:
"""Load agent from hub."""
suffix = path.split(".")[-1]
if suffix not in {"json", "yaml"}:
raise ValueError("Unsupported file type.")
full_url = URL_BASE + path
r = requests.get(full_url)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = tmpdirname + "/agent." + suffix
with open(file, "wb") as f:
f.write(r.content)
return _load_agent_from_file(file)
def _load_agent_from_file(file: Union[str, Path], **kwargs: Any) -> Agent:
"""Load agent from file."""
# Convert file to Path object.
if isinstance(file, str):

Loading…
Cancel
Save