Harrison/add top k (#4707)

Co-authored-by: blc16 <benlc@umich.edu>
dynamic_agent_tools
Harrison Chase 1 year ago committed by GitHub
parent 0551594722
commit dd95f0892d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,7 +32,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "cb4a5787", "id": "cb4a5787",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -46,7 +46,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "bcbe04d9", "id": "bcbe04d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -83,7 +83,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "86e34dbf", "id": "86e34dbf",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -138,7 +138,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='dinosaur' filter=None\n" "query='dinosaur' filter=None limit=None\n"
] ]
}, },
{ {
@ -170,7 +170,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n" "query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5) limit=None\n"
] ]
}, },
{ {
@ -200,7 +200,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n" "query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig') limit=None\n"
] ]
}, },
{ {
@ -229,7 +229,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n" "query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction')]) limit=None\n"
] ]
}, },
{ {
@ -258,7 +258,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n" "query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')]) limit=None\n"
] ]
}, },
{ {
@ -277,11 +277,70 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "87513116",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 4,
"id": "73cfca56",
"metadata": {},
"outputs": [],
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "60110338", "id": "60110338",
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None limit=2\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'year': 1993, 'rating': 7.7, 'genre': 'science fiction'}),\n",
" Document(page_content='Toys come alive and have a blast doing so', metadata={'year': 1995, 'genre': 'animated'})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"what are two movies about dinosaurs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f15d84b3",
"metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
} }

@ -295,13 +295,45 @@
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "6fe7536c",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a2937c2",
"metadata": {},
"outputs": [],
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "69bbd809", "id": "83d233aa",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"What are two movies about dinosaurs\")"
]
} }
], ],
"metadata": { "metadata": {

@ -18,6 +18,8 @@ from langchain.chains.query_constructor.prompt import (
DEFAULT_SCHEMA, DEFAULT_SCHEMA,
DEFAULT_SUFFIX, DEFAULT_SUFFIX,
EXAMPLE_PROMPT, EXAMPLE_PROMPT,
EXAMPLES_WITH_LIMIT,
SCHEMA_WITH_LIMIT,
) )
from langchain.chains.query_constructor.schema import AttributeInfo from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown from langchain.output_parsers.structured import parse_json_markdown
@ -38,7 +40,11 @@ class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
parsed["filter"] = None parsed["filter"] = None
else: else:
parsed["filter"] = self.ast_parse(parsed["filter"]) parsed["filter"] = self.ast_parse(parsed["filter"])
return StructuredQuery(query=parsed["query"], filter=parsed["filter"]) return StructuredQuery(
query=parsed["query"],
filter=parsed["filter"],
limit=parsed.get("limit"),
)
except Exception as e: except Exception as e:
raise OutputParserException( raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}" f"Parsing text\n{text}\n raised following error:\n{e}"
@ -70,15 +76,25 @@ def _get_prompt(
examples: Optional[List] = None, examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
) -> BasePromptTemplate: ) -> BasePromptTemplate:
attribute_str = _format_attribute_info(attribute_info) attribute_str = _format_attribute_info(attribute_info)
examples = examples or DEFAULT_EXAMPLES
allowed_comparators = allowed_comparators or list(Comparator) allowed_comparators = allowed_comparators or list(Comparator)
allowed_operators = allowed_operators or list(Operator) allowed_operators = allowed_operators or list(Operator)
schema = DEFAULT_SCHEMA.format( if enable_limit:
allowed_comparators=" | ".join(allowed_comparators), schema = SCHEMA_WITH_LIMIT.format(
allowed_operators=" | ".join(allowed_operators), allowed_comparators=" | ".join(allowed_comparators),
) allowed_operators=" | ".join(allowed_operators),
)
examples = examples or EXAMPLES_WITH_LIMIT
else:
schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
examples = examples or DEFAULT_EXAMPLES
prefix = DEFAULT_PREFIX.format(schema=schema) prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format( suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str i=len(examples) + 1, content=document_contents, attributes=attribute_str
@ -87,7 +103,7 @@ def _get_prompt(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
) )
return FewShotPromptTemplate( return FewShotPromptTemplate(
examples=DEFAULT_EXAMPLES, examples=examples,
example_prompt=EXAMPLE_PROMPT, example_prompt=EXAMPLE_PROMPT,
input_variables=["query"], input_variables=["query"],
suffix=suffix, suffix=suffix,
@ -103,6 +119,7 @@ def load_query_constructor_chain(
examples: Optional[List] = None, examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None, allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None, allowed_operators: Optional[Sequence[Operator]] = None,
enable_limit: bool = False,
**kwargs: Any, **kwargs: Any,
) -> LLMChain: ) -> LLMChain:
prompt = _get_prompt( prompt = _get_prompt(
@ -111,5 +128,6 @@ def load_query_constructor_chain(
examples=examples, examples=examples,
allowed_comparators=allowed_comparators, allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators, allowed_operators=allowed_operators,
enable_limit=enable_limit,
) )
return LLMChain(llm=llm, prompt=prompt, **kwargs) return LLMChain(llm=llm, prompt=prompt, **kwargs)

@ -81,3 +81,4 @@ class Operation(FilterDirective):
class StructuredQuery(Expr): class StructuredQuery(Expr):
query: str query: str
filter: Optional[FilterDirective] filter: Optional[FilterDirective]
limit: Optional[int]

@ -46,6 +46,16 @@ NO_FILTER_ANSWER = """\
```\ ```\
""" """
WITH_LIMIT_ANSWER = """\
```json
{{
"query": "love",
"filter": "NO_FILTER",
"limit": 2
}}
```\
"""
DEFAULT_EXAMPLES = [ DEFAULT_EXAMPLES = [
{ {
"i": 1, "i": 1,
@ -61,6 +71,27 @@ DEFAULT_EXAMPLES = [
}, },
] ]
EXAMPLES_WITH_LIMIT = [
{
"i": 1,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre",
"structured_request": FULL_ANSWER,
},
{
"i": 2,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs that were not published on Spotify",
"structured_request": NO_FILTER_ANSWER,
},
{
"i": 3,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are three songs about love",
"structured_request": WITH_LIMIT_ANSWER,
},
]
EXAMPLE_PROMPT_TEMPLATE = """\ EXAMPLE_PROMPT_TEMPLATE = """\
<< Example {i}. >> << Example {i}. >>
Data Source: Data Source:
@ -116,6 +147,45 @@ Make sure that filters are only used as needed. If there are no filters that sho
applied return "NO_FILTER" for the filter value.\ applied return "NO_FILTER" for the filter value.\
""" """
SCHEMA_WITH_LIMIT = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
```json
{{{{
"query": string \\ text string to compare to document contents
"filter": string \\ logical condition statement for filtering documents
"limit": int \\ the number of documents to retrieve
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string): name of attribute to apply the comparison to
- `val` (string): is the comparison value
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.
Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense.
"""
DEFAULT_PREFIX = """\ DEFAULT_PREFIX = """\
Your goal is to structure the user's query to match the request schema provided below. Your goal is to structure the user's query to match the request schema provided below.

@ -68,7 +68,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
Returns: Returns:
List of relevant documents List of relevant documents
""" """
inputs = self.llm_chain.prep_inputs(query) inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast( structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs) StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
) )
@ -77,6 +77,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
new_query, new_kwargs = self.structured_query_translator.visit_structured_query( new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query structured_query
) )
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
search_kwargs = {**self.search_kwargs, **new_kwargs} search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(query, self.search_type, **search_kwargs) docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs return docs
@ -93,11 +96,13 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
metadata_field_info: List[AttributeInfo], metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None, structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None, chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
**kwargs: Any, **kwargs: Any,
) -> "SelfQueryRetriever": ) -> "SelfQueryRetriever":
if structured_query_translator is None: if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore.__class__) structured_query_translator = _get_builtin_translator(vectorstore.__class__)
chain_kwargs = chain_kwargs or {} chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs: if "allowed_comparators" not in chain_kwargs:
chain_kwargs[ chain_kwargs[
"allowed_comparators" "allowed_comparators"
@ -107,7 +112,11 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
"allowed_operators" "allowed_operators"
] = structured_query_translator.allowed_operators ] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain( llm_chain = load_query_constructor_chain(
llm, document_contents, metadata_field_info, **chain_kwargs llm,
document_contents,
metadata_field_info,
enable_limit=enable_limit,
**chain_kwargs,
) )
return cls( return cls(
llm_chain=llm_chain, llm_chain=llm_chain,

Loading…
Cancel
Save