Harrison/add top k (#4707)

Co-authored-by: blc16 <benlc@umich.edu>
This commit is contained in:
Harrison Chase 2023-05-15 09:09:22 -07:00 committed by GitHub
parent 0551594722
commit dd95f0892d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 20 deletions

View File

@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "cb4a5787", "id": "cb4a5787",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -46,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "bcbe04d9", "id": "bcbe04d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -83,7 +83,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "86e34dbf", "id": "86e34dbf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -138,7 +138,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='dinosaur' filter=None\n" "query='dinosaur' filter=None limit=None\n"
] ]
}, },
{ {
@ -170,7 +170,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n" "query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5) limit=None\n"
] ]
}, },
{ {
@ -200,7 +200,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n" "query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
] ]
}, },
{ {
@ -229,7 +229,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n" "query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction')]) limit=None\n"
] ]
}, },
{ {
@ -258,7 +258,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n" "query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')]) limit=None\n"
] ]
}, },
{ {
@ -277,10 +277,69 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "87513116",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "73cfca56",
"metadata": {},
"outputs": [],
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "60110338",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None limit=2\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'year': 1993, 'rating': 7.7, 'genre': 'science fiction'}),\n",
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"what are two movies about dinosaurs\")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "60110338", "id": "f15d84b3",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []

View File

@ -295,13 +295,45 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "6fe7536c",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "69bbd809", "id": "3a2937c2",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "83d233aa",
"metadata": {},
"outputs": [],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"What are two movies about dinosaurs\")"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import (
DEFAULT_SCHEMA, DEFAULT_SCHEMA,
DEFAULT_SUFFIX, DEFAULT_SUFFIX,
EXAMPLE_PROMPT, EXAMPLE_PROMPT,
EXAMPLES_WITH_LIMIT,
SCHEMA_WITH_LIMIT,
) )
from langchain.chains.query_constructor.schema import AttributeInfo from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown from langchain.output_parsers.structured import parse_json_markdown
@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
parsed["filter"] = None parsed["filter"] = None
else: else:
parsed["filter"] = self.ast_parse(parsed["filter"]) parsed["filter"] = self.ast_parse(parsed["filter"])
return StructuredQuery(query=parsed["query"], filter=parsed["filter"]) return StructuredQuery(
query=parsed["query"],
filter=parsed["filter"],
limit=parsed.get("limit"),
)
except Exception as e: except Exception as e:
raise OutputParserException( raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}" f"Parsing text\n{text}\n raised following error:\n{e}"
@ -70,15 +76,25 @@ def _get_prompt(
examples: Optional[List] = None, examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
attribute_str = _format_attribute_info(attribute_info) attribute_str = _format_attribute_info(attribute_info)
examples = examples or DEFAULT_EXAMPLES
allowed_comparators = allowed_comparators or list(Comparator) allowed_comparators = allowed_comparators or list(Comparator)
allowed_operators = allowed_operators or list(Operator) allowed_operators = allowed_operators or list(Operator)
if enable_limit:
schema = SCHEMA_WITH_LIMIT.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
examples = examples or EXAMPLES_WITH_LIMIT
else:
schema = DEFAULT_SCHEMA.format( schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators), allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators), allowed_operators=" | ".join(allowed_operators),
) )
examples = examples or DEFAULT_EXAMPLES
prefix = DEFAULT_PREFIX.format(schema=schema) prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format( suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str i=len(examples) + 1, content=document_contents, attributes=attribute_str
@ -87,7 +103,7 @@ def _get_prompt(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
) )
return FewShotPromptTemplate( return FewShotPromptTemplate(
examples=DEFAULT_EXAMPLES, examples=examples,
example_prompt=EXAMPLE_PROMPT, example_prompt=EXAMPLE_PROMPT,
input_variables=["query"], input_variables=["query"],
suffix=suffix, suffix=suffix,
@ -103,6 +119,7 @@ def load_query_constructor_chain(
examples: Optional[List] = None, examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
**kwargs: Any, **kwargs: Any,
) -> LLMChain: ) -> LLMChain:
prompt = _get_prompt( prompt = _get_prompt(
@ -111,5 +128,6 @@ def load_query_constructor_chain(
examples=examples, examples=examples,
allowed_comparators=allowed_comparators, allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators, allowed_operators=allowed_operators,
enable_limit=enable_limit,
) )
return LLMChain(llm=llm, prompt=prompt, **kwargs) return LLMChain(llm=llm, prompt=prompt, **kwargs)

View File

@ -81,3 +81,4 @@ class Operation(FilterDirective):
class StructuredQuery(Expr): class StructuredQuery(Expr):
query: str query: str
filter: Optional[FilterDirective] filter: Optional[FilterDirective]
limit: Optional[int]

View File

@ -46,6 +46,16 @@ NO_FILTER_ANSWER = """\
```\ ```\
""" """
WITH_LIMIT_ANSWER = """\
```json
{{
"query": "love",
"filter": "NO_FILTER",
"limit": 2
}}
```\
"""
DEFAULT_EXAMPLES = [ DEFAULT_EXAMPLES = [
{ {
"i": 1, "i": 1,
@ -61,6 +71,27 @@ DEFAULT_EXAMPLES = [
}, },
] ]
EXAMPLES_WITH_LIMIT = [
{
"i": 1,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre",
"structured_request": FULL_ANSWER,
},
{
"i": 2,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs that were not published on Spotify",
"structured_request": NO_FILTER_ANSWER,
},
{
"i": 3,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are three songs about love",
"structured_request": WITH_LIMIT_ANSWER,
},
]
EXAMPLE_PROMPT_TEMPLATE = """\ EXAMPLE_PROMPT_TEMPLATE = """\
<< Example {i}. >> << Example {i}. >>
Data Source: Data Source:
@ -116,6 +147,45 @@ Make sure that filters are only used as needed. If there are no filters that sho
applied return "NO_FILTER" for the filter value.\ applied return "NO_FILTER" for the filter value.\
""" """
SCHEMA_WITH_LIMIT = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
```json
{{{{
"query": string \\ text string to compare to document contents
"filter": string \\ logical condition statement for filtering documents
"limit": int \\ the number of documents to retrieve
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string): name of attribute to apply the comparison to
- `val` (string): is the comparison value
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
"""
DEFAULT_PREFIX = """\ DEFAULT_PREFIX = """\
Your goal is to structure the user's query to match the request schema provided below. Your goal is to structure the user's query to match the request schema provided below.

View File

@ -68,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
inputs = self.llm_chain.prep_inputs(query) inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast( structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs) StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
) )
@ -77,6 +77,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
new_query, new_kwargs = self.structured_query_translator.visit_structured_query( new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query structured_query
) )
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
search_kwargs = {**self.search_kwargs, **new_kwargs} search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(query, self.search_type, **search_kwargs) docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs return docs
@ -93,11 +96,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
metadata_field_info: List[AttributeInfo], metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None, structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None, chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
**kwargs: Any, **kwargs: Any,
) -> "SelfQueryRetriever": ) -> "SelfQueryRetriever":
if structured_query_translator is None: if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore.__class__) structured_query_translator = _get_builtin_translator(vectorstore.__class__)
chain_kwargs = chain_kwargs or {} chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs: if "allowed_comparators" not in chain_kwargs:
chain_kwargs[ chain_kwargs[
"allowed_comparators" "allowed_comparators"
@ -107,7 +112,11 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
"allowed_operators" "allowed_operators"
] = structured_query_translator.allowed_operators ] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain( llm_chain = load_query_constructor_chain(
llm, document_contents, metadata_field_info, **chain_kwargs llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
) )
return cls( return cls(
llm_chain=llm_chain, llm_chain=llm_chain,