[simple] added test case and improve self class return type annotation (#3773)

a simple follow up of https://github.com/hwchase17/langchain/pull/3748
- added test case
- improve annotation when function return type is class itself.
This commit is contained in:
Mike Wang 2023-04-28 21:54:07 -07:00 committed by GitHub
parent 0c0f14407c
commit ce4fea983b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 1 deletions

View File

@ -1,4 +1,6 @@
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" """DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
from __future__ import annotations
from typing import Any, Callable, Dict, Optional, Sequence from typing import Any, Callable, Dict, Optional, Sequence
from langchain import LLMChain, PromptTemplate from langchain import LLMChain, PromptTemplate
@ -70,7 +72,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
prompt: Optional[PromptTemplate] = None, prompt: Optional[PromptTemplate] = None,
get_input: Optional[Callable[[str, Document], str]] = None, get_input: Optional[Callable[[str, Document], str]] = None,
llm_chain_kwargs: Optional[dict] = None, llm_chain_kwargs: Optional[dict] = None,
) -> "LLMChainExtractor": ) -> LLMChainExtractor:
"""Initialize from LLM.""" """Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt() _prompt = prompt if prompt is not None else _get_default_chain_prompt()
_get_input = get_input if get_input is not None else default_get_input _get_input = get_input if get_input is not None else default_get_input

View File

@ -4,6 +4,14 @@ from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.schema import Document from langchain.schema import Document
def test_llm_construction_with_kwargs() -> None:
llm_chain_kwargs = {"verbose": True}
compressor = LLMChainExtractor.from_llm(
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
)
assert compressor.llm_chain.verbose is True
def test_llm_chain_extractor() -> None: def test_llm_chain_extractor() -> None:
texts = [ texts = [
"The Roman Empire followed the Roman Republic.", "The Roman Empire followed the Roman Republic.",