|
|
|
@ -1,8 +1,11 @@
|
|
|
|
|
"""Functionality for loading chains."""
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import tempfile
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
import yaml
|
|
|
|
|
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
@ -10,6 +13,8 @@ from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.llms.loading import load_llm, load_llm_from_config
|
|
|
|
|
from langchain.prompts.loading import load_prompt, load_prompt_from_config
|
|
|
|
|
|
|
|
|
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_llm_chain(config: dict) -> LLMChain:
|
|
|
|
|
"""Load LLM chain from config dict."""
|
|
|
|
@ -48,7 +53,16 @@ def load_chain_from_config(config: dict) -> Chain:
|
|
|
|
|
return chain_loader(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_chain(file: Union[str, Path]) -> Chain:
|
|
|
|
|
def load_chain(path: Union[str, Path]) -> Chain:
|
|
|
|
|
"""Unified method for loading a chain from LangChainHub or local fs."""
|
|
|
|
|
if isinstance(path, str) and path.startswith("lc://chains"):
|
|
|
|
|
path = os.path.relpath(path, "lc://chains/")
|
|
|
|
|
return _load_from_hub(path)
|
|
|
|
|
else:
|
|
|
|
|
return _load_chain_from_file(path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_chain_from_file(file: Union[str, Path]) -> Chain:
|
|
|
|
|
"""Load chain from file."""
|
|
|
|
|
# Convert file to Path object.
|
|
|
|
|
if isinstance(file, str):
|
|
|
|
@ -66,3 +80,19 @@ def load_chain(file: Union[str, Path]) -> Chain:
|
|
|
|
|
raise ValueError("File type must be json or yaml")
|
|
|
|
|
# Load the chain from the config now.
|
|
|
|
|
return load_chain_from_config(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_from_hub(path: str) -> Chain:
|
|
|
|
|
"""Load chain from hub."""
|
|
|
|
|
suffix = path.split(".")[-1]
|
|
|
|
|
if suffix not in {"json", "yaml"}:
|
|
|
|
|
raise ValueError("Unsupported file type.")
|
|
|
|
|
full_url = URL_BASE + path
|
|
|
|
|
r = requests.get(full_url)
|
|
|
|
|
if r.status_code != 200:
|
|
|
|
|
raise ValueError(f"Could not find file at {full_url}")
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
|
|
|
|
file = tmpdirname + "/chain." + suffix
|
|
|
|
|
with open(file, "wb") as f:
|
|
|
|
|
f.write(r.content)
|
|
|
|
|
return _load_chain_from_file(file)
|
|
|
|
|