Harrison/load from hub (#580)

harrison/sql-agent
Harrison Chase 1 year ago committed by GitHub
parent f74ce7a104
commit 3f2ea5c35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
"""Prompt template classes.""" """Prompt template classes."""
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt from langchain.prompts.loading import load_from_hub, load_prompt
from langchain.prompts.prompt import Prompt, PromptTemplate from langchain.prompts.prompt import Prompt, PromptTemplate
__all__ = [ __all__ = [
@ -10,4 +10,5 @@ __all__ = [
"PromptTemplate", "PromptTemplate",
"FewShotPromptTemplate", "FewShotPromptTemplate",
"Prompt", "Prompt",
"load_from_hub",
] ]

@ -1,8 +1,11 @@
"""Load prompts from disk.""" """Load prompts from disk."""
import importlib
import json import json
import tempfile
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import requests
import yaml import yaml
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
@ -97,7 +100,38 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
elif file_path.suffix == ".yaml": elif file_path.suffix == ".yaml":
with open(file_path, "r") as f: with open(file_path, "r") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
elif file_path.suffix == ".py":
spec = importlib.util.spec_from_loader(
"prompt", loader=None, origin=str(file_path)
)
if spec is None:
raise ValueError("could not load spec")
helper = importlib.util.module_from_spec(spec)
with open(file_path, "rb") as f:
exec(f.read(), helper.__dict__)
if not isinstance(helper.PROMPT, BasePromptTemplate):
raise ValueError("Did not get object of type BasePromptTemplate.")
return helper.PROMPT
else: else:
raise ValueError raise ValueError(f"Got unsupported file type {file_path.suffix}")
# Load the prompt from the config now. # Load the prompt from the config now.
return load_prompt_from_config(config) return load_prompt_from_config(config)
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
def load_from_hub(path: str) -> BasePromptTemplate:
"""Load prompt from hub."""
suffix = path.split(".")[-1]
if suffix not in {"py", "json", "yaml"}:
raise ValueError("Unsupported file type.")
full_url = URL_BASE + path
r = requests.get(full_url)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = tmpdirname + "/prompt." + suffix
with open(file, "wb") as f:
f.write(r.content)
return load_prompt(file)

Loading…
Cancel
Save