mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
parent
0551594722
commit
dd95f0892d
@ -32,7 +32,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "cb4a5787",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -46,7 +46,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "bcbe04d9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -83,7 +83,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "86e34dbf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -138,7 +138,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='dinosaur' filter=None\n"
|
||||
"query='dinosaur' filter=None limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -170,7 +170,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n"
|
||||
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5) limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -200,7 +200,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n"
|
||||
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -229,7 +229,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n"
|
||||
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction')]) limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -258,7 +258,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n"
|
||||
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')]) limit=None\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -277,10 +277,69 @@
|
||||
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "87513116",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Filter k\n",
|
||||
"\n",
|
||||
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
|
||||
"\n",
|
||||
"We can do this by passing `enable_limit=True` to the constructor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "73cfca56",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = SelfQueryRetriever.from_llm(\n",
|
||||
" llm, \n",
|
||||
" vectorstore, \n",
|
||||
" document_content_description, \n",
|
||||
" metadata_field_info, \n",
|
||||
" enable_limit=True,\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "60110338",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query='dinosaur' filter=None limit=2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'year': 1993, 'rating': 7.7, 'genre': 'science fiction'}),\n",
|
||||
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# This example only specifies a relevant query\n",
|
||||
"retriever.get_relevant_documents(\"what are two movies about dinosaurs\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "60110338",
|
||||
"id": "f15d84b3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
|
@ -295,13 +295,45 @@
|
||||
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6fe7536c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Filter k\n",
|
||||
"\n",
|
||||
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
|
||||
"\n",
|
||||
"We can do this by passing `enable_limit=True` to the constructor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "69bbd809",
|
||||
"id": "3a2937c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"retriever = SelfQueryRetriever.from_llm(\n",
|
||||
" llm, \n",
|
||||
" vectorstore, \n",
|
||||
" document_content_description, \n",
|
||||
" metadata_field_info, \n",
|
||||
" enable_limit=True,\n",
|
||||
" verbose=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "83d233aa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This example only specifies a relevant query\n",
|
||||
"retriever.get_relevant_documents(\"What are two movies about dinosaurs\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import (
|
||||
DEFAULT_SCHEMA,
|
||||
DEFAULT_SUFFIX,
|
||||
EXAMPLE_PROMPT,
|
||||
EXAMPLES_WITH_LIMIT,
|
||||
SCHEMA_WITH_LIMIT,
|
||||
)
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.output_parsers.structured import parse_json_markdown
|
||||
@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
|
||||
parsed["filter"] = None
|
||||
else:
|
||||
parsed["filter"] = self.ast_parse(parsed["filter"])
|
||||
return StructuredQuery(query=parsed["query"], filter=parsed["filter"])
|
||||
return StructuredQuery(
|
||||
query=parsed["query"],
|
||||
filter=parsed["filter"],
|
||||
limit=parsed.get("limit"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise OutputParserException(
|
||||
f"Parsing text\n{text}\n raised following error:\n{e}"
|
||||
@ -70,15 +76,25 @@ def _get_prompt(
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
enable_limit: bool = False,
|
||||
) -> BasePromptTemplate:
|
||||
attribute_str = _format_attribute_info(attribute_info)
|
||||
examples = examples or DEFAULT_EXAMPLES
|
||||
allowed_comparators = allowed_comparators or list(Comparator)
|
||||
allowed_operators = allowed_operators or list(Operator)
|
||||
if enable_limit:
|
||||
schema = SCHEMA_WITH_LIMIT.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
)
|
||||
|
||||
examples = examples or EXAMPLES_WITH_LIMIT
|
||||
else:
|
||||
schema = DEFAULT_SCHEMA.format(
|
||||
allowed_comparators=" | ".join(allowed_comparators),
|
||||
allowed_operators=" | ".join(allowed_operators),
|
||||
)
|
||||
|
||||
examples = examples or DEFAULT_EXAMPLES
|
||||
prefix = DEFAULT_PREFIX.format(schema=schema)
|
||||
suffix = DEFAULT_SUFFIX.format(
|
||||
i=len(examples) + 1, content=document_contents, attributes=attribute_str
|
||||
@ -87,7 +103,7 @@ def _get_prompt(
|
||||
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
|
||||
)
|
||||
return FewShotPromptTemplate(
|
||||
examples=DEFAULT_EXAMPLES,
|
||||
examples=examples,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
input_variables=["query"],
|
||||
suffix=suffix,
|
||||
@ -103,6 +119,7 @@ def load_query_constructor_chain(
|
||||
examples: Optional[List] = None,
|
||||
allowed_comparators: Optional[Sequence[Comparator]] = None,
|
||||
allowed_operators: Optional[Sequence[Operator]] = None,
|
||||
enable_limit: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> LLMChain:
|
||||
prompt = _get_prompt(
|
||||
@ -111,5 +128,6 @@ def load_query_constructor_chain(
|
||||
examples=examples,
|
||||
allowed_comparators=allowed_comparators,
|
||||
allowed_operators=allowed_operators,
|
||||
enable_limit=enable_limit,
|
||||
)
|
||||
return LLMChain(llm=llm, prompt=prompt, **kwargs)
|
||||
|
@ -81,3 +81,4 @@ class Operation(FilterDirective):
|
||||
class StructuredQuery(Expr):
|
||||
query: str
|
||||
filter: Optional[FilterDirective]
|
||||
limit: Optional[int]
|
||||
|
@ -46,6 +46,16 @@ NO_FILTER_ANSWER = """\
|
||||
```\
|
||||
"""
|
||||
|
||||
WITH_LIMIT_ANSWER = """\
|
||||
```json
|
||||
{{
|
||||
"query": "love",
|
||||
"filter": "NO_FILTER",
|
||||
"limit": 2
|
||||
}}
|
||||
```\
|
||||
"""
|
||||
|
||||
DEFAULT_EXAMPLES = [
|
||||
{
|
||||
"i": 1,
|
||||
@ -61,6 +71,27 @@ DEFAULT_EXAMPLES = [
|
||||
},
|
||||
]
|
||||
|
||||
EXAMPLES_WITH_LIMIT = [
|
||||
{
|
||||
"i": 1,
|
||||
"data_source": SONG_DATA_SOURCE,
|
||||
"user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre",
|
||||
"structured_request": FULL_ANSWER,
|
||||
},
|
||||
{
|
||||
"i": 2,
|
||||
"data_source": SONG_DATA_SOURCE,
|
||||
"user_query": "What are songs that were not published on Spotify",
|
||||
"structured_request": NO_FILTER_ANSWER,
|
||||
},
|
||||
{
|
||||
"i": 3,
|
||||
"data_source": SONG_DATA_SOURCE,
|
||||
"user_query": "What are three songs about love",
|
||||
"structured_request": WITH_LIMIT_ANSWER,
|
||||
},
|
||||
]
|
||||
|
||||
EXAMPLE_PROMPT_TEMPLATE = """\
|
||||
<< Example {i}. >>
|
||||
Data Source:
|
||||
@ -116,6 +147,45 @@ Make sure that filters are only used as needed. If there are no filters that sho
|
||||
applied return "NO_FILTER" for the filter value.\
|
||||
"""
|
||||
|
||||
SCHEMA_WITH_LIMIT = """\
|
||||
<< Structured Request Schema >>
|
||||
When responding use a markdown code snippet with a JSON object formatted in the \
|
||||
following schema:
|
||||
|
||||
```json
|
||||
{{{{
|
||||
"query": string \\ text string to compare to document contents
|
||||
"filter": string \\ logical condition statement for filtering documents
|
||||
"limit": int \\ the number of documents to retrieve
|
||||
}}}}
|
||||
```
|
||||
|
||||
The query string should contain only text that is expected to match the contents of \
|
||||
documents. Any conditions in the filter should not be mentioned in the query as well.
|
||||
|
||||
A logical condition statement is composed of one or more comparison and logical \
|
||||
operation statements.
|
||||
|
||||
A comparison statement takes the form: `comp(attr, val)`:
|
||||
- `comp` ({allowed_comparators}): comparator
|
||||
- `attr` (string): name of attribute to apply the comparison to
|
||||
- `val` (string): is the comparison value
|
||||
|
||||
A logical operation statement takes the form `op(statement1, statement2, ...)`:
|
||||
- `op` ({allowed_operators}): logical operator
|
||||
- `statement1`, `statement2`, ... (comparison statements or logical operation \
|
||||
statements): one or more statements to apply the operation to
|
||||
|
||||
Make sure that you only use the comparators and logical operators listed above and \
|
||||
no others.
|
||||
Make sure that filters only refer to attributes that exist in the data source.
|
||||
Make sure that filters take into account the descriptions of attributes and only make \
|
||||
comparisons that are feasible given the type of data being stored.
|
||||
Make sure that filters are only used as needed. If there are no filters that should be \
|
||||
applied return "NO_FILTER" for the filter value.
|
||||
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX = """\
|
||||
Your goal is to structure the user's query to match the request schema provided below.
|
||||
|
||||
|
@ -68,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
inputs = self.llm_chain.prep_inputs(query)
|
||||
inputs = self.llm_chain.prep_inputs({"query": query})
|
||||
structured_query = cast(
|
||||
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
|
||||
)
|
||||
@ -77,6 +77,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
|
||||
structured_query
|
||||
)
|
||||
if structured_query.limit is not None:
|
||||
new_kwargs["k"] = structured_query.limit
|
||||
|
||||
search_kwargs = {**self.search_kwargs, **new_kwargs}
|
||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
@ -93,11 +96,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
metadata_field_info: List[AttributeInfo],
|
||||
structured_query_translator: Optional[Visitor] = None,
|
||||
chain_kwargs: Optional[Dict] = None,
|
||||
enable_limit: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> "SelfQueryRetriever":
|
||||
if structured_query_translator is None:
|
||||
structured_query_translator = _get_builtin_translator(vectorstore.__class__)
|
||||
chain_kwargs = chain_kwargs or {}
|
||||
|
||||
if "allowed_comparators" not in chain_kwargs:
|
||||
chain_kwargs[
|
||||
"allowed_comparators"
|
||||
@ -107,7 +112,11 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
"allowed_operators"
|
||||
] = structured_query_translator.allowed_operators
|
||||
llm_chain = load_query_constructor_chain(
|
||||
llm, document_contents, metadata_field_info, **chain_kwargs
|
||||
llm,
|
||||
document_contents,
|
||||
metadata_field_info,
|
||||
enable_limit=enable_limit,
|
||||
**chain_kwargs,
|
||||
)
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
|
Loading…
Reference in New Issue
Block a user