|
|
|
@ -1,39 +1,28 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import logging
|
|
|
|
|
from typing import List, Sequence
|
|
|
|
|
from typing import List, Optional, Sequence
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForRetrieverRun,
|
|
|
|
|
CallbackManagerForRetrieverRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.language_models import BaseLLM
|
|
|
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
|
|
|
from langchain_core.output_parsers import BaseOutputParser
|
|
|
|
|
from langchain_core.prompts.prompt import PromptTemplate
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
|
|
|
|
|
|
from langchain.chains.llm import LLMChain
|
|
|
|
|
from langchain.output_parsers.pydantic import PydanticOutputParser
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineList(BaseModel):
|
|
|
|
|
"""List of lines."""
|
|
|
|
|
|
|
|
|
|
lines: List[str] = Field(description="Lines of text")
|
|
|
|
|
"""List of lines."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LineListOutputParser(PydanticOutputParser):
|
|
|
|
|
class LineListOutputParser(BaseOutputParser[List[str]]):
|
|
|
|
|
"""Output parser for a list of lines."""
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
super().__init__(pydantic_object=LineList)
|
|
|
|
|
|
|
|
|
|
def parse(self, text: str) -> LineList:
|
|
|
|
|
def parse(self, text: str) -> List[str]:
|
|
|
|
|
lines = text.strip().split("\n")
|
|
|
|
|
return LineList(lines=lines)
|
|
|
|
|
return lines
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Default prompt
|
|
|
|
@ -63,6 +52,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|
|
|
|
llm_chain: LLMChain
|
|
|
|
|
verbose: bool = True
|
|
|
|
|
parser_key: str = "lines"
|
|
|
|
|
"""DEPRECATED. parser_key is no longer used and should not be specified."""
|
|
|
|
|
include_original: bool = False
|
|
|
|
|
"""Whether to include the original query in the list of generated queries."""
|
|
|
|
|
|
|
|
|
@ -70,9 +60,9 @@ class MultiQueryRetriever(BaseRetriever):
|
|
|
|
|
def from_llm(
|
|
|
|
|
cls,
|
|
|
|
|
retriever: BaseRetriever,
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
|
|
|
|
|
parser_key: str = "lines",
|
|
|
|
|
parser_key: Optional[str] = None,
|
|
|
|
|
include_original: bool = False,
|
|
|
|
|
) -> "MultiQueryRetriever":
|
|
|
|
|
"""Initialize from llm using default template.
|
|
|
|
@ -91,7 +81,6 @@ class MultiQueryRetriever(BaseRetriever):
|
|
|
|
|
return cls(
|
|
|
|
|
retriever=retriever,
|
|
|
|
|
llm_chain=llm_chain,
|
|
|
|
|
parser_key=parser_key,
|
|
|
|
|
include_original=include_original,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -129,7 +118,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|
|
|
|
response = await self.llm_chain.acall(
|
|
|
|
|
inputs={"question": question}, callbacks=run_manager.get_child()
|
|
|
|
|
)
|
|
|
|
|
lines = getattr(response["text"], self.parser_key, [])
|
|
|
|
|
lines = response["text"]
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logger.info(f"Generated queries: {lines}")
|
|
|
|
|
return lines
|
|
|
|
@ -189,7 +178,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|
|
|
|
response = self.llm_chain(
|
|
|
|
|
{"question": question}, callbacks=run_manager.get_child()
|
|
|
|
|
)
|
|
|
|
|
lines = getattr(response["text"], self.parser_key, [])
|
|
|
|
|
lines = response["text"]
|
|
|
|
|
if self.verbose:
|
|
|
|
|
logger.info(f"Generated queries: {lines}")
|
|
|
|
|
return lines
|
|
|
|
|