Compare commits

...

14 Commits

Author SHA1 Message Date
Chester Curme 402298e376 Merge branch 'master' into cc/retriever_score 1 month ago
Chester Curme 267ee9db4c Merge branch 'master' into cc/retriever_score 1 month ago
Chester Curme bc7af5fd7e bump langchain to 0.1.17rc1 1 month ago
Chester Curme abf1f4c124 bump core to 0.1.47rc1 1 month ago
Chester Curme c9fc0447ec Merge branch 'master' into cc/retriever_score 1 month ago
Chester Curme c262cef1fb update SelfQueryRetriever 1 month ago
Chester Curme 26455d156d Merge branch 'master' into cc/retriever_score 1 month ago
Chester Curme ceea324071 cr 1 month ago
Chester Curme 1544c9d050 update 1 month ago
Chester Curme 7a78068cd4 fix test 1 month ago
Chester Curme d91fd8cdcb update 1 month ago
Chester Curme 1f15b0885d add test 1 month ago
Chester Curme 4f16714195 update VectorStoreRetriever 1 month ago
Chester Curme f8598a7e48 add DocumentSearchHit 1 month ago

@ -1444,7 +1444,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
arbitrary_types_allowed = True
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
@ -1472,7 +1476,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(

@ -1,10 +1,11 @@
import itertools
import random
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.documents import DocumentSearchHit
from langchain_community.vectorstores import DatabricksVectorSearch
from tests.integration_tests.vectorstores.fake_embeddings import (
@ -598,6 +599,12 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
result_with_scores = cast(
List[DocumentSearchHit], retriever.invoke(query, include_score=True)
)
for idx, result in enumerate(result_with_scores):
assert result.score >= threshold
assert result.page_content == search_result[idx].page_content
@pytest.mark.requires("databricks", "databricks.vector_search")

@ -2,8 +2,13 @@
and their transformations.
"""
from langchain_core.documents.base import Document
from langchain_core.documents.base import Document, DocumentSearchHit
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.documents.transformers import BaseDocumentTransformer
__all__ = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
__all__ = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]

@ -30,3 +30,21 @@ class Document(Serializable):
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]
class DocumentSearchHit(Document):
"""Class for storing a document and fields associated with retrieval."""
score: float
"""Score associated with the document's relevance to a query."""
type: Literal["DocumentSearchHit"] = "DocumentSearchHit" # type: ignore[assignment] # noqa: E501
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document_search_hit"]

@ -157,6 +157,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain", "schema", "document_search_hit", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
@ -666,6 +672,12 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain_core", "documents", "base", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",

@ -39,6 +39,7 @@ from typing import (
TypeVar,
)
from langchain_core.documents import DocumentSearchHit
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@ -690,8 +691,17 @@ class VectorStoreRetriever(BaseRetriever):
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
@ -700,6 +710,11 @@ class VectorStoreRetriever(BaseRetriever):
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
@ -710,8 +725,17 @@ class VectorStoreRetriever(BaseRetriever):
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
@ -722,6 +746,11 @@ class VectorStoreRetriever(BaseRetriever):
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-core"
version = "0.1.46"
version = "0.1.47rc1"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

@ -1,6 +1,11 @@
from langchain_core.documents import __all__
EXPECTED_ALL = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
EXPECTED_ALL = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]
def test_all_imports() -> None:

@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.documents import Document, DocumentSearchHit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@ -192,19 +192,43 @@ class SelfQueryRetriever(BaseRetriever):
return new_query, search_kwargs
def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_scores
]
else:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(page_content=doc.page_content, score=score)
for doc, score in docs_and_scores
]
else:
docs = await self.vectorstore.asearch(
query, self.search_type, **search_kwargs
)
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@ -220,11 +244,17 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
docs = self._get_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@ -240,7 +270,9 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
docs = await self._aget_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
@classmethod

@ -3469,7 +3469,7 @@ files = [
[[package]]
name = "langchain-community"
version = "0.0.32"
version = "0.0.34"
description = "Community contributed LangChain integrations."
optional = false
python-versions = ">=3.8.1,<4.0"
@ -3479,7 +3479,7 @@ develop = true
[package.dependencies]
aiohttp = "^3.8.3"
dataclasses-json = ">= 0.5.7, < 0.7"
langchain-core = "^0.1.41"
langchain-core = "^0.1.45"
langsmith = "^0.1.0"
numpy = "^1"
PyYAML = ">=5.3"
@ -3489,7 +3489,7 @@ tenacity = "^8.1.0"
[package.extras]
cli = ["typer (>=0.9.0,<0.10.0)"]
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "httpx-sse (>=0.4.0,<0.5.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pyjwt (>=2.8.0,<3.0.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "azure-identity (>=1.15.0,<2.0.0)", "azure-search-documents (==11.4.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.6,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "httpx-sse (>=0.4.0,<0.5.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pyjwt (>=2.8.0,<3.0.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"]
[package.source]
type = "directory"
@ -3497,7 +3497,7 @@ url = "../community"
[[package]]
name = "langchain-core"
version = "0.1.42"
version = "0.1.47rc1"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -9410,4 +9410,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "845d36b1258779b2b483ec8758070fc73adad9d94b7d4c93a4145c360d946ac2"
content-hash = "00c1e092e378283d46a322dfb3014e30f5388f334e0a82f49acd2e8ecf5c05d3"

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.1.16"
version = "0.1.17rc1"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.42"
langchain-core = { version = "^0.1.47rc1", allow-prereleases = true }
langchain-text-splitters = ">=0.0.1,<0.1"
langchain-community = ">=0.0.32,<0.1"
langsmith = "^0.1.17"

Loading…
Cancel
Save