Harrison/myscale self query (#6376)

Co-authored-by: Fangrui Liu <fangruil@moqi.ai>
Co-authored-by: 刘 方瑞 <fangrui.liu@outlook.com>
Co-authored-by: Fangrui.Liu <fangrui.liu@ubc.ca>
master
Harrison Chase 11 months ago committed by GitHub
parent bd8d418a95
commit 9bf5b0defa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,370 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "13afcae7",
"metadata": {},
"source": [
"# Self-querying with MyScale\n",
"\n",
">[MyScale](https://docs.myscale.com/en/) is an integrated vector database. You can access your database in SQL and also from here, LangChain. MyScale can make a use of [various data types and functions for filters](https://blog.myscale.com/2023/06/06/why-integrated-database-solution-can-boost-your-llm-apps/#filter-on-anything-without-constraints). It will boost up your LLM app no matter if you are scaling up your data or expand your system to broader application.\n",
"\n",
"In the notebook we'll demo the `SelfQueryRetriever` wrapped around a MyScale vector store with some extra piece we contributed to LangChain. In short, it can be concluded into 4 points:\n",
"1. Add `contain` comparator to match list of any if there is more than one element matched\n",
"2. Add `timestamp` data type for datetime match (ISO-format, or YYYY-MM-DD)\n",
"3. Add `like` comparator for string pattern search\n",
"4. Add arbitrary function capability"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "68e75fb9",
"metadata": {},
"source": [
"## Creating a MyScale vectorstore\n",
"MyScale has already been integrated to LangChain for a while. So you can follow [this notebook](../../vectorstores/examples/myscale.ipynb) to create your own vectorstore for a self-query retriever.\n",
"\n",
"NOTE: All self-query retrievers requires you to have `lark` installed (`pip install lark`). We use `lark` for grammar definition. Before you proceed to the next step, we also want to remind you that `clickhouse-connect` is also needed to interact with your MyScale backend."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63a8af5b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"! pip install lark clickhouse-connect"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "83811610-7df3-4ede-b268-68a6a83ba9e2",
"metadata": {},
"source": [
"In this tutorial we follow other example's setting and use `OpenAIEmbeddings`. Remember to get a OpenAI API Key for valid accesss to LLMs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd01b61b-7d32-4a55-85d6-b2d2d4f18840",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"\n",
"os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')\n",
"os.environ['MYSCALE_HOST'] = getpass.getpass('MyScale URL:')\n",
"os.environ['MYSCALE_PORT'] = getpass.getpass('MyScale Port:')\n",
"os.environ['MYSCALE_USERNAME'] = getpass.getpass('MyScale Username:')\n",
"os.environ['MYSCALE_PASSWORD'] = getpass.getpass('MyScale Password:')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb4a5787",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.schema import Document\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.vectorstores import MyScale\n",
"\n",
"embeddings = OpenAIEmbeddings()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bf7f6fc4",
"metadata": {},
"source": [
"## Create some sample data\n",
"As you can see, the data we created has some difference to other self-query retrievers. We replaced keyword `year` to `date` which gives you a finer control on timestamps. We also altered the type of keyword `gerne` to list of strings, where LLM can use a new `contain` comparator to construct filters. We also provides comparator `like` and arbitrary function support to filters, which will be introduced in next few cells.\n",
"\n",
"Now let's look at the data first."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bcbe04d9",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"docs = [\n",
" Document(page_content=\"A bunch of scientists bring back dinosaurs and mayhem breaks loose\", metadata={\"date\": \"1993-07-02\", \"rating\": 7.7, \"genre\": [\"science fiction\"]}),\n",
" Document(page_content=\"Leo DiCaprio gets lost in a dream within a dream within a dream within a ...\", metadata={\"date\": \"2010-12-30\", \"director\": \"Christopher Nolan\", \"rating\": 8.2}),\n",
" Document(page_content=\"A psychologist / detective gets lost in a series of dreams within dreams within dreams and Inception reused the idea\", metadata={\"date\": \"2006-04-23\", \"director\": \"Satoshi Kon\", \"rating\": 8.6}),\n",
" Document(page_content=\"A bunch of normal-sized women are supremely wholesome and some men pine after them\", metadata={\"date\": \"2019-08-22\", \"director\": \"Greta Gerwig\", \"rating\": 8.3}),\n",
" Document(page_content=\"Toys come alive and have a blast doing so\", metadata={\"date\": \"1995-02-11\", \"genre\": [\"animated\"]}),\n",
" Document(page_content=\"Three men walk into the Zone, three men walk out of the Zone\", metadata={\"date\": \"1979-09-10\", \"rating\": 9.9, \"director\": \"Andrei Tarkovsky\", \"genre\": [\"science fiction\", \"adventure\"], \"rating\": 9.9})\n",
"]\n",
"vectorstore = MyScale.from_documents(\n",
" docs, \n",
" embeddings, \n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ecaab6d",
"metadata": {},
"source": [
"## Creating our self-querying retriever\n",
"Just like other retrievers... Simple and nice."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86e34dbf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
"from langchain.chains.query_constructor.base import AttributeInfo\n",
"\n",
"metadata_field_info=[\n",
" AttributeInfo(\n",
" name=\"genre\",\n",
" description=\"The genres of the movie\", \n",
" type=\"list[string]\", \n",
" ),\n",
" # If you want to include length of a list, just define it as a new column\n",
" # This will teach the LLM to use it as a column when constructing filter.\n",
" AttributeInfo(\n",
" name=\"length(genre)\",\n",
" description=\"The lenth of genres of the movie\", \n",
" type=\"integer\", \n",
" ),\n",
" # Now you can define a column as timestamp. By simply set the type to timestamp.\n",
" AttributeInfo(\n",
" name=\"date\",\n",
" description=\"The date the movie was released\", \n",
" type=\"timestamp\", \n",
" ),\n",
" AttributeInfo(\n",
" name=\"director\",\n",
" description=\"The name of the movie director\", \n",
" type=\"string\", \n",
" ),\n",
" AttributeInfo(\n",
" name=\"rating\",\n",
" description=\"A 1-10 rating for the movie\",\n",
" type=\"float\"\n",
" ),\n",
"]\n",
"document_content_description = \"Brief summary of a movie\"\n",
"llm = OpenAI(temperature=0)\n",
"retriever = SelfQueryRetriever.from_llm(llm, vectorstore, document_content_description, metadata_field_info, verbose=True)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ea9df8d4",
"metadata": {},
"source": [
"## Testing it out with self-query retriever's existing functionalities\n",
"And now we can try actually using our retriever!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38a126e9",
"metadata": {},
"outputs": [],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"What are some movies about dinosaurs\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc3f1e6e",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# This example only specifies a filter\n",
"retriever.get_relevant_documents(\"I want to watch a movie rated higher than 8.5\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b19d4da0",
"metadata": {},
"outputs": [],
"source": [
"# This example specifies a query and a filter\n",
"retriever.get_relevant_documents(\"Has Greta Gerwig directed any movies about women\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f900e40e",
"metadata": {},
"outputs": [],
"source": [
"# This example specifies a composite filter\n",
"retriever.get_relevant_documents(\"What's a highly rated (above 8.5) science fiction film?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12a51522",
"metadata": {},
"outputs": [],
"source": [
"# This example specifies a query and composite filter\n",
"retriever.get_relevant_documents(\"What's a movie after 1990 but before 2005 that's all about toys, and preferably is animated\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "86371ac8",
"metadata": {},
"source": [
"# Wait a second... What else?\n",
"\n",
"Self-query retriever with MyScale can do more! Let's find out."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d043096",
"metadata": {},
"outputs": [],
"source": [
"# You can use length(genres) to do anything you want\n",
"retriever.get_relevant_documents(\"What's a movie that have more than 1 genres?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d570d33c",
"metadata": {},
"outputs": [],
"source": [
"# Fine-grained datetime? You got it already.\n",
"retriever.get_relevant_documents(\"What's a movie that release after feb 1995?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fbe0b21b",
"metadata": {},
"outputs": [],
"source": [
"# Don't know what your exact filter should be? Use string pattern match!\n",
"retriever.get_relevant_documents(\"What's a movie whose name is like Andrei?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a514104",
"metadata": {},
"outputs": [],
"source": [
"# Contain works for lists: so you can match a list with contain comparator!\n",
"retriever.get_relevant_documents(\"What's a movie who has genres science fiction and adventure?\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "39bd1de1-b9fe-4a98-89da-58d8a7a6ae51",
"metadata": {},
"source": [
"## Filter k\n",
"\n",
"We can also use the self query retriever to specify `k`: the number of documents to fetch.\n",
"\n",
"We can do this by passing `enable_limit=True` to the constructor."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bff36b88-b506-4877-9c63-e5a1a8d78e64",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"retriever = SelfQueryRetriever.from_llm(\n",
" llm, \n",
" vectorstore, \n",
" document_content_description, \n",
" metadata_field_info, \n",
" enable_limit=True,\n",
" verbose=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2758d229-4f97-499c-819f-888acaf8ee10",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# This example only specifies a relevant query\n",
"retriever.get_relevant_documents(\"what are two movies about dinosaurs\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "13afcae7",
"metadata": {},
@ -13,12 +14,13 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "68e75fb9",
"metadata": {},
"source": [
"## Creating a Qdrant vectorstore\n",
"First we'll want to create a Chroma VectorStore and seed it with some data. We've created a small demo set of documents that contain summaries of movies.\n",
"First we'll want to create a Qdrant VectorStore and seed it with some data. We've created a small demo set of documents that contain summaries of movies.\n",
"\n",
"NOTE: The self-query retriever requires you to have `lark` installed (`pip install lark`). We also need the `qdrant-client` package."
]
@ -36,6 +38,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "83811610-7df3-4ede-b268-68a6a83ba9e2",
"metadata": {},
@ -124,6 +127,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ecaab6d",
"metadata": {},
@ -173,6 +177,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ea9df8d4",
"metadata": {},
@ -337,6 +342,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "39bd1de1-b9fe-4a98-89da-58d8a7a6ae51",
"metadata": {},

@ -67,7 +67,7 @@
"1. Environment Variables\n",
"\n",
" Before you run the app, please set the environment variable with `export`:\n",
" `export MYSCALE_URL='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`\n",
" `export MYSCALE_HOST='<your-endpoints-url>' MYSCALE_PORT=<your-endpoints-port> MYSCALE_USERNAME=<your-username> MYSCALE_PASSWORD=<your-password> ...`\n",
"\n",
" You can easily find your account, password and other info on our SaaS. For details please refer to [this document](https://docs.myscale.com/en/cluster-management/)\n",
"\n",
@ -120,18 +120,10 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "6e104aee",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting data...: 100%|██████████| 42/42 [00:18<00:00, 2.21it/s]\n"
]
}
],
"outputs": [],
"source": [
"for d in docs:\n",
" d.metadata = {\"some\": \"metadata\"}\n",
@ -143,32 +135,10 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "9c608226",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"As Frances Haugen, who is here with us tonight, has shown, we must hold social media platforms accountable for the national experiment theyre conducting on our children for profit. \n",
"\n",
"Its time to strengthen privacy protections, ban targeted advertising to children, demand tech companies stop collecting personal data on our children. \n",
"\n",
"And lets get all Americans the mental health services they need. More people they can turn to for help, and full parity between physical and mental health care. \n",
"\n",
"Third, support our veterans. \n",
"\n",
"Veterans are the best of us. \n",
"\n",
"Ive always believed that we have a sacred obligation to equip all those we send to war and care for them and their families when they come home. \n",
"\n",
"My administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free. \n",
"\n",
"Our troops in Iraq and Afghanistan faced many dangers.\n"
]
}
],
"outputs": [],
"source": [
"print(docs[0].page_content)"
]
@ -209,18 +179,10 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "232055f6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Inserting data...: 100%|██████████| 42/42 [00:15<00:00, 2.69it/s]\n"
]
}
],
"outputs": [],
"source": [
"from langchain.vectorstores import MyScale, MyScaleSettings\n",
"from langchain.document_loaders import TextLoader\n",
@ -258,21 +220,10 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "ddbcee77",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.252379834651947 {'doc_id': 6, 'some': ''} And Im taking robus...\n",
"0.25022566318511963 {'doc_id': 1, 'some': ''} Groups of citizens b...\n",
"0.2469480037689209 {'doc_id': 8, 'some': ''} And so many families...\n",
"0.2428302764892578 {'doc_id': 0, 'some': 'metadata'} As Frances Haugen, w...\n"
]
}
],
"outputs": [],
"source": [
"meta = docsearch.metadata_column\n",
"output = docsearch.similarity_search_with_relevance_scores(\n",
@ -328,7 +279,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.8.8"
}
},
"nbformat": 4,

@ -71,6 +71,8 @@ class Comparator(str, Enum):
GTE = "gte"
LT = "lt"
LTE = "lte"
CONTAIN = "contain"
LIKE = "like"
class FilterDirective(Expr, ABC):

@ -1,3 +1,4 @@
import datetime
from typing import Any, Optional, Sequence, Union
try:
@ -34,12 +35,14 @@ GRAMMAR = """
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| TIMESTAMP -> timestamp
| list
| string
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
TIMESTAMP.2: /["'](\d{4}-[01]\d-[0-3]\d)["']/
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
@ -120,6 +123,10 @@ class QueryTransformer(Transformer):
def float(self, item: Any) -> float:
return float(item)
def timestamp(self, item: Any) -> datetime.date:
item = item.replace("'", '"')
return datetime.datetime.strptime(item, '"%Y-%m-%d"').date()
def string(self, item: Any) -> str:
# Remove escaped quotes
return str(item).strip("\"'")

@ -141,6 +141,8 @@ statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
@ -179,6 +181,8 @@ statements): one or more statements to apply the operation to
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \

@ -5,15 +5,24 @@ from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate
from langchain.vectorstores import (
Chroma,
MyScale,
Pinecone,
Qdrant,
VectorStore,
Weaviate,
)
def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
@ -24,6 +33,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Chroma: ChromaTranslator,
Weaviate: WeaviateTranslator,
Qdrant: QdrantTranslator,
MyScale: MyScaleTranslator,
}
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
@ -32,6 +42,8 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
)
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
return BUILTIN_TRANSLATORS[vectorstore_cls]()
@ -50,6 +62,8 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
structured_query_translator: Visitor
"""Translator for turning internal query language into vectorstore search params."""
verbose: bool = False
"""Use original query instead of the revised new query from LLM"""
use_original_query: bool = False
class Config:
"""Configuration for this pydantic object."""
@ -65,7 +79,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
)
return values
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(
self, query: str, callbacks: Callbacks = None
) -> List[Document]:
"""Get documents relevant for a query.
Args:
@ -76,7 +92,8 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
"""
inputs = self.llm_chain.prep_inputs({"query": query})
structured_query = cast(
StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
StructuredQuery,
self.llm_chain.predict_and_parse(callbacks=callbacks, **inputs),
)
if self.verbose:
print(structured_query)
@ -86,6 +103,9 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
if structured_query.limit is not None:
new_kwargs["k"] = structured_query.limit
if self.use_original_query:
new_query = query
search_kwargs = {**self.search_kwargs, **new_kwargs}
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs
@ -103,6 +123,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
structured_query_translator: Optional[Visitor] = None,
chain_kwargs: Optional[Dict] = None,
enable_limit: bool = False,
use_original_query: bool = False,
**kwargs: Any,
) -> "SelfQueryRetriever":
if structured_query_translator is None:
@ -127,6 +148,7 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
return cls(
llm_chain=llm_chain,
vectorstore=vectorstore,
use_original_query=use_original_query,
structured_query_translator=structured_query_translator,
**kwargs,
)

@ -0,0 +1,106 @@
import datetime
import re
from typing import Any, Callable, Dict, Tuple
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
def DEFAULT_COMPOSER(op_name: str) -> Callable:
def f(*args: Any) -> str:
args_: map[str] = map(str, args)
return f" {op_name} ".join(args_)
return f
def FUNCTION_COMPOSER(op_name: str) -> Callable:
def f(*args: Any) -> str:
args_: map[str] = map(str, args)
return f"{op_name}({','.join(args_)})"
return f
class MyScaleTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.CONTAIN,
Comparator.LIKE,
]
map_dict = {
Operator.AND: DEFAULT_COMPOSER("AND"),
Operator.OR: DEFAULT_COMPOSER("OR"),
Operator.NOT: DEFAULT_COMPOSER("NOT"),
Comparator.EQ: DEFAULT_COMPOSER("="),
Comparator.GT: DEFAULT_COMPOSER(">"),
Comparator.GTE: DEFAULT_COMPOSER(">="),
Comparator.LT: DEFAULT_COMPOSER("<"),
Comparator.LTE: DEFAULT_COMPOSER("<="),
Comparator.CONTAIN: FUNCTION_COMPOSER("has"),
Comparator.LIKE: DEFAULT_COMPOSER("ILIKE"),
}
def __init__(self, metadata_key: str = "metadata") -> None:
super().__init__()
self.metadata_key = metadata_key
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
func = operation.operator
self._validate_func(func)
return self.map_dict[func](*args)
def visit_comparison(self, comparison: Comparison) -> Dict:
regex = "\((.*?)\)"
matched = re.search("\(\w+\)", comparison.attribute)
# If arbitrary function is applied to an attribute
if matched:
attr = re.sub(
regex,
f"({self.metadata_key}.{matched.group(0)[1:-1]})",
comparison.attribute,
)
else:
attr = f"{self.metadata_key}.{comparison.attribute}"
value = comparison.value
comp = comparison.comparator
value = f"'{value}'" if type(value) is str else value
# convert timestamp for datetime objects
if type(value) is datetime.date:
attr = f"parseDateTime32BestEffort({attr})"
value = f"parseDateTime32BestEffort('{value.strftime('%Y-%m-%d')}')"
# string pattern match
if comp is Comparator.LIKE:
value = f"'%{value[1:-1]}%'"
return self.map_dict[comp](attr, value)
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
print(structured_query)
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"where_str": structured_query.filter.accept(self)}
return structured_query.query, kwargs

@ -21,7 +21,7 @@ DB_NAME, COLLECTION_NAME = NAMESPACE.split(".")
# Instantiate as constant instead of pytest fixture to prevent needing to make multiple
# connections.
TEST_CLIENT = MongoClient(CONNECTION_STRING)
TEST_CLIENT: MongoClient = MongoClient(CONNECTION_STRING)
collection = TEST_CLIENT[DB_NAME][COLLECTION_NAME]

@ -0,0 +1,44 @@
from typing import Any, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
)
from langchain.retrievers.self_query.myscale import MyScaleTranslator
DEFAULT_TRANSLATOR = MyScaleTranslator()
@pytest.mark.parametrize(
"triplet",
[
(Comparator.LT, 2, "metadata.foo < 2"),
(Comparator.LTE, 2, "metadata.foo <= 2"),
(Comparator.GT, 2, "metadata.foo > 2"),
(Comparator.GTE, 2, "metadata.foo >= 2"),
(Comparator.CONTAIN, 2, "has(metadata.foo,2)"),
(Comparator.LIKE, "bar", "metadata.foo ILIKE '%bar%'"),
],
)
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
comparator, value, expected = triplet
comp = Comparison(comparator=comparator, attribute="foo", value=value)
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = "metadata.foo < 2 AND metadata.bar = 'baz'"
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
Loading…
Cancel
Save