From ce4fea983bd279401a92f468e5859e924ef6e460 Mon Sep 17 00:00:00 2001 From: Mike Wang <62768671+skcoirz@users.noreply.github.com> Date: Fri, 28 Apr 2023 21:54:07 -0700 Subject: [PATCH] [simple] added test case and improve self class return type annotation (#3773) a simple follow up of https://github.com/hwchase17/langchain/pull/3748 - added test case - improve annotation when function return type is class itself. --- .../retrievers/document_compressors/chain_extract.py | 4 +++- .../retrievers/document_compressors/test_chain_extract.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index 175dd76d..db4b5a67 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -1,4 +1,6 @@ """DocumentFilter that uses an LLM chain to extract the relevant parts of documents.""" +from __future__ import annotations + from typing import Any, Callable, Dict, Optional, Sequence from langchain import LLMChain, PromptTemplate @@ -70,7 +72,7 @@ class LLMChainExtractor(BaseDocumentCompressor): prompt: Optional[PromptTemplate] = None, get_input: Optional[Callable[[str, Document], str]] = None, llm_chain_kwargs: Optional[dict] = None, - ) -> "LLMChainExtractor": + ) -> LLMChainExtractor: """Initialize from LLM.""" _prompt = prompt if prompt is not None else _get_default_chain_prompt() _get_input = get_input if get_input is not None else default_get_input diff --git a/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py b/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py index 0fcfebf9..7434f665 100644 --- a/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py +++ b/tests/integration_tests/retrievers/document_compressors/test_chain_extract.py @@ -4,6 +4,14 @@ from langchain.retrievers.document_compressors import LLMChainExtractor from langchain.schema import Document +def test_llm_construction_with_kwargs() -> None: + llm_chain_kwargs = {"verbose": True} + compressor = LLMChainExtractor.from_llm( + ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs + ) + assert compressor.llm_chain.verbose is True + + def test_llm_chain_extractor() -> None: texts = [ "The Roman Empire followed the Roman Republic.",