langchain[patch]: expose cohere rerank score, add parent doc param (#16887)

pull/17274/head
Bagatur 4 months ago committed by GitHub
parent 35c1bf339d
commit 02ef9164b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -484,8 +484,8 @@ class ElasticsearchStore(VectorStore):
from langchain_community.vectorstores.utils import DistanceStrategy
vectorstore = ElasticsearchStore(
"langchain-demo",
embedding=OpenAIEmbeddings(),
index_name="langchain-demo",
es_url="http://localhost:9200",
distance_strategy="DOT_PRODUCT"
)

@ -0,0 +1,3 @@
from langchain.chains.query_constructor.base import load_query_constructor_runnable
__all__ = ["load_query_constructor_runnable"]

@ -323,7 +323,8 @@ def load_query_constructor_runnable(
Args:
llm: BaseLanguageModel to use for the chain.
document_contents: The contents of the document to be queried.
document_contents: Description of the page contents of the document to be
queried.
attribute_info: Sequence of attributes in the document.
examples: Optional list of examples to use for the chain.
allowed_comparators: Sequence of allowed comparators. Defaults to all

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Optional, Sequence
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union
from langchain_core.documents import Document
from langchain_core.pydantic_v1 import Extra, root_validator
@ -9,23 +10,13 @@ from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
from cohere import Client
else:
# We do to avoid pydantic annotation issues when actually instantiating
# while keeping this import optional
try:
from cohere import Client
except ImportError:
pass
class CohereRerank(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
client: Client
client: Any
"""Cohere client to use for compressing documents."""
top_n: int = 3
top_n: Optional[int] = 3
"""Number of documents to return."""
model: str = "rerank-english-v2.0"
"""Model to use for reranking."""
@ -57,6 +48,42 @@ class CohereRerank(BaseDocumentCompressor):
values["client"] = cohere.Client(cohere_api_key, client_name=client_name)
return values
def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
model: Optional[str] = None,
top_n: Optional[int] = -1,
max_chunks_per_doc: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""Returns an ordered list of documents ordered by their relevance to the provided query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
model: The model to use for re-ranking. Default to self.model.
top_n : The number of results to return. If None returns all results.
Defaults to self.top_n.
max_chunks_per_doc : The maximum number of chunks derived from a document.
""" # noqa: E501
if len(documents) == 0: # to avoid empty api call
return []
docs = [
doc.page_content if isinstance(doc, Document) else doc for doc in documents
]
model = model or self.model
top_n = top_n if (top_n is None or top_n > 0) else self.top_n
results = self.client.rerank(
query, docs, model, top_n=top_n, max_chunks_per_doc=max_chunks_per_doc
)
result_dicts = []
for res in results:
result_dicts.append(
{"index": res.index, "relevance_score": res.relevance_score}
)
return result_dicts
def compress_documents(
self,
documents: Sequence[Document],
@ -74,16 +101,10 @@ class CohereRerank(BaseDocumentCompressor):
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
results = self.client.rerank(
model=self.model, query=query, documents=_docs, top_n=self.top_n
)
final_results = []
for r in results:
doc = doc_list[r.index]
doc.metadata["relevance_score"] = r.relevance_score
final_results.append(doc)
return final_results
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed

@ -1,5 +1,5 @@
import uuid
from typing import List, Optional
from typing import List, Optional, Sequence
from langchain_core.documents import Document
@ -31,17 +31,16 @@ class ParentDocumentRetriever(MultiVectorRetriever):
.. code-block:: python
# Imports
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.storage import InMemoryStore
# This text splitter is used to create the parent documents
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, add_start_index=True)
# This text splitter is used to create the child documents
# It should create documents smaller than the parent
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=400, add_start_index=True)
# The vectorstore to use to index the child chunks
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
# The storage layer for the parent documents
@ -54,7 +53,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
"""
""" # noqa: E501
child_splitter: TextSplitter
"""The text splitter to use to create child documents."""
@ -65,6 +64,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
"""The text splitter to use to create parent documents.
If none, then the parent documents will be the raw documents passed in."""
child_metadata_fields: Optional[Sequence[str]] = None
"""Metadata fields to leave in child documents. If None, leave all parent document
metadata.
"""
def add_documents(
self,
documents: List[Document],
@ -76,7 +80,7 @@ class ParentDocumentRetriever(MultiVectorRetriever):
Args:
documents: List of documents to add
ids: Optional list of ids for documents. If provided should be the same
length as the list of documents. Can provided if parent documents
length as the list of documents. Can be provided if parent documents
are already in the document store and you don't want to re-add
to the docstore. If not provided, random UUIDs will be used as
ids.
@ -106,6 +110,11 @@ class ParentDocumentRetriever(MultiVectorRetriever):
for i, doc in enumerate(documents):
_id = doc_ids[i]
sub_docs = self.child_splitter.split_documents([doc])
if self.child_metadata_fields is not None:
for _doc in sub_docs:
_doc.metadata = {
k: _doc.metadata[k] for k in self.child_metadata_fields
}
for _doc in sub_docs:
_doc.metadata[self.id_key] = _id
docs.extend(sub_docs)

@ -649,7 +649,7 @@ class ChatOpenAI(BaseChatModel):
Must be the name of the single provided function or
"auto" to automatically determine which function to call
(if any).
kwargs: Any additional parameters to pass to the
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
@ -701,22 +701,21 @@ class ChatOpenAI(BaseChatModel):
"auto" to automatically determine which function to call
(if any), or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
kwargs: Any additional parameters to pass to the
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice is not None:
if isinstance(tool_choice, str) and tool_choice not in ("auto", "none"):
if isinstance(tool_choice, str) and (tool_choice not in ("auto", "none")):
tool_choice = {"type": "function", "function": {"name": tool_choice}}
if isinstance(tool_choice, dict) and len(formatted_tools) != 1:
if isinstance(tool_choice, dict) and (len(formatted_tools) != 1):
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(formatted_tools)} tools."
)
if (
isinstance(tool_choice, dict)
and formatted_tools[0]["function"]["name"]
if isinstance(tool_choice, dict) and (
formatted_tools[0]["function"]["name"]
!= tool_choice["function"]["name"]
):
raise ValueError(
@ -724,7 +723,4 @@ class ChatOpenAI(BaseChatModel):
f"provided tool was {formatted_tools[0]['function']['name']}."
)
kwargs["tool_choice"] = tool_choice
return super().bind(
tools=formatted_tools,
**kwargs,
)
return super().bind(tools=formatted_tools, **kwargs)

Loading…
Cancel
Save