You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/cli/langchain_cli/utils/git.py

131 lines
4.1 KiB
Python

from typing import Optional, TypedDict
from pathlib import Path
import shutil
import re
from langchain_cli.constants import DEFAULT_GIT_REPO, DEFAULT_GIT_SUBDIRECTORY
import hashlib
from git import Repo
class DependencySource(TypedDict):
git: str
ref: Optional[str]
subdirectory: Optional[str]
def _get_main_branch(repo: Repo) -> Optional[str]:
"""
Get the name of the main branch of a git repo.
From https://stackoverflow.com/questions/69651536/how-to-get-master-main-branch-from-gitpython
"""
try:
# replace "origin" with your remote name if differs
show_result = repo.git.remote("show", "origin")
# The show_result contains a wall of text in the language that
# is set by your locales. Now you can use regex to extract the
# default branch name, but if your language is different
# from english, you need to adjust this regex pattern.
matches = re.search(r"\s*HEAD branch:\s*(.*)", show_result)
if matches:
default_branch = matches.group(1)
return default_branch
except Exception:
pass
# fallback to main/master
if "main" in repo.heads:
return "main"
if "master" in repo.heads:
return "master"
raise ValueError("Could not find main branch")
# use poetry dependency string format
def _parse_dependency_string(package_string: str) -> DependencySource:
if package_string.startswith("git+"):
# remove git+
remaining = package_string[4:]
# split main string from params
gitstring, *params = remaining.split("#")
# parse params
params_dict = {}
for param in params:
if not param:
# ignore empty entries
continue
if "=" in param:
key, value = param.split("=")
if key in params_dict:
raise ValueError(
f"Duplicate parameter {key} in dependency string {package_string}"
)
params_dict[key] = value
else:
if "ref" in params_dict:
raise ValueError(
f"Duplicate parameter ref in dependency string {package_string}"
)
params_dict["ref"] = param
return DependencySource(
git=gitstring,
ref=params_dict.get("ref"),
subdirectory=params_dict.get("subdirectory"),
)
elif package_string.startswith("https://"):
raise NotImplementedError("url dependencies are not supported yet")
else:
# it's a default git repo dependency
gitstring = DEFAULT_GIT_REPO
subdirectory = str(Path(DEFAULT_GIT_SUBDIRECTORY) / package_string)
return DependencySource(git=gitstring, ref=None, subdirectory=subdirectory)
def _get_repo_path(dependency: DependencySource, repo_dir: Path) -> Path:
# only based on git for now
gitstring = dependency["git"]
hashed = hashlib.sha256(gitstring.encode("utf-8")).hexdigest()[:8]
removed_protocol = gitstring.split("://")[-1]
removed_basename = re.split(r"[/:]", removed_protocol, 1)[-1]
removed_extras = removed_basename.split("#")[0]
foldername = re.sub(r"[^a-zA-Z0-9_]", "_", removed_extras)
directory_name = f"{foldername}_{hashed}"
return repo_dir / directory_name
def update_repo(gitpath: str, repo_dir: Path) -> Path:
# see if path already saved
dependency = _parse_dependency_string(gitpath)
repo_path = _get_repo_path(dependency, repo_dir)
if not repo_path.exists():
repo = Repo.clone_from(dependency["git"], repo_path)
else:
repo = Repo(repo_path)
# pull it
ref = dependency.get("ref") if dependency.get("ref") else _get_main_branch(repo)
repo.git.checkout(ref)
repo.git.pull()
return (
repo_path
if dependency["subdirectory"] is None
else repo_path / dependency["subdirectory"]
)
def copy_repo(
source: Path,
destination: Path,
) -> None:
def ignore_func(_, files):
return [f for f in files if f == ".git"]
shutil.copytree(source, destination, ignore=ignore_func)