core[patch]: doc init positional args (#16854)

pull/16956/head
Bagatur 4 months ago committed by GitHub
parent d80c612c92
commit 2a510c71a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import List, Literal
from typing import Any, List, Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
@ -17,6 +17,10 @@ class Document(Serializable):
"""
type: Literal["Document"] = "Document"
def __init__(self, page_content: str, **kwargs: Any) -> None:
"""Pass page_content in as positional or named arg."""
super().__init__(page_content=page_content, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""

@ -2,7 +2,7 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Dict, Optional, Sequence, cast
from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
@ -67,7 +67,9 @@ class LLMChainExtractor(BaseDocumentCompressor):
output = self.llm_chain.predict_and_parse(**_input, callbacks=callbacks)
if len(output) == 0:
continue
compressed_docs.append(Document(page_content=output, metadata=doc.metadata))
compressed_docs.append(
Document(page_content=cast(str, output), metadata=doc.metadata)
)
return compressed_docs
async def acompress_documents(

@ -3,7 +3,7 @@ Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion
"""
import asyncio
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@ -195,7 +195,7 @@ class EnsembleRetriever(BaseRetriever):
# Enforce that retrieved docs are Documents for each list in retriever_docs
for i in range(len(retriever_docs)):
retriever_docs[i] = [
Document(page_content=doc) if not isinstance(doc, Document) else doc
Document(page_content=cast(str, doc)) if isinstance(doc, str) else doc
for doc in retriever_docs[i]
]

@ -13,10 +13,10 @@ def test_hashed_document_hashing() -> None:
def test_hashing_with_missing_content() -> None:
"""Check that ValueError is raised if page_content is missing."""
with pytest.raises(ValueError):
with pytest.raises(TypeError):
_HashedDocument(
metadata={"key": "value"},
)
) # type: ignore
def test_uid_auto_assigned_to_hash() -> None:

Loading…
Cancel
Save