Harrison/load from hub (#580)

harrison/sql-agent
Harrison Chase 1 year ago committed by GitHub
parent f74ce7a104
commit 3f2ea5c35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
"""Prompt template classes."""
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt
from langchain.prompts.loading import load_from_hub, load_prompt
from langchain.prompts.prompt import Prompt, PromptTemplate
__all__ = [
@ -10,4 +10,5 @@ __all__ = [
"PromptTemplate",
"FewShotPromptTemplate",
"Prompt",
"load_from_hub",
]

@ -1,8 +1,11 @@
"""Load prompts from disk."""
import importlib
import json
import tempfile
from pathlib import Path
from typing import Union
import requests
import yaml
from langchain.prompts.base import BasePromptTemplate
@ -97,7 +100,38 @@ def load_prompt(file: Union[str, Path]) -> BasePromptTemplate:
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
config = yaml.safe_load(f)
elif file_path.suffix == ".py":
spec = importlib.util.spec_from_loader(
"prompt", loader=None, origin=str(file_path)
)
if spec is None:
raise ValueError("could not load spec")
helper = importlib.util.module_from_spec(spec)
with open(file_path, "rb") as f:
exec(f.read(), helper.__dict__)
if not isinstance(helper.PROMPT, BasePromptTemplate):
raise ValueError("Did not get object of type BasePromptTemplate.")
return helper.PROMPT
else:
raise ValueError
raise ValueError(f"Got unsupported file type {file_path.suffix}")
# Load the prompt from the config now.
return load_prompt_from_config(config)
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
def load_from_hub(path: str) -> BasePromptTemplate:
"""Load prompt from hub."""
suffix = path.split(".")[-1]
if suffix not in {"py", "json", "yaml"}:
raise ValueError("Unsupported file type.")
full_url = URL_BASE + path
r = requests.get(full_url)
if r.status_code != 200:
raise ValueError(f"Could not find file at {full_url}")
with tempfile.TemporaryDirectory() as tmpdirname:
file = tmpdirname + "/prompt." + suffix
with open(file, "wb") as f:
f.write(r.content)
return load_prompt(file)

Loading…
Cancel
Save