mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
4babefcb2f
- **Description:** Modified regular expression to add support for unicode chars and simplify pattern Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
198 lines
6.1 KiB
Python
198 lines
6.1 KiB
Python
import hashlib
|
|
import re
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Sequence, TypedDict
|
|
|
|
from git import Repo
|
|
|
|
from langchain_cli.constants import (
|
|
DEFAULT_GIT_REF,
|
|
DEFAULT_GIT_REPO,
|
|
DEFAULT_GIT_SUBDIRECTORY,
|
|
)
|
|
|
|
|
|
class DependencySource(TypedDict):
|
|
git: str
|
|
ref: Optional[str]
|
|
subdirectory: Optional[str]
|
|
api_path: Optional[str]
|
|
event_metadata: Dict
|
|
|
|
|
|
# use poetry dependency string format
|
|
def parse_dependency_string(
|
|
dep: Optional[str],
|
|
repo: Optional[str],
|
|
branch: Optional[str],
|
|
api_path: Optional[str],
|
|
) -> DependencySource:
|
|
if dep is not None and dep.startswith("git+"):
|
|
if repo is not None or branch is not None:
|
|
raise ValueError(
|
|
"If a dependency starts with git+, you cannot manually specify "
|
|
"a repo or branch."
|
|
)
|
|
# remove git+
|
|
gitstring = dep[4:]
|
|
subdirectory = None
|
|
ref = None
|
|
# first check for #subdirectory= on the end
|
|
if "#subdirectory=" in gitstring:
|
|
gitstring, subdirectory = gitstring.split("#subdirectory=")
|
|
if "#" in subdirectory or "@" in subdirectory:
|
|
raise ValueError(
|
|
"#subdirectory must be the last part of the dependency string"
|
|
)
|
|
|
|
# find first slash after ://
|
|
# find @ or # after that slash
|
|
# remainder is ref
|
|
# if no @ or #, then ref is None
|
|
|
|
# find first slash after ://
|
|
if "://" not in gitstring:
|
|
raise ValueError(
|
|
"git+ dependencies must start with git+https:// or git+ssh://"
|
|
)
|
|
|
|
_, find_slash = gitstring.split("://", 1)
|
|
|
|
if "/" not in find_slash:
|
|
post_slash = find_slash
|
|
ref = None
|
|
else:
|
|
_, post_slash = find_slash.split("/", 1)
|
|
if "@" in post_slash or "#" in post_slash:
|
|
_, ref = re.split(r"[@#]", post_slash, 1)
|
|
|
|
# gitstring is everything before that
|
|
gitstring = gitstring[: -len(ref) - 1] if ref is not None else gitstring
|
|
|
|
return DependencySource(
|
|
git=gitstring,
|
|
ref=ref,
|
|
subdirectory=subdirectory,
|
|
api_path=api_path,
|
|
event_metadata={"dependency_string": dep},
|
|
)
|
|
|
|
elif dep is not None and dep.startswith("https://"):
|
|
raise ValueError("Only git dependencies are supported")
|
|
else:
|
|
# if repo is none, use default, including subdirectory
|
|
base_subdir = Path(DEFAULT_GIT_SUBDIRECTORY) if repo is None else Path()
|
|
subdir = str(base_subdir / dep) if dep is not None else None
|
|
gitstring = (
|
|
DEFAULT_GIT_REPO
|
|
if repo is None
|
|
else f"https://github.com/{repo.strip('/')}.git"
|
|
)
|
|
ref = DEFAULT_GIT_REF if branch is None else branch
|
|
# it's a default git repo dependency
|
|
return DependencySource(
|
|
git=gitstring,
|
|
ref=ref,
|
|
subdirectory=subdir,
|
|
api_path=api_path,
|
|
event_metadata={
|
|
"dependency_string": dep,
|
|
"used_repo_flag": repo is not None,
|
|
"used_branch_flag": branch is not None,
|
|
},
|
|
)
|
|
|
|
|
|
def _list_arg_to_length(arg: Optional[List[str]], num: int) -> Sequence[Optional[str]]:
|
|
if not arg:
|
|
return [None] * num
|
|
elif len(arg) == 1:
|
|
return arg * num
|
|
elif len(arg) == num:
|
|
return arg
|
|
else:
|
|
raise ValueError(f"Argument must be of length 1 or {num}")
|
|
|
|
|
|
def parse_dependencies(
|
|
dependencies: Optional[List[str]],
|
|
repo: List[str],
|
|
branch: List[str],
|
|
api_path: List[str],
|
|
) -> List[DependencySource]:
|
|
num_deps = max(
|
|
len(dependencies) if dependencies is not None else 0, len(repo), len(branch)
|
|
)
|
|
if (
|
|
(dependencies and len(dependencies) != num_deps)
|
|
or (api_path and len(api_path) != num_deps)
|
|
or (repo and len(repo) not in [1, num_deps])
|
|
or (branch and len(branch) not in [1, num_deps])
|
|
):
|
|
raise ValueError(
|
|
"Number of defined repos/branches/api_paths did not match the "
|
|
"number of templates."
|
|
)
|
|
inner_deps = _list_arg_to_length(dependencies, num_deps)
|
|
inner_api_paths = _list_arg_to_length(api_path, num_deps)
|
|
inner_repos = _list_arg_to_length(repo, num_deps)
|
|
inner_branches = _list_arg_to_length(branch, num_deps)
|
|
|
|
return [
|
|
parse_dependency_string(iter_dep, iter_repo, iter_branch, iter_api_path)
|
|
for iter_dep, iter_repo, iter_branch, iter_api_path in zip(
|
|
inner_deps, inner_repos, inner_branches, inner_api_paths
|
|
)
|
|
]
|
|
|
|
|
|
def _get_repo_path(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
|
|
# only based on git for now
|
|
ref_str = ref if ref is not None else ""
|
|
hashed = hashlib.sha256((f"{gitstring}:{ref_str}").encode("utf-8")).hexdigest()[:8]
|
|
|
|
removed_protocol = gitstring.split("://")[-1]
|
|
removed_basename = re.split(r"[/:]", removed_protocol, 1)[-1]
|
|
removed_extras = removed_basename.split("#")[0]
|
|
foldername = re.sub(r"\W", "_", removed_extras)
|
|
|
|
directory_name = f"{foldername}_{hashed}"
|
|
return repo_dir / directory_name
|
|
|
|
|
|
def update_repo(gitstring: str, ref: Optional[str], repo_dir: Path) -> Path:
|
|
# see if path already saved
|
|
repo_path = _get_repo_path(gitstring, ref, repo_dir)
|
|
if repo_path.exists():
|
|
# try pulling
|
|
try:
|
|
repo = Repo(repo_path)
|
|
if repo.active_branch.name != ref:
|
|
raise ValueError()
|
|
repo.remotes.origin.pull()
|
|
except Exception:
|
|
# if it fails, delete and clone again
|
|
shutil.rmtree(repo_path)
|
|
Repo.clone_from(gitstring, repo_path, branch=ref, depth=1)
|
|
else:
|
|
Repo.clone_from(gitstring, repo_path, branch=ref, depth=1)
|
|
|
|
return repo_path
|
|
|
|
|
|
def copy_repo(
|
|
source: Path,
|
|
destination: Path,
|
|
) -> None:
|
|
"""
|
|
Copies a repo, ignoring git folders.
|
|
|
|
Raises FileNotFound error if it can't find source
|
|
"""
|
|
|
|
def ignore_func(_, files):
|
|
return [f for f in files if f == ".git"]
|
|
|
|
shutil.copytree(source, destination, ignore=ignore_func)
|