From dd95f0892d1390db68ee20f63dcf8b9673e89b7f Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 15 May 2023 09:09:22 -0700 Subject: [PATCH] Harrison/add top k (#4707) Co-authored-by: blc16 --- .../chroma_self_query_retriever.ipynb | 77 ++++++++++++++++--- .../examples/self_query_retriever.ipynb | 36 ++++++++- langchain/chains/query_constructor/base.py | 32 ++++++-- langchain/chains/query_constructor/ir.py | 1 + langchain/chains/query_constructor/prompt.py | 70 +++++++++++++++++ langchain/retrievers/self_query/base.py | 13 +++- 6 files changed, 209 insertions(+), 20 deletions(-) diff --git a/docs/modules/indexes/retrievers/examples/chroma_self_query_retriever.ipynb b/docs/modules/indexes/retrievers/examples/chroma_self_query_retriever.ipynb index b54746a2..b448781c 100644 --- a/docs/modules/indexes/retrievers/examples/chroma_self_query_retriever.ipynb +++ b/docs/modules/indexes/retrievers/examples/chroma_self_query_retriever.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "cb4a5787", "metadata": {}, "outputs": [], @@ -46,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "bcbe04d9", "metadata": {}, "outputs": [ @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "86e34dbf", "metadata": {}, "outputs": [], @@ -138,7 +138,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "query='dinosaur' filter=None\n" + "query='dinosaur' filter=None limit=None\n" ] }, { @@ -170,7 +170,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "query=' ' filter=Comparison(comparator=, attribute='rating', value=8.5)\n" + "query=' ' filter=Comparison(comparator=, attribute='rating', value=8.5) limit=None\n" ] }, { @@ -200,7 +200,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "query='women' filter=Comparison(comparator=, attribute='director', value='Greta Gerwig')\n" + "query='women' filter=Comparison(comparator=, attribute='director', value='Greta Gerwig') limit=None\n" ] }, { @@ -229,7 +229,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "query=' ' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='genre', value='science fiction'), Comparison(comparator=, attribute='rating', value=8.5)])\n" + "query=' ' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='rating', value=8.5), Comparison(comparator=, attribute='genre', value='science fiction')]) limit=None\n" ] }, { @@ -258,7 +258,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "query='toys' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='year', value=1990), Comparison(comparator=, attribute='year', value=2005), Comparison(comparator=, attribute='genre', value='animated')])\n" + "query='toys' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='year', value=1990), Comparison(comparator=, attribute='year', value=2005), Comparison(comparator=, attribute='genre', value='animated')]) limit=None\n" ] }, { @@ -277,11 +277,70 @@ "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" ] }, + { + "cell_type": "markdown", + "id": "87513116", + "metadata": {}, + "source": [ + "## Filter k\n", + "\n", + "We can also use the self query retriever to specify `k`: the number of documents to fetch.\n", + "\n", + "We can do this by passing `enable_limit=True` to the constructor." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, + "id": "73cfca56", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = SelfQueryRetriever.from_llm(\n", + " llm, \n", + " vectorstore, \n", + " document_content_description, \n", + " metadata_field_info, \n", + " enable_limit=True,\n", + " verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "id": "60110338", "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query='dinosaur' filter=None limit=2\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'year': 1993, 'rating': 7.7, 'genre': 'science fiction'}),\n", + " Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example only specifies a relevant query\n", + "retriever.get_relevant_documents(\"what are two movies about dinosaurs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f15d84b3", + "metadata": {}, "outputs": [], "source": [] } diff --git a/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb b/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb index 7668bf34..27e5338e 100644 --- a/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb +++ b/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb @@ -295,13 +295,45 @@ "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" ] }, + { + "cell_type": "markdown", + "id": "6fe7536c", + "metadata": {}, + "source": [ + "## Filter k\n", + "\n", + "We can also use the self query retriever to specify `k`: the number of documents to fetch.\n", + "\n", + "We can do this by passing `enable_limit=True` to the constructor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a2937c2", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = SelfQueryRetriever.from_llm(\n", + " llm, \n", + " vectorstore, \n", + " document_content_description, \n", + " metadata_field_info, \n", + " enable_limit=True,\n", + " verbose=True\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "69bbd809", + "id": "83d233aa", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# This example only specifies a relevant query\n", + "retriever.get_relevant_documents(\"What are two movies about dinosaurs\")" + ] } ], "metadata": { diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py index 834084a1..48adec01 100644 --- a/langchain/chains/query_constructor/base.py +++ b/langchain/chains/query_constructor/base.py @@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import ( DEFAULT_SCHEMA, DEFAULT_SUFFIX, EXAMPLE_PROMPT, + EXAMPLES_WITH_LIMIT, + SCHEMA_WITH_LIMIT, ) from langchain.chains.query_constructor.schema import AttributeInfo from langchain.output_parsers.structured import parse_json_markdown @@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): parsed["filter"] = None else: parsed["filter"] = self.ast_parse(parsed["filter"]) - return StructuredQuery(query=parsed["query"], filter=parsed["filter"]) + return StructuredQuery( + query=parsed["query"], + filter=parsed["filter"], + limit=parsed.get("limit"), + ) except Exception as e: raise OutputParserException( f"Parsing text\n{text}\n raised following error:\n{e}" @@ -70,15 +76,25 @@ def _get_prompt( examples: Optional[List] = None, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, + enable_limit: bool = False, ) -> BasePromptTemplate: attribute_str = _format_attribute_info(attribute_info) - examples = examples or DEFAULT_EXAMPLES allowed_comparators = allowed_comparators or list(Comparator) allowed_operators = allowed_operators or list(Operator) - schema = DEFAULT_SCHEMA.format( - allowed_comparators=" | ".join(allowed_comparators), - allowed_operators=" | ".join(allowed_operators), - ) + if enable_limit: + schema = SCHEMA_WITH_LIMIT.format( + allowed_comparators=" | ".join(allowed_comparators), + allowed_operators=" | ".join(allowed_operators), + ) + + examples = examples or EXAMPLES_WITH_LIMIT + else: + schema = DEFAULT_SCHEMA.format( + allowed_comparators=" | ".join(allowed_comparators), + allowed_operators=" | ".join(allowed_operators), + ) + + examples = examples or DEFAULT_EXAMPLES prefix = DEFAULT_PREFIX.format(schema=schema) suffix = DEFAULT_SUFFIX.format( i=len(examples) + 1, content=document_contents, attributes=attribute_str @@ -87,7 +103,7 @@ def _get_prompt( allowed_comparators=allowed_comparators, allowed_operators=allowed_operators ) return FewShotPromptTemplate( - examples=DEFAULT_EXAMPLES, + examples=examples, example_prompt=EXAMPLE_PROMPT, input_variables=["query"], suffix=suffix, @@ -103,6 +119,7 @@ def load_query_constructor_chain( examples: Optional[List] = None, allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_operators: Optional[Sequence[Operator]] = None, + enable_limit: bool = False, **kwargs: Any, ) -> LLMChain: prompt = _get_prompt( @@ -111,5 +128,6 @@ def load_query_constructor_chain( examples=examples, allowed_comparators=allowed_comparators, allowed_operators=allowed_operators, + enable_limit=enable_limit, ) return LLMChain(llm=llm, prompt=prompt, **kwargs) diff --git a/langchain/chains/query_constructor/ir.py b/langchain/chains/query_constructor/ir.py index 8562ec2b..2aca4280 100644 --- a/langchain/chains/query_constructor/ir.py +++ b/langchain/chains/query_constructor/ir.py @@ -81,3 +81,4 @@ class Operation(FilterDirective): class StructuredQuery(Expr): query: str filter: Optional[FilterDirective] + limit: Optional[int] diff --git a/langchain/chains/query_constructor/prompt.py b/langchain/chains/query_constructor/prompt.py index f8cec9e5..ae7530b7 100644 --- a/langchain/chains/query_constructor/prompt.py +++ b/langchain/chains/query_constructor/prompt.py @@ -46,6 +46,16 @@ NO_FILTER_ANSWER = """\ ```\ """ +WITH_LIMIT_ANSWER = """\ +```json +{{ + "query": "love", + "filter": "NO_FILTER", + "limit": 2 +}} +```\ +""" + DEFAULT_EXAMPLES = [ { "i": 1, @@ -61,6 +71,27 @@ DEFAULT_EXAMPLES = [ }, ] +EXAMPLES_WITH_LIMIT = [ + { + "i": 1, + "data_source": SONG_DATA_SOURCE, + "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + "structured_request": FULL_ANSWER, + }, + { + "i": 2, + "data_source": SONG_DATA_SOURCE, + "user_query": "What are songs that were not published on Spotify", + "structured_request": NO_FILTER_ANSWER, + }, + { + "i": 3, + "data_source": SONG_DATA_SOURCE, + "user_query": "What are three songs about love", + "structured_request": WITH_LIMIT_ANSWER, + }, +] + EXAMPLE_PROMPT_TEMPLATE = """\ << Example {i}. >> Data Source: @@ -116,6 +147,45 @@ Make sure that filters are only used as needed. If there are no filters that sho applied return "NO_FILTER" for the filter value.\ """ +SCHEMA_WITH_LIMIT = """\ +<< Structured Request Schema >> +When responding use a markdown code snippet with a JSON object formatted in the \ +following schema: + +```json +{{{{ + "query": string \\ text string to compare to document contents + "filter": string \\ logical condition statement for filtering documents + "limit": int \\ the number of documents to retrieve +}}}} +``` + +The query string should contain only text that is expected to match the contents of \ +documents. Any conditions in the filter should not be mentioned in the query as well. + +A logical condition statement is composed of one or more comparison and logical \ +operation statements. + +A comparison statement takes the form: `comp(attr, val)`: +- `comp` ({allowed_comparators}): comparator +- `attr` (string): name of attribute to apply the comparison to +- `val` (string): is the comparison value + +A logical operation statement takes the form `op(statement1, statement2, ...)`: +- `op` ({allowed_operators}): logical operator +- `statement1`, `statement2`, ... (comparison statements or logical operation \ +statements): one or more statements to apply the operation to + +Make sure that you only use the comparators and logical operators listed above and \ +no others. +Make sure that filters only refer to attributes that exist in the data source. +Make sure that filters take into account the descriptions of attributes and only make \ +comparisons that are feasible given the type of data being stored. +Make sure that filters are only used as needed. If there are no filters that should be \ +applied return "NO_FILTER" for the filter value. +Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense. +""" + DEFAULT_PREFIX = """\ Your goal is to structure the user's query to match the request schema provided below. diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index bf5ad303..97eb3f05 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -68,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): Returns: List of relevant documents """ - inputs = self.llm_chain.prep_inputs(query) + inputs = self.llm_chain.prep_inputs({"query": query}) structured_query = cast( StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs) ) @@ -77,6 +77,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): new_query, new_kwargs = self.structured_query_translator.visit_structured_query( structured_query ) + if structured_query.limit is not None: + new_kwargs["k"] = structured_query.limit + search_kwargs = {**self.search_kwargs, **new_kwargs} docs = self.vectorstore.search(query, self.search_type, **search_kwargs) return docs @@ -93,11 +96,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): metadata_field_info: List[AttributeInfo], structured_query_translator: Optional[Visitor] = None, chain_kwargs: Optional[Dict] = None, + enable_limit: bool = False, **kwargs: Any, ) -> "SelfQueryRetriever": if structured_query_translator is None: structured_query_translator = _get_builtin_translator(vectorstore.__class__) chain_kwargs = chain_kwargs or {} + if "allowed_comparators" not in chain_kwargs: chain_kwargs[ "allowed_comparators" @@ -107,7 +112,11 @@ class SelfQueryRetriever(BaseRetriever, BaseModel): "allowed_operators" ] = structured_query_translator.allowed_operators llm_chain = load_query_constructor_chain( - llm, document_contents, metadata_field_info, **chain_kwargs + llm, + document_contents, + metadata_field_info, + enable_limit=enable_limit, + **chain_kwargs, ) return cls( llm_chain=llm_chain,