From 3b609642ae2f83cb2f503e5a2b10db85bbc6203b Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Thu, 27 Apr 2023 08:36:00 -0700 Subject: [PATCH] Self-query with generic query constructor (#3607) Alternate implementation of #3452 that relies on a generic query constructor chain and language and then has vector store-specific translation layer. Still refactoring and updating examples but general structure is there and seems to work s well as #3452 on exampels --------- Co-authored-by: Harrison Chase --- .../examples/self_query_retriever.ipynb | 330 ++++++++++++++++++ .../chains/query_constructor/__init__.py | 0 langchain/chains/query_constructor/base.py | 114 ++++++ langchain/chains/query_constructor/ir.py | 83 +++++ langchain/chains/query_constructor/parser.py | 113 ++++++ langchain/chains/query_constructor/prompt.py | 137 ++++++++ langchain/chains/query_constructor/schema.py | 15 + langchain/output_parsers/structured.py | 41 ++- .../retrievers/document_compressors/base.py | 2 +- langchain/retrievers/self_query/__init__.py | 0 langchain/retrievers/self_query/base.py | 116 ++++++ langchain/retrievers/self_query/pinecone.py | 53 +++ langchain/vectorstores/base.py | 26 ++ langchain/vectorstores/deeplake.py | 12 +- poetry.lock | 25 +- pyproject.toml | 1 + 16 files changed, 1039 insertions(+), 29 deletions(-) create mode 100644 docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb create mode 100644 langchain/chains/query_constructor/__init__.py create mode 100644 langchain/chains/query_constructor/base.py create mode 100644 langchain/chains/query_constructor/ir.py create mode 100644 langchain/chains/query_constructor/parser.py create mode 100644 langchain/chains/query_constructor/prompt.py create mode 100644 langchain/chains/query_constructor/schema.py create mode 100644 langchain/retrievers/self_query/__init__.py create mode 100644 langchain/retrievers/self_query/base.py create mode 100644 langchain/retrievers/self_query/pinecone.py diff --git a/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb b/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb new file mode 100644 index 00000000..665adf9f --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/self_query_retriever.ipynb @@ -0,0 +1,330 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "13afcae7", + "metadata": {}, + "source": [ + "# Self-querying retriever\n", + "In the notebook we'll demo the `SelfQueryRetriever`, which, as the name suggests, has the ability to query itself. Specifically, given any natural language query, the retriever uses a query-constructing LLM chain to write a structured query and then applies that structured query to it's underlying VectorStore. This allows the retriever to not only use the user-input query for semantic similarity comparison with the contents of stored documented, but to also extract filters from the user query on the metadata of stored documents and to execute those filter." + ] + }, + { + "cell_type": "markdown", + "id": "68e75fb9", + "metadata": {}, + "source": [ + "## Creating a Pinecone index\n", + "First we'll want to create a Pinecone VectorStore and seed it with some data. We've created a small demo set of documents that contain summaries of movies.\n", + "\n", + "NOTE: The self-query retriever currently only has built-in support for Pinecone VectorStore.\n", + "\n", + "NOTE: The self-query retriever requires you to have `lark` installed (`pip install lark`)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "63a8af5b", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install lark" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3eb9c9a4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pinecone/index.py:4: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "import pinecone\n", + "\n", + "\n", + "pinecone.init(api_key=os.environ[\"PINECONE_API_KEY\"], environment=os.environ[\"PINECONE_ENV\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "cb4a5787", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import Document\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.vectorstores import Pinecone\n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "# create new index\n", + "pinecone.create_index(\"langchain-self-retriever-demo\", dimension=1536)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bcbe04d9", + "metadata": {}, + "outputs": [], + "source": [ + "docs = [\n", + " Document(page_content=\"A bunch of scientists bring back dinosaurs and mayhem breaks loose\", metadata={\"year\": 1993, \"rating\": 7.7, \"genre\": [\"action\", \"science fiction\"]}),\n", + " Document(page_content=\"Leo DiCaprio gets lost in a dream within a dream within a dream within a ...\", metadata={\"year\": 2010, \"director\": \"Christopher Nolan\", \"rating\": 8.2}),\n", + " Document(page_content=\"A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea\", metadata={\"year\": 2006, \"director\": \"Satoshi Kon\", \"rating\": 8.6}),\n", + " Document(page_content=\"A bunch of normal-sized women are supremely wholesome and some men pine after them\", metadata={\"year\": 2019, \"director\": \"Greta Gerwig\", \"rating\": 8.3}),\n", + " Document(page_content=\"Toys come alive and have a blast doing so\", metadata={\"year\": 1995, \"genre\": \"animated\"}),\n", + " Document(page_content=\"Three men walk into the Zone, three men walk out of the Zone\", metadata={\"year\": 1979, \"rating\": 9.9, \"director\": \"Andrei Tarkovsky\", \"genre\": [\"science fiction\", \"thriller\"], \"rating\": 9.9})\n", + "]\n", + "vectorstore = Pinecone.from_documents(\n", + " docs, embeddings, index_name=\"langchain-self-retriever-demo\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5ecaab6d", + "metadata": {}, + "source": [ + "# Creating our self-querying retriever\n", + "Now we can instantiate our retriever. To do this we'll need to provide some information upfront about the metadata fields that our documents support and a short description of the document contents." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "86e34dbf", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.retrievers.self_query.base import SelfQueryRetriever\n", + "from langchain.chains.query_constructor.base import AttributeInfo\n", + "\n", + "metadata_field_info=[\n", + " AttributeInfo(\n", + " name=\"genre\",\n", + " description=\"The genre of the movie\", \n", + " type=\"string or list[string]\", \n", + " ),\n", + " AttributeInfo(\n", + " name=\"year\",\n", + " description=\"The year the movie was released\", \n", + " type=\"integer\", \n", + " ),\n", + " AttributeInfo(\n", + " name=\"director\",\n", + " description=\"The name of the movie director\", \n", + " type=\"string\", \n", + " ),\n", + " AttributeInfo(\n", + " name=\"rating\",\n", + " description=\"A 1-10 rating for the movie\",\n", + " type=\"float\"\n", + " ),\n", + "]\n", + "document_content_description = \"Brief summary of a movie\"\n", + "llm = OpenAI(temperature=0)\n", + "retriever = SelfQueryRetriever.from_llm(llm, vectorstore, document_content_description, metadata_field_info, verbose=True)" + ] + }, + { + "cell_type": "markdown", + "id": "ea9df8d4", + "metadata": {}, + "source": [ + "# Testing it out\n", + "And now we can try actually using our retriever!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "38a126e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query='dinosaur' filter=None\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'genre': ['action', 'science fiction'], 'rating': 7.7, 'year': 1993.0}),\n", + " Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0}),\n", + " Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n", + " Document(page_content='Leo DiCaprio gets lost in a dream within a dream within a dream within a ...', metadata={'director': 'Christopher Nolan', 'rating': 8.2, 'year': 2010.0})]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example only specifies a relevant query\n", + "retriever.get_relevant_documents(\"What are some movies about dinosaurs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fc3f1e6e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query=' ' filter=Comparison(comparator=, attribute='rating', value=8.5)\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n", + " Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example only specifies a filter\n", + "retriever.get_relevant_documents(\"I want to watch a movie rated higher than 8.5\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b19d4da0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query='women' filter=Comparison(comparator=, attribute='director', value='Greta Gerwig')\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='A bunch of normal-sized women are supremely wholesome and some men pine after them', metadata={'director': 'Greta Gerwig', 'rating': 8.3, 'year': 2019.0})]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example specifies a query and a filter\n", + "retriever.get_relevant_documents(\"Has Greta Gerwig directed any movies about women\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f900e40e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query=' ' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='genre', value='science fiction'), Comparison(comparator=, attribute='rating', value=8.5)])\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example specifies a composite filter\n", + "retriever.get_relevant_documents(\"What's a highly rated (above 8.5) science fiction film?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "12a51522", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query='toys' filter=Operation(operator=, arguments=[Comparison(comparator=, attribute='year', value=1990.0), Comparison(comparator=, attribute='year', value=2005.0), Comparison(comparator=, attribute='genre', value='animated')])\n" + ] + }, + { + "data": { + "text/plain": [ + "[Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0})]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# This example specifies a query and composite filter\n", + "retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69bbd809", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/query_constructor/__init__.py b/langchain/chains/query_constructor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain/chains/query_constructor/base.py b/langchain/chains/query_constructor/base.py new file mode 100644 index 00000000..1ff730fc --- /dev/null +++ b/langchain/chains/query_constructor/base.py @@ -0,0 +1,114 @@ +"""LLM Chain for turning a user text query into a structured query.""" +from __future__ import annotations + +import json +from typing import Any, Callable, List, Optional, Sequence + +from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain +from langchain.chains.query_constructor.ir import ( + Comparator, + Operator, + StructuredQuery, +) +from langchain.chains.query_constructor.parser import get_parser +from langchain.chains.query_constructor.prompt import ( + DEFAULT_EXAMPLES, + DEFAULT_PREFIX, + DEFAULT_SCHEMA, + DEFAULT_SUFFIX, + EXAMPLE_PROMPT, +) +from langchain.chains.query_constructor.schema import AttributeInfo +from langchain.output_parsers.structured import parse_json_markdown +from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException + + +class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]): + ast_parse: Callable + """Callable that parses dict into internal representation of query language.""" + + def parse(self, text: str) -> StructuredQuery: + try: + expected_keys = ["query", "filter"] + parsed = parse_json_markdown(text, expected_keys) + if len(parsed["query"]) == 0: + parsed["query"] = " " + if parsed["filter"] == "NO_FILTER": + parsed["filter"] = None + else: + parsed["filter"] = self.ast_parse(parsed["filter"]) + return StructuredQuery(query=parsed["query"], filter=parsed["filter"]) + except Exception as e: + raise OutputParserException( + f"Parsing text\n{text}\n raised following error:\n{e}" + ) + + @classmethod + def from_components( + cls, + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]] = None, + ) -> StructuredQueryOutputParser: + ast_parser = get_parser( + allowed_comparators=allowed_comparators, allowed_operators=allowed_operators + ) + return cls(ast_parse=ast_parser.parse) + + +def _format_attribute_info(info: List[AttributeInfo]) -> str: + info_dicts = {} + for i in info: + i_dict = dict(i) + info_dicts[i_dict.pop("name")] = i_dict + return json.dumps(info_dicts, indent=2).replace("{", "{{").replace("}", "}}") + + +def _get_prompt( + document_contents: str, + attribute_info: List[AttributeInfo], + examples: Optional[List] = None, + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]] = None, +) -> BasePromptTemplate: + attribute_str = _format_attribute_info(attribute_info) + examples = examples or DEFAULT_EXAMPLES + allowed_comparators = allowed_comparators or list(Comparator) + allowed_operators = allowed_operators or list(Operator) + schema = DEFAULT_SCHEMA.format( + allowed_comparators=" | ".join(allowed_comparators), + allowed_operators=" | ".join(allowed_operators), + ) + prefix = DEFAULT_PREFIX.format(schema=schema) + suffix = DEFAULT_SUFFIX.format( + i=len(examples) + 1, content=document_contents, attributes=attribute_str + ) + output_parser = StructuredQueryOutputParser.from_components( + allowed_comparators=allowed_comparators, allowed_operators=allowed_operators + ) + return FewShotPromptTemplate( + examples=DEFAULT_EXAMPLES, + example_prompt=EXAMPLE_PROMPT, + input_variables=["query"], + suffix=suffix, + prefix=prefix, + output_parser=output_parser, + ) + + +def load_query_constructor_chain( + llm: BaseLanguageModel, + document_contents: str, + attribute_info: List[AttributeInfo], + examples: Optional[List] = None, + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]] = None, + **kwargs: Any, +) -> LLMChain: + prompt = _get_prompt( + document_contents, + attribute_info, + examples=examples, + allowed_comparators=allowed_comparators, + allowed_operators=allowed_operators, + ) + return LLMChain(llm=llm, prompt=prompt, **kwargs) diff --git a/langchain/chains/query_constructor/ir.py b/langchain/chains/query_constructor/ir.py new file mode 100644 index 00000000..8562ec2b --- /dev/null +++ b/langchain/chains/query_constructor/ir.py @@ -0,0 +1,83 @@ +"""Internal representation of a structured query language.""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, List, Optional, Sequence + +from pydantic import BaseModel + + +class Visitor(ABC): + """Defines interface for IR translation using visitor pattern.""" + + allowed_comparators: Optional[Sequence[Comparator]] = None + allowed_operators: Optional[Sequence[Operator]] = None + + @abstractmethod + def visit_operation(self, operation: Operation) -> Any: + """Translate an Operation.""" + + @abstractmethod + def visit_comparison(self, comparison: Comparison) -> Any: + """Translate a Comparison.""" + + @abstractmethod + def visit_structured_query(self, structured_query: StructuredQuery) -> Any: + """Translate a StructuredQuery.""" + + +def _to_snake_case(name: str) -> str: + """Convert a name into snake_case.""" + snake_case = "" + for i, char in enumerate(name): + if char.isupper() and i != 0: + snake_case += "_" + char.lower() + else: + snake_case += char.lower() + return snake_case + + +class Expr(BaseModel): + def accept(self, visitor: Visitor) -> Any: + return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")( + self + ) + + +class Operator(str, Enum): + AND = "and" + OR = "or" + NOT = "not" + + +class Comparator(str, Enum): + EQ = "eq" + GT = "gt" + GTE = "gte" + LT = "lt" + LTE = "lte" + + +class FilterDirective(Expr, ABC): + """A filtering expression.""" + + +class Comparison(FilterDirective): + """A comparison to a value.""" + + comparator: Comparator + attribute: str + value: Any + + +class Operation(FilterDirective): + """A logical operation over other directives.""" + + operator: Operator + arguments: List[FilterDirective] + + +class StructuredQuery(Expr): + query: str + filter: Optional[FilterDirective] diff --git a/langchain/chains/query_constructor/parser.py b/langchain/chains/query_constructor/parser.py new file mode 100644 index 00000000..e59ec057 --- /dev/null +++ b/langchain/chains/query_constructor/parser.py @@ -0,0 +1,113 @@ +from typing import Any, Optional, Sequence, Union + +try: + from lark import Lark, Transformer, v_args +except ImportError: + pass + +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + FilterDirective, + Operation, + Operator, +) + +GRAMMAR = """ + ?program: func_call + ?expr: func_call + | value + + func_call: CNAME "(" [args] ")" + + ?value: SIGNED_NUMBER -> number + | list + | string + | "false" -> false + | "true" -> true + + args: expr ("," expr)* + string: ESCAPED_STRING + list: "[" [args] "]" + + %import common.CNAME + %import common.SIGNED_NUMBER + %import common.ESCAPED_STRING + %import common.WS + %ignore WS +""" + + +@v_args(inline=True) +class QueryTransformer(Transformer): + def __init__( + self, + *args: Any, + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]], + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.allowed_comparators = allowed_comparators + self.allowed_operators = allowed_operators + + def program(self, *items: Any) -> tuple: + return items + + def func_call(self, func_name: Any, *args: Any) -> FilterDirective: + func = self._match_func_name(str(func_name)) + if isinstance(func, Comparator): + return Comparison(comparator=func, attribute=args[0][0], value=args[0][1]) + return Operation(operator=func, arguments=args[0]) + + def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]: + if func_name in set(Comparator): + if self.allowed_comparators is not None: + if func_name not in self.allowed_comparators: + raise ValueError( + f"Received disallowed comparator {func_name}. Allowed " + f"comparators are {self.allowed_comparators}" + ) + return Comparator(func_name) + elif func_name in set(Operator): + if self.allowed_operators is not None: + if func_name not in self.allowed_operators: + raise ValueError( + f"Received disallowed operator {func_name}. Allowed operators" + f" are {self.allowed_operators}" + ) + return Operator(func_name) + else: + raise ValueError( + f"Received unrecognized function {func_name}. Valid functions are " + f"{list(Operator) + list(Comparator)}" + ) + + def args(self, *items: Any) -> tuple: + return items + + def false(self) -> bool: + return False + + def true(self) -> bool: + return True + + def list(self, item: Any) -> list: + return list(item) + + def number(self, item: Any) -> float: + return float(item) + + def string(self, item: Any) -> str: + # Remove escaped quotes + return str(item).strip("\"'") + + +def get_parser( + allowed_comparators: Optional[Sequence[Comparator]] = None, + allowed_operators: Optional[Sequence[Operator]] = None, +) -> Lark: + transformer = QueryTransformer( + allowed_comparators=allowed_comparators, allowed_operators=allowed_operators + ) + return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program") diff --git a/langchain/chains/query_constructor/prompt.py b/langchain/chains/query_constructor/prompt.py new file mode 100644 index 00000000..89d8a60a --- /dev/null +++ b/langchain/chains/query_constructor/prompt.py @@ -0,0 +1,137 @@ +# flake8: noqa +from langchain import PromptTemplate + +SONG_DATA_SOURCE = """\ +```json +{ + content: "Lyrics of a song", + attributes: { + "artist": { + "type": "string", + "description": "Name of the song artist" + }, + "length": { + "type": "integer", + "description": "Length of the song in seconds" + }, + "genre": { + "type": "string", + "description": "The song genre, one of \"pop\", \"rock\" or \"rap\"" + } + } +} +```\ +""".replace( + "{", "{{" +).replace( + "}", "}}" +) + +FULL_ANSWER = """\ +```json +{{ + "query": "teenager love", + "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \ +lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))" +}}""" + +NO_FILTER_ANSWER = """\ +```json +{{ + "query": "", + "filter": "NO_FILTER" +}} +```\ +""" + +DEFAULT_EXAMPLES = [ + { + "i": 1, + "data_source": SONG_DATA_SOURCE, + "user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + "structured_request": FULL_ANSWER, + }, + { + "i": 2, + "data_source": SONG_DATA_SOURCE, + "user_query": "What are songs that were not published on Spotify", + "structured_request": NO_FILTER_ANSWER, + }, +] + +EXAMPLE_PROMPT_TEMPLATE = """\ +<< Example {i}. >> +Data Source: +{data_source} + +User Query: +{user_query} + +Structured Request: +{structured_request} +""" + +EXAMPLE_PROMPT = PromptTemplate( + input_variables=["i", "data_source", "user_query", "structured_request"], + template=EXAMPLE_PROMPT_TEMPLATE, +) + + +DEFAULT_SCHEMA = """\ +<< Structured Request Schema >> +When responding use a markdown code snippet with a JSON object formatted in the \ +following schema: + +```json +{{{{ + "query": string \\ text string to compare to document contents + "filter": string \\ logical condition statement for filtering documents +}}}} +``` + +The query string should contain only text that is expected to match the contents of \ +documents. Any conditions in the filter should not be mentioned in the query as well. + +A logical condition statement is composed of one or more comparison and logical \ +operation statements. + +A comparison statement takes the form: `comp(attr, val)`: +- `comp` ({allowed_comparators}): comparator +- `attr` (string): name of attribute to apply the comparison to +- `val` (string): is the comparison value + +A logical operation statement takes the form `op(statement1, statement2, ...)`: +- `op` ({allowed_operators}): logical operator +- `statement1`, `statement2`, ... (comparison statements or logical operation \ +statements): one or more statements to appy the operation to + +Make sure that you only use the comparators and logical operators listed above and \ +no others. +Make sure that filters only refer to attributes that exist in the data source. +Make sure that filters take into account the descriptions of attributes and only make \ +comparisons that are feasible given the type of data being stored. +Make sure that filters are only used as needed. If there are no filters that should be \ +applied return "NO_FILTER" for the filter value.\ +""" + +DEFAULT_PREFIX = """\ +Your goal is to structure the user's query to match the request schema provided below. + +{schema}\ +""" + +DEFAULT_SUFFIX = """\ +<< Example {i}. >> +Data Source: +```json +{{{{ + content: {content}, + attributes: {attributes} +}}}} +``` + +User Query: +{{query}} + +Structured Request: +""" diff --git a/langchain/chains/query_constructor/schema.py b/langchain/chains/query_constructor/schema.py new file mode 100644 index 00000000..557ad5ea --- /dev/null +++ b/langchain/chains/query_constructor/schema.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class AttributeInfo(BaseModel): + """Information about a data source attribute.""" + + name: str + description: str + type: str + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + frozen = True diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py index af9b80bc..345950f9 100644 --- a/langchain/output_parsers/structured.py +++ b/langchain/output_parsers/structured.py @@ -22,6 +22,27 @@ def _get_sub_string(schema: ResponseSchema) -> str: ) +def parse_json_markdown(text: str, expected_keys: List[str]) -> Any: + if "```json" not in text: + raise OutputParserException( + f"Got invalid return object. Expected markdown code snippet with JSON " + f"object, but got:\n{text}" + ) + + json_string = text.split("```json")[1].strip().strip("```").strip() + try: + json_obj = json.loads(json_string) + except json.JSONDecodeError as e: + raise OutputParserException(f"Got invalid JSON object. Error: {e}") + for key in expected_keys: + if key not in json_obj: + raise OutputParserException( + f"Got invalid return object. Expected key `{key}` " + f"to be present, but got {json_obj}" + ) + return json_obj + + class StructuredOutputParser(BaseOutputParser): response_schemas: List[ResponseSchema] @@ -38,24 +59,8 @@ class StructuredOutputParser(BaseOutputParser): return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) def parse(self, text: str) -> Any: - if "```json" not in text: - raise OutputParserException( - f"Got invalid return object. Expected markdown code snippet with JSON " - f"object, but got:\n{text}" - ) - - json_string = text.split("```json")[1].strip().strip("```").strip() - try: - json_obj = json.loads(json_string) - except json.JSONDecodeError as e: - raise OutputParserException(f"Got invalid JSON object. Error: {e}") - for schema in self.response_schemas: - if schema.name not in json_obj: - raise OutputParserException( - f"Got invalid return object. Expected key `{schema.name}` " - f"to be present, but got {json_obj}" - ) - return json_obj + expected_keys = [rs.name for rs in self.response_schemas] + return parse_json_markdown(text, expected_keys) @property def _type(self) -> str: diff --git a/langchain/retrievers/document_compressors/base.py b/langchain/retrievers/document_compressors/base.py index b42d95ea..6c697f55 100644 --- a/langchain/retrievers/document_compressors/base.py +++ b/langchain/retrievers/document_compressors/base.py @@ -8,7 +8,7 @@ from langchain.schema import BaseDocumentTransformer, Document class BaseDocumentCompressor(BaseModel, ABC): - """""" + """Base abstraction interface for document compression.""" @abstractmethod def compress_documents( diff --git a/langchain/retrievers/self_query/__init__.py b/langchain/retrievers/self_query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py new file mode 100644 index 00000000..a9d7cad4 --- /dev/null +++ b/langchain/retrievers/self_query/base.py @@ -0,0 +1,116 @@ +"""Retriever that generates and executes structured queries over its own data source.""" +from typing import Any, Dict, List, Optional, Type, cast + +from pydantic import BaseModel, Field, root_validator + +from langchain import LLMChain +from langchain.chains.query_constructor.base import ( + load_query_constructor_chain, +) +from langchain.chains.query_constructor.ir import StructuredQuery, Visitor +from langchain.chains.query_constructor.schema import AttributeInfo +from langchain.retrievers.self_query.pinecone import PineconeTranslator +from langchain.schema import BaseLanguageModel, BaseRetriever, Document +from langchain.vectorstores import Pinecone, VectorStore + + +def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor: + """Get the translator class corresponding to the vector store class.""" + BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = { + Pinecone: PineconeTranslator + } + if vectorstore_cls not in BUILTIN_TRANSLATORS: + raise ValueError( + f"Self query retriever with Vector Store type {vectorstore_cls}" + f" not supported." + ) + return BUILTIN_TRANSLATORS[vectorstore_cls]() + + +class SelfQueryRetriever(BaseRetriever, BaseModel): + """Retriever that wraps around a vector store and uses an LLM to generate + the vector store queries.""" + + vectorstore: VectorStore + """The underlying vector store from which documents will be retrieved.""" + llm_chain: LLMChain + """The LLMChain for generating the vector store queries.""" + search_type: str = "similarity" + """The search type to perform on the vector store.""" + search_kwargs: dict = Field(default_factory=dict) + """Keyword arguments to pass in to the vector store search.""" + structured_query_translator: Visitor + """Translator for turning internal query language into vectorstore search params.""" + verbose: bool = False + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator(pre=True) + def validate_translator(cls, values: Dict) -> Dict: + """Validate translator.""" + if "structured_query_translator" not in values: + vectorstore_cls = values["vectorstore"].__class__ + values["structured_query_translator"] = _get_builtin_translator( + vectorstore_cls + ) + return values + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + inputs = self.llm_chain.prep_inputs(query) + structured_query = cast( + StructuredQuery, self.llm_chain.predict_and_parse(**inputs) + ) + if self.verbose: + print(structured_query) + new_query, new_kwargs = self.structured_query_translator.visit_structured_query( + structured_query + ) + search_kwargs = {**self.search_kwargs, **new_kwargs} + docs = self.vectorstore.search(query, self.search_type, **search_kwargs) + return docs + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + vectorstore: VectorStore, + document_contents: str, + metadata_field_info: List[AttributeInfo], + structured_query_translator: Optional[Visitor] = None, + chain_kwargs: Optional[Dict] = None, + **kwargs: Any, + ) -> "SelfQueryRetriever": + if structured_query_translator is None: + structured_query_translator = _get_builtin_translator(vectorstore.__class__) + chain_kwargs = chain_kwargs or {} + if "allowed_comparators" not in chain_kwargs: + chain_kwargs[ + "allowed_comparators" + ] = structured_query_translator.allowed_comparators + if "allowed_operators" not in chain_kwargs: + chain_kwargs[ + "allowed_operators" + ] = structured_query_translator.allowed_operators + llm_chain = load_query_constructor_chain( + llm, document_contents, metadata_field_info, **chain_kwargs + ) + return cls( + llm_chain=llm_chain, + vectorstore=vectorstore, + structured_query_translator=structured_query_translator, + **kwargs, + ) diff --git a/langchain/retrievers/self_query/pinecone.py b/langchain/retrievers/self_query/pinecone.py new file mode 100644 index 00000000..c4e0b844 --- /dev/null +++ b/langchain/retrievers/self_query/pinecone.py @@ -0,0 +1,53 @@ +"""Logic for converting internal query language to a valid Pinecone query.""" +from typing import Dict, Tuple, Union + +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, + Visitor, +) + + +class PineconeTranslator(Visitor): + """Logic for converting internal query language elements to valid filters.""" + + allowed_operators = [Operator.AND, Operator.OR] + """Subset of allowed logical operators.""" + + def _format_func(self, func: Union[Operator, Comparator]) -> str: + if isinstance(func, Operator) and self.allowed_operators is not None: + if func not in self.allowed_operators: + raise ValueError( + f"Received disallowed operator {func}. Allowed " + f"comparators are {self.allowed_operators}" + ) + if isinstance(func, Comparator) and self.allowed_comparators is not None: + if func not in self.allowed_comparators: + raise ValueError( + f"Received disallowed comparator {func}. Allowed " + f"comparators are {self.allowed_comparators}" + ) + return f"${func}" + + def visit_operation(self, operation: Operation) -> Dict: + args = [arg.accept(self) for arg in operation.arguments] + return {self._format_func(operation.operator): args} + + def visit_comparison(self, comparison: Comparison) -> Dict: + return { + comparison.attribute: { + self._format_func(comparison.comparator): comparison.value + } + } + + def visit_structured_query( + self, structured_query: StructuredQuery + ) -> Tuple[str, dict]: + if structured_query.filter is None: + kwargs = {} + else: + kwargs = {"filter": structured_query.filter.accept(self)} + return structured_query.query, kwargs diff --git a/langchain/vectorstores/base.py b/langchain/vectorstores/base.py index db7ff43c..53c9fa0c 100644 --- a/langchain/vectorstores/base.py +++ b/langchain/vectorstores/base.py @@ -75,6 +75,32 @@ class VectorStore(ABC): metadatas = [doc.metadata for doc in documents] return await self.aadd_texts(texts, metadatas, **kwargs) + def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]: + """Return docs most similar to query using specified search type.""" + if search_type == "similarity": + return self.similarity_search(query, **kwargs) + elif search_type == "mmr": + return self.max_marginal_relevance_search(query, **kwargs) + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity' or 'mmr'." + ) + + async def asearch( + self, query: str, search_type: str, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query using specified search type.""" + if search_type == "similarity": + return await self.asimilarity_search(query, **kwargs) + elif search_type == "mmr": + return await self.amax_marginal_relevance_search(query, **kwargs) + else: + raise ValueError( + f"search_type of {search_type} not allowed. Expected " + "search_type to be 'similarity' or 'mmr'." + ) + @abstractmethod def similarity_search( self, query: str, k: int = 4, **kwargs: Any diff --git a/langchain/vectorstores/deeplake.py b/langchain/vectorstores/deeplake.py index 05cf573e..7773e26b 100644 --- a/langchain/vectorstores/deeplake.py +++ b/langchain/vectorstores/deeplake.py @@ -243,7 +243,7 @@ class DeepLake(VectorStore): self.ds.summary() return ids - def search( + def _search_helper( self, query: Any[str, None] = None, embedding: Any[float, None] = None, @@ -366,7 +366,7 @@ class DeepLake(VectorStore): Returns: List of Documents most similar to the query vector. """ - return self.search(query=query, k=k, **kwargs) + return self._search_helper(query=query, k=k, **kwargs) def similarity_search_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any @@ -379,7 +379,7 @@ class DeepLake(VectorStore): Returns: List of Documents most similar to the query vector. """ - return self.search(embedding=embedding, k=k, **kwargs) + return self._search_helper(embedding=embedding, k=k, **kwargs) def similarity_search_with_score( self, @@ -401,7 +401,7 @@ class DeepLake(VectorStore): List[Tuple[Document, float]]: List of documents most similar to the query text with distance in float. """ - return self.search( + return self._search_helper( query=query, k=k, filter=filter, @@ -431,7 +431,7 @@ class DeepLake(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ - return self.search( + return self._search_helper( embedding=embedding, k=k, fetch_k=fetch_k, @@ -465,7 +465,7 @@ class DeepLake(VectorStore): raise ValueError( "For MMR search, you must specify an embedding function on" "creation." ) - return self.search( + return self._search_helper( query=query, k=k, fetch_k=fetch_k, diff --git a/poetry.lock b/poetry.lock index 5e814181..6db2e611 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3483,6 +3483,23 @@ files = [ [package.extras] data = ["language-data (>=1.1,<2.0)"] +[[package]] +name = "lark" +version = "1.1.5" +description = "a modern parsing library" +category = "main" +optional = true +python-versions = "*" +files = [ + {file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"}, + {file = "lark-1.1.5.tar.gz", hash = "sha256:4b534eae1f9af5b4ea000bea95776350befe1981658eea3820a01c37e504bb4d"}, +] + +[package.extras] +atomic-cache = ["atomicwrites"] +nearley = ["js2py"] +regex = ["regex"] + [[package]] name = "libclang" version = "16.0.0" @@ -9376,15 +9393,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos"] -azure = ["azure-identity", "azure-cosmos", "openai"] +all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +azure = ["azure-cosmos", "azure-identity", "openai"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] +llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] openai = ["openai"] qdrant = ["qdrant-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "b629046308d7f32d4f456972ff669a383c6d349fcf1c89e6e167a74b28cbb458" +content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40" diff --git a/pyproject.toml b/pyproject.toml index a1635d3d..0dc6e005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ html2text = {version="^2020.1.16", optional=true} numexpr = "^2.8.4" duckduckgo-search = {version="^2.8.6", optional=true} azure-cosmos = {version="^4.4.0b1", optional=true} +lark = {version="^1.1.5", optional=true} lancedb = {version = "^0.1", optional = true} [tool.poetry.group.docs.dependencies]