Self-query with generic query constructor (#3607)

Alternate implementation of #3452 that relies on a generic query
constructor chain and language and then has vector store-specific
translation layer. Still refactoring and updating examples but general
structure is there and seems to work s well as #3452 on exampels

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
fix_agent_callbacks
Davis Chase 1 year ago committed by GitHub
parent 6d6fd1b9e1
commit 3b609642ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,330 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "13afcae7",
"metadata": {},
"source": [
"# Self-querying retriever\n",
"In the notebook we'll demo the `SelfQueryRetriever`, which, as the name suggests, has the ability to query itself. Specifically, given any natural language query, the retriever uses a query-constructing LLM chain to write a structured query and then applies that structured query to it's underlying VectorStore. This allows the retriever to not only use the user-input query for semantic similarity comparison with the contents of stored documented, but to also extract filters from the user query on the metadata of stored documents and to execute those filter."
]
},
{
"cell_type": "markdown",
"id": "68e75fb9",
"metadata": {},
"source": [
"## Creating a Pinecone index\n",
"First we'll want to create a Pinecone VectorStore and seed it with some data. We've created a small demo set of documents that contain summaries of movies.\n",
"\n",
"NOTE: The self-query retriever currently only has built-in support for Pinecone VectorStore.\n",
"\n",
"NOTE: The self-query retriever requires you to have `lark` installed (`pip install lark`)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "63a8af5b",
"metadata": {},
"outputs": [],
"source": [
"# !pip install lark"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3eb9c9a4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/pinecone/index.py:4: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" from tqdm.autonotebook import tqdm\n"
]
}
],
"source": [
"import os\n",
"\n",
"import pinecone\n",
"\n",
"\n",
"pinecone.init(api_key=os.environ[\"PINECONE_API_KEY\"], environment=os.environ[\"PINECONE_ENV\"])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cb4a5787",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import Document\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import Pinecone\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"# create new index\n",
"pinecone.create_index(\"langchain-self-retriever-demo\", dimension=1536)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bcbe04d9",
"metadata": {},
"outputs": [],
"source": [
"docs = [\n",
" Document(page_content=\"A bunch of scientists bring back dinosaurs and mayhem breaks loose\", metadata={\"year\": 1993, \"rating\": 7.7, \"genre\": [\"action\", \"science fiction\"]}),\n",
" Document(page_content=\"Leo DiCaprio gets lost in a dream within a dream within a dream within a ...\", metadata={\"year\": 2010, \"director\": \"Christopher Nolan\", \"rating\": 8.2}),\n",
" Document(page_content=\"A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea\", metadata={\"year\": 2006, \"director\": \"Satoshi Kon\", \"rating\": 8.6}),\n",
" Document(page_content=\"A bunch of normal-sized women are supremely wholesome and some men pine after them\", metadata={\"year\": 2019, \"director\": \"Greta Gerwig\", \"rating\": 8.3}),\n",
" Document(page_content=\"Toys come alive and have a blast doing so\", metadata={\"year\": 1995, \"genre\": \"animated\"}),\n",
" Document(page_content=\"Three men walk into the Zone, three men walk out of the Zone\", metadata={\"year\": 1979, \"rating\": 9.9, \"director\": \"Andrei Tarkovsky\", \"genre\": [\"science fiction\", \"thriller\"], \"rating\": 9.9})\n",
"]\n",
"vectorstore = Pinecone.from_documents(\n",
" docs, embeddings, index_name=\"langchain-self-retriever-demo\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "5ecaab6d",
"metadata": {},
"source": [
"# Creating our self-querying retriever\n",
"Now we can instantiate our retriever. To do this we'll need to provide some information upfront about the metadata fields that our documents support and a short description of the document contents."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "86e34dbf",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
"from langchain.chains.query_constructor.base import AttributeInfo\n",
"\n",
"metadata_field_info=[\n",
" AttributeInfo(\n",
" name=\"genre\",\n",
" description=\"The genre of the movie\", \n",
" type=\"string or list[string]\", \n",
" ),\n",
" AttributeInfo(\n",
" name=\"year\",\n",
" description=\"The year the movie was released\", \n",
" type=\"integer\", \n",
" ),\n",
" AttributeInfo(\n",
" name=\"director\",\n",
" description=\"The name of the movie director\", \n",
" type=\"string\", \n",
" ),\n",
" AttributeInfo(\n",
" name=\"rating\",\n",
" description=\"A 1-10 rating for the movie\",\n",
" type=\"float\"\n",
" ),\n",
"]\n",
"document_content_description = \"Brief summary of a movie\"\n",
"llm = OpenAI(temperature=0)\n",
"retriever = SelfQueryRetriever.from_llm(llm, vectorstore, document_content_description, metadata_field_info, verbose=True)"
]
},
{
"cell_type": "markdown",
"id": "ea9df8d4",
"metadata": {},
"source": [
"# Testing it out\n",
"And now we can try actually using our retriever!"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "38a126e9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='dinosaur' filter=None\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of scientists bring back dinosaurs and mayhem breaks loose', metadata={'genre': ['action', 'science fiction'], 'rating': 7.7, 'year': 1993.0}),\n",
" Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0}),\n",
" Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n",
" Document(page_content='Leo DiCaprio gets lost in a dream within a dream within a dream within a ...', metadata={'director': 'Christopher Nolan', 'rating': 8.2, 'year': 2010.0})]"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"What are some movies about dinosaurs\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fc3f1e6e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query=' ' filter=Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea', metadata={'director': 'Satoshi Kon', 'rating': 8.6, 'year': 2006.0}),\n",
" Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example only specifies a filter\n",
"retriever.get_relevant_documents(\"I want to watch a movie rated higher than 8.5\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b19d4da0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='women' filter=Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='director', value='Greta Gerwig')\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='A bunch of normal-sized women are supremely wholesome and some men pine after them', metadata={'director': 'Greta Gerwig', 'rating': 8.3, 'year': 2019.0})]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example specifies a query and a filter\n",
"retriever.get_relevant_documents(\"Has Greta Gerwig directed any movies about women\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f900e40e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query=' ' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='science fiction'), Comparison(comparator=<Comparator.GT: 'gt'>, attribute='rating', value=8.5)])\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='Three men walk into the Zone, three men walk out of the Zone', metadata={'director': 'Andrei Tarkovsky', 'genre': ['science fiction', 'thriller'], 'rating': 9.9, 'year': 1979.0})]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example specifies a composite filter\n",
"retriever.get_relevant_documents(\"What's a highly rated (above 8.5) science fiction film?\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "12a51522",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"query='toys' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.GT: 'gt'>, attribute='year', value=1990.0), Comparison(comparator=<Comparator.LT: 'lt'>, attribute='year', value=2005.0), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='genre', value='animated')])\n"
]
},
{
"data": {
"text/plain": [
"[Document(page_content='Toys come alive and have a blast doing so', metadata={'genre': 'animated', 'year': 1995.0})]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This example specifies a query and composite filter\n",
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69bbd809",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -0,0 +1,114 @@
"""LLM Chain for turning a user text query into a structured query."""
from __future__ import annotations
import json
from typing import Any, Callable, List, Optional, Sequence
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
from langchain.chains.query_constructor.ir import (
Comparator,
Operator,
StructuredQuery,
)
from langchain.chains.query_constructor.parser import get_parser
from langchain.chains.query_constructor.prompt import (
DEFAULT_EXAMPLES,
DEFAULT_PREFIX,
DEFAULT_SCHEMA,
DEFAULT_SUFFIX,
EXAMPLE_PROMPT,
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.structured import parse_json_markdown
from langchain.schema import BaseLanguageModel, BaseOutputParser, OutputParserException
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):
ast_parse: Callable
"""Callable that parses dict into internal representation of query language."""
def parse(self, text: str) -> StructuredQuery:
try:
expected_keys = ["query", "filter"]
parsed = parse_json_markdown(text, expected_keys)
if len(parsed["query"]) == 0:
parsed["query"] = " "
if parsed["filter"] == "NO_FILTER":
parsed["filter"] = None
else:
parsed["filter"] = self.ast_parse(parsed["filter"])
return StructuredQuery(query=parsed["query"], filter=parsed["filter"])
except Exception as e:
raise OutputParserException(
f"Parsing text\n{text}\n raised following error:\n{e}"
)
@classmethod
def from_components(
cls,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> StructuredQueryOutputParser:
ast_parser = get_parser(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return cls(ast_parse=ast_parser.parse)
def _format_attribute_info(info: List[AttributeInfo]) -> str:
info_dicts = {}
for i in info:
i_dict = dict(i)
info_dicts[i_dict.pop("name")] = i_dict
return json.dumps(info_dicts, indent=2).replace("{", "{{").replace("}", "}}")
def _get_prompt(
document_contents: str,
attribute_info: List[AttributeInfo],
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> BasePromptTemplate:
attribute_str = _format_attribute_info(attribute_info)
examples = examples or DEFAULT_EXAMPLES
allowed_comparators = allowed_comparators or list(Comparator)
allowed_operators = allowed_operators or list(Operator)
schema = DEFAULT_SCHEMA.format(
allowed_comparators=" | ".join(allowed_comparators),
allowed_operators=" | ".join(allowed_operators),
)
prefix = DEFAULT_PREFIX.format(schema=schema)
suffix = DEFAULT_SUFFIX.format(
i=len(examples) + 1, content=document_contents, attributes=attribute_str
)
output_parser = StructuredQueryOutputParser.from_components(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return FewShotPromptTemplate(
examples=DEFAULT_EXAMPLES,
example_prompt=EXAMPLE_PROMPT,
input_variables=["query"],
suffix=suffix,
prefix=prefix,
output_parser=output_parser,
)
def load_query_constructor_chain(
llm: BaseLanguageModel,
document_contents: str,
attribute_info: List[AttributeInfo],
examples: Optional[List] = None,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
**kwargs: Any,
) -> LLMChain:
prompt = _get_prompt(
document_contents,
attribute_info,
examples=examples,
allowed_comparators=allowed_comparators,
allowed_operators=allowed_operators,
)
return LLMChain(llm=llm, prompt=prompt, **kwargs)

@ -0,0 +1,83 @@
"""Internal representation of a structured query language."""
from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, List, Optional, Sequence
from pydantic import BaseModel
class Visitor(ABC):
"""Defines interface for IR translation using visitor pattern."""
allowed_comparators: Optional[Sequence[Comparator]] = None
allowed_operators: Optional[Sequence[Operator]] = None
@abstractmethod
def visit_operation(self, operation: Operation) -> Any:
"""Translate an Operation."""
@abstractmethod
def visit_comparison(self, comparison: Comparison) -> Any:
"""Translate a Comparison."""
@abstractmethod
def visit_structured_query(self, structured_query: StructuredQuery) -> Any:
"""Translate a StructuredQuery."""
def _to_snake_case(name: str) -> str:
"""Convert a name into snake_case."""
snake_case = ""
for i, char in enumerate(name):
if char.isupper() and i != 0:
snake_case += "_" + char.lower()
else:
snake_case += char.lower()
return snake_case
class Expr(BaseModel):
def accept(self, visitor: Visitor) -> Any:
return getattr(visitor, f"visit_{_to_snake_case(self.__class__.__name__)}")(
self
)
class Operator(str, Enum):
AND = "and"
OR = "or"
NOT = "not"
class Comparator(str, Enum):
EQ = "eq"
GT = "gt"
GTE = "gte"
LT = "lt"
LTE = "lte"
class FilterDirective(Expr, ABC):
"""A filtering expression."""
class Comparison(FilterDirective):
"""A comparison to a value."""
comparator: Comparator
attribute: str
value: Any
class Operation(FilterDirective):
"""A logical operation over other directives."""
operator: Operator
arguments: List[FilterDirective]
class StructuredQuery(Expr):
query: str
filter: Optional[FilterDirective]

@ -0,0 +1,113 @@
from typing import Any, Optional, Sequence, Union
try:
from lark import Lark, Transformer, v_args
except ImportError:
pass
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
FilterDirective,
Operation,
Operator,
)
GRAMMAR = """
?program: func_call
?expr: func_call
| value
func_call: CNAME "(" [args] ")"
?value: SIGNED_NUMBER -> number
| list
| string
| "false" -> false
| "true" -> true
args: expr ("," expr)*
string: ESCAPED_STRING
list: "[" [args] "]"
%import common.CNAME
%import common.SIGNED_NUMBER
%import common.ESCAPED_STRING
%import common.WS
%ignore WS
"""
@v_args(inline=True)
class QueryTransformer(Transformer):
def __init__(
self,
*args: Any,
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]],
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.allowed_comparators = allowed_comparators
self.allowed_operators = allowed_operators
def program(self, *items: Any) -> tuple:
return items
def func_call(self, func_name: Any, *args: Any) -> FilterDirective:
func = self._match_func_name(str(func_name))
if isinstance(func, Comparator):
return Comparison(comparator=func, attribute=args[0][0], value=args[0][1])
return Operation(operator=func, arguments=args[0])
def _match_func_name(self, func_name: str) -> Union[Operator, Comparator]:
if func_name in set(Comparator):
if self.allowed_comparators is not None:
if func_name not in self.allowed_comparators:
raise ValueError(
f"Received disallowed comparator {func_name}. Allowed "
f"comparators are {self.allowed_comparators}"
)
return Comparator(func_name)
elif func_name in set(Operator):
if self.allowed_operators is not None:
if func_name not in self.allowed_operators:
raise ValueError(
f"Received disallowed operator {func_name}. Allowed operators"
f" are {self.allowed_operators}"
)
return Operator(func_name)
else:
raise ValueError(
f"Received unrecognized function {func_name}. Valid functions are "
f"{list(Operator) + list(Comparator)}"
)
def args(self, *items: Any) -> tuple:
return items
def false(self) -> bool:
return False
def true(self) -> bool:
return True
def list(self, item: Any) -> list:
return list(item)
def number(self, item: Any) -> float:
return float(item)
def string(self, item: Any) -> str:
# Remove escaped quotes
return str(item).strip("\"'")
def get_parser(
allowed_comparators: Optional[Sequence[Comparator]] = None,
allowed_operators: Optional[Sequence[Operator]] = None,
) -> Lark:
transformer = QueryTransformer(
allowed_comparators=allowed_comparators, allowed_operators=allowed_operators
)
return Lark(GRAMMAR, parser="lalr", transformer=transformer, start="program")

@ -0,0 +1,137 @@
# flake8: noqa
from langchain import PromptTemplate
SONG_DATA_SOURCE = """\
```json
{
content: "Lyrics of a song",
attributes: {
"artist": {
"type": "string",
"description": "Name of the song artist"
},
"length": {
"type": "integer",
"description": "Length of the song in seconds"
},
"genre": {
"type": "string",
"description": "The song genre, one of \"pop\", \"rock\" or \"rap\""
}
}
}
```\
""".replace(
"{", "{{"
).replace(
"}", "}}"
)
FULL_ANSWER = """\
```json
{{
"query": "teenager love",
"filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \
lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\gg\"))"
}}"""
NO_FILTER_ANSWER = """\
```json
{{
"query": "",
"filter": "NO_FILTER"
}}
```\
"""
DEFAULT_EXAMPLES = [
{
"i": 1,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre",
"structured_request": FULL_ANSWER,
},
{
"i": 2,
"data_source": SONG_DATA_SOURCE,
"user_query": "What are songs that were not published on Spotify",
"structured_request": NO_FILTER_ANSWER,
},
]
EXAMPLE_PROMPT_TEMPLATE = """\
<< Example {i}. >>
Data Source:
{data_source}
User Query:
{user_query}
Structured Request:
{structured_request}
"""
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["i", "data_source", "user_query", "structured_request"],
template=EXAMPLE_PROMPT_TEMPLATE,
)
DEFAULT_SCHEMA = """\
<< Structured Request Schema >>
When responding use a markdown code snippet with a JSON object formatted in the \
following schema:
```json
{{{{
"query": string \\ text string to compare to document contents
"filter": string \\ logical condition statement for filtering documents
}}}}
```
The query string should contain only text that is expected to match the contents of \
documents. Any conditions in the filter should not be mentioned in the query as well.
A logical condition statement is composed of one or more comparison and logical \
operation statements.
A comparison statement takes the form: `comp(attr, val)`:
- `comp` ({allowed_comparators}): comparator
- `attr` (string): name of attribute to apply the comparison to
- `val` (string): is the comparison value
A logical operation statement takes the form `op(statement1, statement2, ...)`:
- `op` ({allowed_operators}): logical operator
- `statement1`, `statement2`, ... (comparison statements or logical operation \
statements): one or more statements to appy the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
applied return "NO_FILTER" for the filter value.\
"""
DEFAULT_PREFIX = """\
Your goal is to structure the user's query to match the request schema provided below.
{schema}\
"""
DEFAULT_SUFFIX = """\
<< Example {i}. >>
Data Source:
```json
{{{{
content: {content},
attributes: {attributes}
}}}}
```
User Query:
{{query}}
Structured Request:
"""

@ -0,0 +1,15 @@
from pydantic import BaseModel
class AttributeInfo(BaseModel):
"""Information about a data source attribute."""
name: str
description: str
type: str
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
frozen = True

@ -22,6 +22,27 @@ def _get_sub_string(schema: ResponseSchema) -> str:
)
def parse_json_markdown(text: str, expected_keys: List[str]) -> Any:
if "```json" not in text:
raise OutputParserException(
f"Got invalid return object. Expected markdown code snippet with JSON "
f"object, but got:\n{text}"
)
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj
class StructuredOutputParser(BaseOutputParser):
response_schemas: List[ResponseSchema]
@ -38,24 +59,8 @@ class StructuredOutputParser(BaseOutputParser):
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> Any:
if "```json" not in text:
raise OutputParserException(
f"Got invalid return object. Expected markdown code snippet with JSON "
f"object, but got:\n{text}"
)
json_string = text.split("```json")[1].strip().strip("```").strip()
try:
json_obj = json.loads(json_string)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for schema in self.response_schemas:
if schema.name not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
return json_obj
expected_keys = [rs.name for rs in self.response_schemas]
return parse_json_markdown(text, expected_keys)
@property
def _type(self) -> str:

@ -8,7 +8,7 @@ from langchain.schema import BaseDocumentTransformer, Document
class BaseDocumentCompressor(BaseModel, ABC):
""""""
"""Base abstraction interface for document compression."""
@abstractmethod
def compress_documents(

@ -0,0 +1,116 @@
"""Retriever that generates and executes structured queries over its own data source."""
from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.chains.query_constructor.base import (
load_query_constructor_chain,
)
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.schema import BaseLanguageModel, BaseRetriever, Document
from langchain.vectorstores import Pinecone, VectorStore
def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
"""Get the translator class corresponding to the vector store class."""
BUILTIN_TRANSLATORS: Dict[Type[VectorStore], Type[Visitor]] = {
Pinecone: PineconeTranslator
}
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore_cls}"
f" not supported."
)
return BUILTIN_TRANSLATORS[vectorstore_cls]()
class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that wraps around a vector store and uses an LLM to generate
the vector store queries."""
vectorstore: VectorStore
"""The underlying vector store from which documents will be retrieved."""
llm_chain: LLMChain
"""The LLMChain for generating the vector store queries."""
search_type: str = "similarity"
"""The search type to perform on the vector store."""
search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass in to the vector store search."""
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator(pre=True)
def validate_translator(cls, values: Dict) -> Dict:
"""Validate translator."""
if "structured_query_translator" not in values:
vectorstore_cls = values["vectorstore"].__class__
values["structured_query_translator"] = _get_builtin_translator(
vectorstore_cls
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
inputs = self.llm_chain.prep_inputs(query)
structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(**inputs)
)
if self.verbose:
print(structured_query)
new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
structured_query
)
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
vectorstore: VectorStore,
document_contents: str,
metadata_field_info: List[AttributeInfo],
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
structured_query_translator = _get_builtin_translator(vectorstore.__class__)
chain_kwargs = chain_kwargs or {}
if "allowed_comparators" not in chain_kwargs:
chain_kwargs[
"allowed_comparators"
] = structured_query_translator.allowed_comparators
if "allowed_operators" not in chain_kwargs:
chain_kwargs[
"allowed_operators"
] = structured_query_translator.allowed_operators
llm_chain = load_query_constructor_chain(
llm, document_contents, metadata_field_info, **chain_kwargs
)
return cls(
llm_chain=llm_chain,
vectorstore=vectorstore,
structured_query_translator=structured_query_translator,
**kwargs,
)

@ -0,0 +1,53 @@
"""Logic for converting internal query language to a valid Pinecone query."""
from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class PineconeTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
if isinstance(func, Operator) and self.allowed_operators is not None:
if func not in self.allowed_operators:
raise ValueError(
f"Received disallowed operator {func}. Allowed "
f"comparators are {self.allowed_operators}"
)
if isinstance(func, Comparator) and self.allowed_comparators is not None:
if func not in self.allowed_comparators:
raise ValueError(
f"Received disallowed comparator {func}. Allowed "
f"comparators are {self.allowed_comparators}"
)
return f"${func}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -75,6 +75,32 @@ class VectorStore(ABC):
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
@abstractmethod
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any

@ -243,7 +243,7 @@ class DeepLake(VectorStore):
self.ds.summary()
return ids
def search(
def _search_helper(
self,
query: Any[str, None] = None,
embedding: Any[float, None] = None,
@ -366,7 +366,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents most similar to the query vector.
"""
return self.search(query=query, k=k, **kwargs)
return self._search_helper(query=query, k=k, **kwargs)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
@ -379,7 +379,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents most similar to the query vector.
"""
return self.search(embedding=embedding, k=k, **kwargs)
return self._search_helper(embedding=embedding, k=k, **kwargs)
def similarity_search_with_score(
self,
@ -401,7 +401,7 @@ class DeepLake(VectorStore):
List[Tuple[Document, float]]: List of documents most similar to the query
text with distance in float.
"""
return self.search(
return self._search_helper(
query=query,
k=k,
filter=filter,
@ -431,7 +431,7 @@ class DeepLake(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
return self.search(
return self._search_helper(
embedding=embedding,
k=k,
fetch_k=fetch_k,
@ -465,7 +465,7 @@ class DeepLake(VectorStore):
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
return self.search(
return self._search_helper(
query=query,
k=k,
fetch_k=fetch_k,

25
poetry.lock generated

@ -3483,6 +3483,23 @@ files = [
[package.extras]
data = ["language-data (>=1.1,<2.0)"]
[[package]]
name = "lark"
version = "1.1.5"
description = "a modern parsing library"
category = "main"
optional = true
python-versions = "*"
files = [
{file = "lark-1.1.5-py3-none-any.whl", hash = "sha256:8476f9903e93fbde4f6c327f74d79e9b4bd0ed9294c5dfa3164ab8c581b5de2a"},
{file = "lark-1.1.5.tar.gz", hash = "sha256:4b534eae1f9af5b4ea000bea95776350befe1981658eea3820a01c37e504bb4d"},
]
[package.extras]
atomic-cache = ["atomicwrites"]
nearley = ["js2py"]
regex = ["regex"]
[[package]]
name = "libclang"
version = "16.0.0"
@ -9376,15 +9393,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos"]
azure = ["azure-identity", "azure-cosmos", "openai"]
all = ["aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-cosmos", "azure-identity", "openai"]
cohere = ["cohere"]
embeddings = ["sentence-transformers"]
llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
openai = ["openai"]
qdrant = ["qdrant-client"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "b629046308d7f32d4f456972ff669a383c6d349fcf1c89e6e167a74b28cbb458"
content-hash = "2979794d110362d851c1ef78075f6f394c62cbe97f7a331eeacd0d111e823b40"

@ -71,6 +71,7 @@ html2text = {version="^2020.1.16", optional=true}
numexpr = "^2.8.4"
duckduckgo-search = {version="^2.8.6", optional=true}
azure-cosmos = {version="^4.4.0b1", optional=true}
lark = {version="^1.1.5", optional=true}
lancedb = {version = "^0.1", optional = true}
[tool.poetry.group.docs.dependencies]

Loading…
Cancel
Save