mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
[simple] added test case and improve self class return type annotation (#3773)
a simple follow up of https://github.com/hwchase17/langchain/pull/3748 - added test case - improve annotation when function return type is class itself.
This commit is contained in:
parent
0c0f14407c
commit
ce4fea983b
@ -1,4 +1,6 @@
|
|||||||
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
|
"""DocumentFilter that uses an LLM chain to extract the relevant parts of documents."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Optional, Sequence
|
from typing import Any, Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
from langchain import LLMChain, PromptTemplate
|
from langchain import LLMChain, PromptTemplate
|
||||||
@ -70,7 +72,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
prompt: Optional[PromptTemplate] = None,
|
prompt: Optional[PromptTemplate] = None,
|
||||||
get_input: Optional[Callable[[str, Document], str]] = None,
|
get_input: Optional[Callable[[str, Document], str]] = None,
|
||||||
llm_chain_kwargs: Optional[dict] = None,
|
llm_chain_kwargs: Optional[dict] = None,
|
||||||
) -> "LLMChainExtractor":
|
) -> LLMChainExtractor:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||||
_get_input = get_input if get_input is not None else default_get_input
|
_get_input = get_input if get_input is not None else default_get_input
|
||||||
|
@ -4,6 +4,14 @@ from langchain.retrievers.document_compressors import LLMChainExtractor
|
|||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_construction_with_kwargs() -> None:
|
||||||
|
llm_chain_kwargs = {"verbose": True}
|
||||||
|
compressor = LLMChainExtractor.from_llm(
|
||||||
|
ChatOpenAI(), llm_chain_kwargs=llm_chain_kwargs
|
||||||
|
)
|
||||||
|
assert compressor.llm_chain.verbose is True
|
||||||
|
|
||||||
|
|
||||||
def test_llm_chain_extractor() -> None:
|
def test_llm_chain_extractor() -> None:
|
||||||
texts = [
|
texts = [
|
||||||
"The Roman Empire followed the Roman Republic.",
|
"The Roman Empire followed the Roman Republic.",
|
||||||
|
Loading…
Reference in New Issue
Block a user