mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
466 lines
18 KiB
Python
466 lines
18 KiB
Python
|
"""Parse arXiv references from the documentation.
|
||
|
Generate a page with a table of the arXiv references with links to the documentation pages.
|
||
|
"""
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
import re
|
||
|
from dataclasses import dataclass
|
||
|
from pathlib import Path
|
||
|
from typing import Any, Dict, List, Set
|
||
|
|
||
|
from pydantic.v1 import BaseModel, root_validator
|
||
|
|
||
|
# TODO parse docstrings for arXiv references
|
||
|
# TODO Generate a page with a table of the references with correspondent modules/classes/functions.
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
_ROOT_DIR = Path(os.path.abspath(__file__)).parents[2]
|
||
|
DOCS_DIR = _ROOT_DIR / "docs" / "docs"
|
||
|
CODE_DIR = _ROOT_DIR / "libs"
|
||
|
ARXIV_ID_PATTERN = r"https://arxiv\.org/(abs|pdf)/(\d+\.\d+)"
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class ArxivPaper:
|
||
|
"""ArXiv paper information."""
|
||
|
|
||
|
arxiv_id: str
|
||
|
referencing_docs: list[str] # TODO: Add the referencing docs
|
||
|
referencing_api_refs: list[str] # TODO: Add the referencing docs
|
||
|
title: str
|
||
|
authors: list[str]
|
||
|
abstract: str
|
||
|
url: str
|
||
|
published_date: str
|
||
|
|
||
|
|
||
|
def search_documentation_for_arxiv_references(docs_dir: Path) -> dict[str, set[str]]:
|
||
|
"""Search the documentation for arXiv references.
|
||
|
|
||
|
Search for the arXiv references in the documentation pages.
|
||
|
Note: It finds only the first arXiv reference in a line.
|
||
|
|
||
|
Args:
|
||
|
docs_dir: Path to the documentation root folder.
|
||
|
Returns:
|
||
|
dict: Dictionary with arxiv_id as key and set of file names as value.
|
||
|
"""
|
||
|
arxiv_url_pattern = re.compile(ARXIV_ID_PATTERN)
|
||
|
exclude_strings = {"file_path", "metadata", "link", "loader", "PyPDFLoader"}
|
||
|
|
||
|
# loop all the files (ipynb, mdx, md) in the docs folder
|
||
|
files = (
|
||
|
p.resolve()
|
||
|
for p in Path(docs_dir).glob("**/*")
|
||
|
if p.suffix in {".ipynb", ".mdx", ".md"}
|
||
|
)
|
||
|
arxiv_id2file_names: dict[str, set[str]] = {}
|
||
|
for file in files:
|
||
|
if "-checkpoint.ipynb" in file.name:
|
||
|
continue
|
||
|
with open(file, "r", encoding="utf-8") as f:
|
||
|
lines = f.readlines()
|
||
|
for line in lines:
|
||
|
if any(exclude_string in line for exclude_string in exclude_strings):
|
||
|
continue
|
||
|
matches = arxiv_url_pattern.search(line)
|
||
|
if matches:
|
||
|
arxiv_id = matches.group(2)
|
||
|
file_name = _get_doc_path(file.parts, file.suffix)
|
||
|
if arxiv_id not in arxiv_id2file_names:
|
||
|
arxiv_id2file_names[arxiv_id] = {file_name}
|
||
|
else:
|
||
|
arxiv_id2file_names[arxiv_id].add(file_name)
|
||
|
return arxiv_id2file_names
|
||
|
|
||
|
|
||
|
def convert_module_name_and_members_to_urls(
|
||
|
arxiv_id2module_name_and_members: dict[str, set[str]],
|
||
|
) -> dict[str, set[str]]:
|
||
|
arxiv_id2urls = {}
|
||
|
for arxiv_id, module_name_and_members in arxiv_id2module_name_and_members.items():
|
||
|
urls = set()
|
||
|
for module_name_and_member in module_name_and_members:
|
||
|
module_name, type_and_member = module_name_and_member.split(":")
|
||
|
if "$" in type_and_member:
|
||
|
type, member = type_and_member.split("$")
|
||
|
else:
|
||
|
type = type_and_member
|
||
|
member = ""
|
||
|
_namespace_parts = module_name.split(".")
|
||
|
if type == "module":
|
||
|
first_namespace_part = _namespace_parts[0]
|
||
|
if first_namespace_part.startswith("langchain_"):
|
||
|
first_namespace_part = first_namespace_part.replace(
|
||
|
"langchain_", ""
|
||
|
)
|
||
|
url = f"{first_namespace_part}_api_reference.html#module-{module_name}"
|
||
|
elif type in ["class", "function"]:
|
||
|
second_namespace_part = _namespace_parts[1]
|
||
|
url = f"{second_namespace_part}/{module_name}.{member}.html#{module_name}.{member}"
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown type: {type} in the {module_name_and_member}."
|
||
|
)
|
||
|
urls.add(url)
|
||
|
arxiv_id2urls[arxiv_id] = urls
|
||
|
return arxiv_id2urls
|
||
|
|
||
|
|
||
|
def search_code_for_arxiv_references(code_dir: Path) -> dict[str, set[str]]:
|
||
|
"""Search the code for arXiv references.
|
||
|
|
||
|
Search for the arXiv references in the code.
|
||
|
Note: It finds only the first arXiv reference in a line.
|
||
|
|
||
|
Args:
|
||
|
code_dir: Path to the code root folder.
|
||
|
Returns:
|
||
|
dict: Dictionary with arxiv_id as key and set of module names as value.
|
||
|
module names encoded as:
|
||
|
<module_name>:module
|
||
|
<module_name>:class$<ClassName>
|
||
|
<module_name>:function$<function_name>
|
||
|
"""
|
||
|
arxiv_url_pattern = re.compile(ARXIV_ID_PATTERN)
|
||
|
# exclude_strings = {"file_path", "metadata", "link", "loader"}
|
||
|
class_pattern = re.compile(r"\s*class\s+(\w+).*:")
|
||
|
function_pattern = re.compile(r"\s*def\s+(\w+)")
|
||
|
|
||
|
# loop all the files (ipynb, mdx, md) in the docs folder
|
||
|
files = (
|
||
|
p.resolve()
|
||
|
for p in Path(code_dir).glob("**/*")
|
||
|
if p.suffix in {".py"} and "tests" not in p.parts and "scripts" not in p.parts
|
||
|
# ".md" files are excluded
|
||
|
)
|
||
|
arxiv_id2module_name_and_members: dict[str, set[str]] = {}
|
||
|
for file in files:
|
||
|
try:
|
||
|
with open(file, "r", encoding="utf-8") as f:
|
||
|
module_name = _get_module_name(file.parts)
|
||
|
class_or_function_started = "module"
|
||
|
for line in f.readlines():
|
||
|
# class line:
|
||
|
matches = class_pattern.search(line)
|
||
|
if matches:
|
||
|
class_name = matches.group(1)
|
||
|
class_or_function_started = f"class${class_name}"
|
||
|
|
||
|
# function line:
|
||
|
# not inside a class!
|
||
|
if "class" not in class_or_function_started:
|
||
|
matches = function_pattern.search(line)
|
||
|
if matches:
|
||
|
func_name = matches.group(1)
|
||
|
class_or_function_started = f"function${func_name}"
|
||
|
|
||
|
# arxiv line:
|
||
|
matches = arxiv_url_pattern.search(line)
|
||
|
if matches:
|
||
|
arxiv_id = matches.group(2)
|
||
|
module_name_and_member = (
|
||
|
f"{module_name}:{class_or_function_started}"
|
||
|
)
|
||
|
if arxiv_id not in arxiv_id2module_name_and_members:
|
||
|
arxiv_id2module_name_and_members[arxiv_id] = {
|
||
|
module_name_and_member
|
||
|
}
|
||
|
else:
|
||
|
arxiv_id2module_name_and_members[arxiv_id].add(
|
||
|
module_name_and_member
|
||
|
)
|
||
|
except UnicodeDecodeError:
|
||
|
# Skip files like this 'tests/integration_tests/examples/non-utf8-encoding.py'
|
||
|
logger.warning(f"Could not read the file {file}.")
|
||
|
|
||
|
# handle border cases:
|
||
|
# 1. {'langchain_experimental.pal_chain.base:class$PALChain', 'langchain_experimental.pal_chain.base:module' - remove}
|
||
|
for arxiv_id, module_name_and_members in arxiv_id2module_name_and_members.items():
|
||
|
module_name_and_member_deduplicated = set()
|
||
|
non_module_members = set()
|
||
|
for module_name_and_member in module_name_and_members:
|
||
|
if not module_name_and_member.endswith(":module"):
|
||
|
module_name_and_member_deduplicated.add(module_name_and_member)
|
||
|
non_module_members.add(module_name_and_member.split(":")[0])
|
||
|
for module_name_and_member in module_name_and_members:
|
||
|
if module_name_and_member.endswith(":module"):
|
||
|
if module_name_and_member.split(":")[0] in non_module_members:
|
||
|
continue
|
||
|
module_name_and_member_deduplicated.add(module_name_and_member)
|
||
|
arxiv_id2module_name_and_members[arxiv_id] = module_name_and_member_deduplicated
|
||
|
|
||
|
# 2. {'langchain.evaluation.scoring.prompt:module', 'langchain.evaluation.comparison.prompt:module'}
|
||
|
# only modules with 2-part namespaces are parsed into API Reference now! TODO fix this behavior
|
||
|
# leave only the modules with 2-part namespaces
|
||
|
arxiv_id2module_name_and_members_reduced = {}
|
||
|
for arxiv_id, module_name_and_members in arxiv_id2module_name_and_members.items():
|
||
|
module_name_and_member_reduced = set()
|
||
|
removed_modules = set()
|
||
|
for module_name_and_member in module_name_and_members:
|
||
|
if module_name_and_member.endswith(":module"):
|
||
|
if module_name_and_member.split(":")[0].count(".") <= 1:
|
||
|
module_name_and_member_reduced.add(module_name_and_member)
|
||
|
else:
|
||
|
removed_modules.add(module_name_and_member)
|
||
|
else:
|
||
|
module_name_and_member_reduced.add(module_name_and_member)
|
||
|
if module_name_and_member_reduced:
|
||
|
arxiv_id2module_name_and_members_reduced[arxiv_id] = (
|
||
|
module_name_and_member_reduced
|
||
|
)
|
||
|
if removed_modules:
|
||
|
logger.warning(
|
||
|
f"{arxiv_id}: Removed the following modules with 2+ -part namespaces: {removed_modules}."
|
||
|
)
|
||
|
return arxiv_id2module_name_and_members_reduced
|
||
|
|
||
|
|
||
|
def _get_doc_path(file_parts: tuple[str, ...], file_extension) -> str:
|
||
|
"""Get the relative path to the documentation page
|
||
|
from the absolute path of the file.
|
||
|
Remove file_extension
|
||
|
"""
|
||
|
res = []
|
||
|
for el in file_parts[::-1]:
|
||
|
res.append(el)
|
||
|
if el == "docs":
|
||
|
break
|
||
|
ret = "/".join(reversed(res))
|
||
|
return ret[: -len(file_extension)] if ret.endswith(file_extension) else ret
|
||
|
|
||
|
|
||
|
def _get_code_path(file_parts: tuple[str, ...]) -> str:
|
||
|
"""Get the relative path to the documentation page
|
||
|
from the absolute path of the file.
|
||
|
"""
|
||
|
res = []
|
||
|
for el in file_parts[::-1]:
|
||
|
res.append(el)
|
||
|
if el == "libs":
|
||
|
break
|
||
|
return "/".join(reversed(res))
|
||
|
|
||
|
|
||
|
def _get_module_name(file_parts: tuple[str, ...]) -> str:
|
||
|
"""Get the module name from the absolute path of the file."""
|
||
|
ns_parts = []
|
||
|
for el in file_parts[::-1]:
|
||
|
if str(el) == "__init__.py":
|
||
|
continue
|
||
|
ns_parts.insert(0, str(el).replace(".py", ""))
|
||
|
if el.startswith("langchain"):
|
||
|
break
|
||
|
return ".".join(ns_parts)
|
||
|
|
||
|
|
||
|
def compound_urls(
|
||
|
arxiv_id2file_names: dict[str, set[str]], arxiv_id2code_urls: dict[str, set[str]]
|
||
|
) -> dict[str, dict[str, set[str]]]:
|
||
|
arxiv_id2urls = dict()
|
||
|
for arxiv_id, code_urls in arxiv_id2code_urls.items():
|
||
|
arxiv_id2urls[arxiv_id] = {"api": code_urls}
|
||
|
# intersection of the two sets
|
||
|
if arxiv_id in arxiv_id2file_names:
|
||
|
arxiv_id2urls[arxiv_id]["docs"] = arxiv_id2file_names[arxiv_id]
|
||
|
for arxiv_id, file_names in arxiv_id2file_names.items():
|
||
|
if arxiv_id not in arxiv_id2code_urls:
|
||
|
arxiv_id2urls[arxiv_id] = {"docs": file_names}
|
||
|
# reverse sort by the arxiv_id (the newest papers first)
|
||
|
ret = dict(sorted(arxiv_id2urls.items(), key=lambda item: item[0], reverse=True))
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def _format_doc_link(doc_paths: list[str]) -> list[str]:
|
||
|
return [
|
||
|
f"[{doc_path}](https://python.langchain.com/{doc_path})"
|
||
|
for doc_path in doc_paths
|
||
|
]
|
||
|
|
||
|
|
||
|
def _format_api_ref_link(
|
||
|
doc_paths: list[str], compact: bool = False
|
||
|
) -> list[str]: # TODO
|
||
|
# agents/langchain_core.agents.AgentAction.html#langchain_core.agents.AgentAction
|
||
|
ret = []
|
||
|
for doc_path in doc_paths:
|
||
|
module = doc_path.split("#")[1].replace("module-", "")
|
||
|
if compact and module.count(".") > 2:
|
||
|
# langchain_community.llms.oci_data_science_model_deployment_endpoint.OCIModelDeploymentTGI
|
||
|
# -> langchain_community.llms...OCIModelDeploymentTGI
|
||
|
module_parts = module.split(".")
|
||
|
module = f"{module_parts[0]}.{module_parts[1]}...{module_parts[-1]}"
|
||
|
ret.append(
|
||
|
f"[{module}](https://api.python.langchain.com/en/latest/{doc_path.split('langchain.com/')[-1]})"
|
||
|
)
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def log_results(arxiv_id2urls):
|
||
|
arxiv_ids = arxiv_id2urls.keys()
|
||
|
doc_number, api_number = 0, 0
|
||
|
for urls in arxiv_id2urls.values():
|
||
|
if "docs" in urls:
|
||
|
doc_number += len(urls["docs"])
|
||
|
if "api" in urls:
|
||
|
api_number += len(urls["api"])
|
||
|
logger.info(
|
||
|
f"Found {len(arxiv_ids)} arXiv references in the {doc_number} docs and in {api_number} API Refs."
|
||
|
)
|
||
|
|
||
|
|
||
|
class ArxivAPIWrapper(BaseModel):
|
||
|
arxiv_search: Any #: :meta private:
|
||
|
arxiv_exceptions: Any # :meta private:
|
||
|
|
||
|
@root_validator()
|
||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||
|
"""Validate that the python package exists in environment."""
|
||
|
try:
|
||
|
import arxiv
|
||
|
|
||
|
values["arxiv_search"] = arxiv.Search
|
||
|
values["arxiv_exceptions"] = (
|
||
|
arxiv.ArxivError,
|
||
|
arxiv.UnexpectedEmptyPageError,
|
||
|
arxiv.HTTPError,
|
||
|
)
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Could not import arxiv python package. "
|
||
|
"Please install it with `pip install arxiv`."
|
||
|
)
|
||
|
return values
|
||
|
|
||
|
def get_papers(
|
||
|
self, arxiv_id2urls: dict[str, dict[str, set[str]]]
|
||
|
) -> list[ArxivPaper]:
|
||
|
"""
|
||
|
Performs an arxiv search and returns information about the papers found.
|
||
|
|
||
|
If an error occurs or no documents found, error text
|
||
|
is returned instead.
|
||
|
Args:
|
||
|
arxiv_id2urls: Dictionary with arxiv_id as key and dictionary
|
||
|
with sets of doc file names and API Ref urls.
|
||
|
|
||
|
Returns:
|
||
|
List of ArxivPaper objects.
|
||
|
""" # noqa: E501
|
||
|
|
||
|
def cut_authors(authors: list) -> list[str]:
|
||
|
if len(authors) > 3:
|
||
|
return [str(a) for a in authors[:3]] + [" et al."]
|
||
|
else:
|
||
|
return [str(a) for a in authors]
|
||
|
|
||
|
if not arxiv_id2urls:
|
||
|
return []
|
||
|
try:
|
||
|
arxiv_ids = list(arxiv_id2urls.keys())
|
||
|
results = self.arxiv_search(
|
||
|
id_list=arxiv_ids,
|
||
|
max_results=len(arxiv_ids),
|
||
|
).results()
|
||
|
except self.arxiv_exceptions as ex:
|
||
|
raise ex
|
||
|
papers = [
|
||
|
ArxivPaper(
|
||
|
arxiv_id=result.entry_id.split("/")[-1],
|
||
|
title=result.title,
|
||
|
authors=cut_authors(result.authors),
|
||
|
abstract=result.summary,
|
||
|
url=result.entry_id,
|
||
|
published_date=str(result.published.date()),
|
||
|
referencing_docs=urls["docs"] if "docs" in urls else [],
|
||
|
referencing_api_refs=urls["api"] if "api" in urls else [],
|
||
|
)
|
||
|
for result, urls in zip(results, arxiv_id2urls.values())
|
||
|
]
|
||
|
return papers
|
||
|
|
||
|
|
||
|
def generate_arxiv_references_page(file_name: str, papers: list[ArxivPaper]) -> None:
|
||
|
with open(file_name, "w") as f:
|
||
|
# Write the table headers
|
||
|
f.write("""# arXiv
|
||
|
|
||
|
LangChain implements the latest research in the field of Natural Language Processing.
|
||
|
This page contains `arXiv` papers referenced in the LangChain Documentation and API Reference.
|
||
|
|
||
|
## Summary
|
||
|
|
||
|
| arXiv id / Title | Authors | Published date 🔻 | LangChain Documentation and API Reference |
|
||
|
|------------------|---------|-------------------|-------------------------|
|
||
|
""")
|
||
|
for paper in papers:
|
||
|
refs = []
|
||
|
if paper.referencing_docs:
|
||
|
refs += [
|
||
|
"`Docs:` " + ", ".join(_format_doc_link(paper.referencing_docs))
|
||
|
]
|
||
|
if paper.referencing_api_refs:
|
||
|
refs += [
|
||
|
"`API:` "
|
||
|
+ ", ".join(
|
||
|
_format_api_ref_link(paper.referencing_api_refs, compact=True)
|
||
|
)
|
||
|
]
|
||
|
refs_str = ", ".join(refs)
|
||
|
|
||
|
title_link = f"[{paper.title}]({paper.url})"
|
||
|
f.write(
|
||
|
f"| {' | '.join([f'`{paper.arxiv_id}` {title_link}', ', '.join(paper.authors), paper.published_date, refs_str])}\n"
|
||
|
)
|
||
|
|
||
|
for paper in papers:
|
||
|
docs_refs = (
|
||
|
f"- **LangChain Documentation:** {', '.join(_format_doc_link(paper.referencing_docs))}"
|
||
|
if paper.referencing_docs
|
||
|
else ""
|
||
|
)
|
||
|
api_ref_refs = (
|
||
|
f"- **LangChain API Reference:** {', '.join(_format_api_ref_link(paper.referencing_api_refs))}"
|
||
|
if paper.referencing_api_refs
|
||
|
else ""
|
||
|
)
|
||
|
f.write(f"""
|
||
|
## {paper.title}
|
||
|
|
||
|
- **arXiv id:** {paper.arxiv_id}
|
||
|
- **Title:** {paper.title}
|
||
|
- **Authors:** {', '.join(paper.authors)}
|
||
|
- **Published Date:** {paper.published_date}
|
||
|
- **URL:** {paper.url}
|
||
|
{docs_refs}
|
||
|
{api_ref_refs}
|
||
|
|
||
|
**Abstract:** {paper.abstract}
|
||
|
""")
|
||
|
|
||
|
logger.info(f"Created the {file_name} file with {len(papers)} arXiv references.")
|
||
|
|
||
|
|
||
|
def main():
|
||
|
# search the documentation and the API Reference for arXiv references:
|
||
|
arxiv_id2module_name_and_members = search_code_for_arxiv_references(CODE_DIR)
|
||
|
arxiv_id2code_urls = convert_module_name_and_members_to_urls(
|
||
|
arxiv_id2module_name_and_members
|
||
|
)
|
||
|
arxiv_id2file_names = search_documentation_for_arxiv_references(DOCS_DIR)
|
||
|
arxiv_id2urls = compound_urls(arxiv_id2file_names, arxiv_id2code_urls)
|
||
|
log_results(arxiv_id2urls)
|
||
|
|
||
|
# get the arXiv paper information
|
||
|
papers = ArxivAPIWrapper().get_papers(arxiv_id2urls)
|
||
|
|
||
|
# generate the arXiv references page
|
||
|
output_file = str(DOCS_DIR / "additional_resources" / "arxiv_references.mdx")
|
||
|
generate_arxiv_references_page(output_file, papers)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|