Harrison/add top k (#4707)

Co-authored-by: blc16 <benlc@umich.edu>
This commit is contained in:
Harrison Chase 2023-05-15 09:09:22 -07:00 committed by GitHub
parent 0551594722
commit dd95f0892d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 20 deletions

View File

@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "cb4a5787",
"metadata": {},
"outputs": [],
@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "bcbe04d9",
"metadata": {},
"outputs": [
@ -83,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "86e34dbf",
"metadata": {},
"outputs": [],
@ -138,7 +138,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None\n"
"query='dinosaur' filter=None limit=None\n"
]
},
{
@ -170,7 +170,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n"
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5) limit=None\n"
]
},
{
@ -200,7 +200,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n"
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
]
},
{
@ -229,7 +229,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n"
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction')]) limit=None\n"
]
},
{
@ -258,7 +258,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n"
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')]) limit=None\n"
]
},
{
@ -277,10 +277,69 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
]
},
{
"cell_type": "markdown",
"id": "87513116",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "73cfca56",
"metadata": {},
"outputs": [],
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "60110338",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None limit=2\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'year': 1993, 'rating': 7.7, 'genre': 'science fiction'}),\n",
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"what are two movies about dinosaurs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60110338",
"id": "f15d84b3",
"metadata": {},
"outputs": [],
"source": []

View File

@ -295,13 +295,45 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
]
},
{
"cell_type": "markdown",
"id": "6fe7536c",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69bbd809",
"id": "3a2937c2",
"metadata": {},
"outputs": [],
"source": []
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83d233aa",
"metadata": {},
"outputs": [],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"What are two movies about dinosaurs\")"
]
}
],
"metadata": {

View File

@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import (
DEFAULT_SCHEMA,
DEFAULT_SUFFIX,
EXAMPLE_PROMPT,
EXAMPLES_WITH_LIMIT,
SCHEMA_WITH_LIMIT,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown
@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
parsed["filter"] = None
else:
parsed["filter"] = self.ast_parse(parsed["filter"])
return StructuredQuery(query=parsed["query"], filter=parsed["filter"])
return StructuredQuery(
query=parsed["query"],
filter=parsed["filter"],
limit=parsed.get("limit"),
)
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}"
@ -70,15 +76,25 @@ def _get_prompt(
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
) -> BasePromptTemplate:
attribute_str = _format_attribute_info(attribute_info)
examples = examples or DEFAULT_EXAMPLES
allowed_comparators = allowed_comparators or list(Comparator)
allowed_operators = allowed_operators or list(Operator)
schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
if enable_limit:
schema = SCHEMA_WITH_LIMIT.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
examples = examples or EXAMPLES_WITH_LIMIT
else:
schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
examples = examples or DEFAULT_EXAMPLES
prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str
@ -87,7 +103,7 @@ def _get_prompt(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return FewShotPromptTemplate(
examples=DEFAULT_EXAMPLES,
examples=examples,
example_prompt=EXAMPLE_PROMPT,
input_variables=["query"],
suffix=suffix,
@ -103,6 +119,7 @@ def load_query_constructor_chain(
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
**kwargs: Any,
) -> LLMChain:
prompt = _get_prompt(
@ -111,5 +128,6 @@ def load_query_constructor_chain(
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
enable_limit=enable_limit,
)
return LLMChain(llm=llm, prompt=prompt, **kwargs)

View File

@ -81,3 +81,4 @@ class Operation(FilterDirective):
class StructuredQuery(Expr):
query: str
filter: Optional[FilterDirective]
limit: Optional[int]

View File

@ -46,6 +46,16 @@ NO_FILTER_ANSWER = """\
```\
"""
WITH_LIMIT_ANSWER = """\
```json
{{
"query": "love",
"filter": "NO_FILTER",
"limit": 2
}}
```\
"""
DEFAULT_EXAMPLES = [
{
"i": 1,
@ -61,6 +71,27 @@ DEFAULT_EXAMPLES = [
},
]
EXAMPLES_WITH_LIMIT = [
{
"i": 1,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre",
"structured_request": FULL_ANSWER,
},
{
"i": 2,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs that were not published on Spotify",
"structured_request": NO_FILTER_ANSWER,
},
{
"i": 3,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are three songs about love",
"structured_request": WITH_LIMIT_ANSWER,
},
]
EXAMPLE_PROMPT_TEMPLATE = """\
<< Example {i}. >>
Data Source:
@ -116,6 +147,45 @@ Make sure that filters are only used as needed. If there are no filters that sho
applied return "NO_FILTER" for the filter value.\
"""
SCHEMA_WITH_LIMIT = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
```json
{{{{
"query": string \\ text string to compare to document contents
"filter": string \\ logical condition statement for filtering documents
"limit": int \\ the number of documents to retrieve
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string): name of attribute to apply the comparison to
- `val` (string): is the comparison value
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
"""
DEFAULT_PREFIX = """\
Your goal is to structure the user's query to match the request schema provided below.

View File

@ -68,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs(query)
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
)
@ -77,6 +77,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
@ -93,11 +96,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore.__class__)
chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs:
chain_kwargs[
"allowed_comparators"
@ -107,7 +112,11 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
"allowed_operators"
] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain(
llm, document_contents, metadata_field_info, **chain_kwargs
llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
)
return cls(
llm_chain=llm_chain,