feat: add support for arxiv identifier in ArxivAPIWrapper() (#9318)

- Description: this PR adds the support for arxiv identifier of the
ArxivAPIWrapper. I modified the `run()` and `load()` functions in
`arxiv.py`, using regex to recognize if the query is in the form of
arxiv identifier (see
[https://info.arxiv.org/help/find/index.html](https://info.arxiv.org/help/find/index.html)).
If so, it will directly search the paper corresponding to the arxiv
identifier. I also modified and added tests in `test_arxiv.py`.
  - Issue: #9047 
  - Dependencies: N/A
  - Tag maintainer: N/A

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/11143/head
Mincoolee 8 months ago committed by GitHub
parent d3c2ca5656
commit 05b75f3f13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Util that calls Arxiv."""
import logging
import os
import re
from typing import Any, Dict, List, Optional
from langchain.pydantic_v1 import BaseModel, root_validator
@ -17,6 +18,9 @@ class ArxivAPIWrapper(BaseModel):
This wrapper will use the Arxiv API to conduct searches and
fetch document summaries. By default, it will return the document summaries
of the top-k results.
If the query is in the form of arxiv identifier
(see https://info.arxiv.org/help/find/index.html), it will return the paper
corresponding to the arxiv identifier.
It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size.
@ -54,6 +58,18 @@ class ArxivAPIWrapper(BaseModel):
load_all_available_meta: bool = False
doc_content_chars_max: Optional[int] = 4000
def is_arxiv_identifier(self, query: str) -> bool:
"""Check if a query is an arxiv identifier."""
arxiv_identifier_pattern = r"\d{2}(0[1-9]|1[0-2])\.\d{4,5}(v\d+|)|\d{7}.*"
for query_item in query[: self.ARXIV_MAX_QUERY_LENGTH].split():
match_result = re.match(arxiv_identifier_pattern, query_item)
if not match_result:
return False
assert match_result is not None
if not match_result.group(0) == query_item:
return False
return True
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
@ -88,9 +104,15 @@ class ArxivAPIWrapper(BaseModel):
query: a plaintext search query
""" # noqa: E501
try:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
if self.is_arxiv_identifier(query):
results = self.arxiv_search(
id_list=query.split(),
max_results=self.top_k_results,
).results()
else:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results
).results()
except self.arxiv_exceptions as ex:
return f"Arxiv exception: {ex}"
docs = [
@ -129,9 +151,15 @@ class ArxivAPIWrapper(BaseModel):
try:
# Remove the ":" and "-" from the query, as they can cause search problems
query = query.replace(":", "").replace("-", "")
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
).results()
if self.is_arxiv_identifier(query):
results = self.arxiv_search(
id_list=query[: self.ARXIV_MAX_QUERY_LENGTH].split(),
max_results=self.load_max_docs,
).results()
else:
results = self.arxiv_search( # type: ignore
query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.load_max_docs
).results()
except self.arxiv_exceptions as ex:
logger.debug("Error on arxiv: %s", ex)
return []

File diff suppressed because it is too large Load Diff

@ -346,6 +346,7 @@ extended_testing = [
"faiss-cpu",
"openapi-schema-pydantic",
"markdownify",
"arxiv",
"dashvector",
"sqlite-vss",
"timescale-vector",

@ -15,13 +15,38 @@ def api_client() -> ArxivAPIWrapper:
return ArxivAPIWrapper()
def test_run_success(api_client: ArxivAPIWrapper) -> None:
"""Test that returns the correct answer"""
def test_run_success_paper_name(api_client: ArxivAPIWrapper) -> None:
"""Test a query of paper name that returns the correct answer"""
output = api_client.run("1605.08386")
output = api_client.run("Heat-bath random walks with Markov bases")
assert "Probability distributions for Markov chains based quantum walks" in output
assert (
"Transformations of random walks on groups via Markov stopping times" in output
)
assert (
"Recurrence of Multidimensional Persistent Random Walks. Fourier and Series "
"Criteria" in output
)
def test_run_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
"""Test a query of an arxiv identifier returns the correct answer"""
output = api_client.run("1605.08386v1")
assert "Heat-bath random walks with Markov bases" in output
def test_run_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
"""Test a query of multiple arxiv identifiers that returns the correct answer"""
output = api_client.run("1605.08386v1 2212.00794v2 2308.07912")
assert "Heat-bath random walks with Markov bases" in output
assert "Scaling Language-Image Pre-training via Masking" in output
assert (
"Ultra-low mass PBHs in the early universe can explain the PTA signal" in output
)
def test_run_returns_several_docs(api_client: ArxivAPIWrapper) -> None:
"""Test that returns several docs"""
@ -43,14 +68,30 @@ def assert_docs(docs: List[Document]) -> None:
assert set(doc.metadata) == {"Published", "Title", "Authors", "Summary"}
def test_load_success(api_client: ArxivAPIWrapper) -> None:
"""Test that returns one document"""
def test_load_success_paper_name(api_client: ArxivAPIWrapper) -> None:
"""Test a query of paper name that returns one document"""
docs = api_client.load("1605.08386")
docs = api_client.load("Heat-bath random walks with Markov bases")
assert len(docs) == 3
assert_docs(docs)
def test_load_success_arxiv_identifier(api_client: ArxivAPIWrapper) -> None:
"""Test a query of an arxiv identifier that returns one document"""
docs = api_client.load("1605.08386v1")
assert len(docs) == 1
assert_docs(docs)
def test_load_success_multiple_arxiv_identifiers(api_client: ArxivAPIWrapper) -> None:
"""Test a query of arxiv identifiers that returns the correct answer"""
docs = api_client.load("1605.08386v1 2212.00794v2 2308.07912")
assert len(docs) == 3
assert_docs(docs)
def test_load_returns_no_result(api_client: ArxivAPIWrapper) -> None:
"""Test that returns no docs"""

@ -0,0 +1,17 @@
import pytest as pytest
from langchain.utilities import ArxivAPIWrapper
@pytest.mark.requires("arxiv")
def test_is_arxiv_identifier() -> None:
"""Test that is_arxiv_identifier returns True for valid arxiv identifiers"""
api_client = ArxivAPIWrapper()
assert api_client.is_arxiv_identifier("1605.08386v1")
assert api_client.is_arxiv_identifier("0705.0123")
assert api_client.is_arxiv_identifier("2308.07912")
assert api_client.is_arxiv_identifier("9603067 2308.07912 2308.07912")
assert not api_client.is_arxiv_identifier("12345")
assert not api_client.is_arxiv_identifier("0705.012")
assert not api_client.is_arxiv_identifier("0705.012300")
assert not api_client.is_arxiv_identifier("1605.08386w1")

191
poetry.lock generated

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save