From d9bc46186d9e10270fd26ab57c339c98b39626bc Mon Sep 17 00:00:00 2001 From: shibuiwilliam Date: Sun, 6 Aug 2023 09:31:41 +0900 Subject: [PATCH] Add missing test for retrievers self_query (#8783) # What - Add missing test for retrievers self_query - Add missing import validation --- .../langchain/retrievers/self_query/qdrant.py | 16 +++- .../retrievers/self_query/test_chroma.py | 89 +++++++++++++++++++ .../retrievers/self_query/test_deeplake.py | 49 ++++++++++ .../retrievers/self_query/test_myscale.py | 44 ++++++++- .../retrievers/self_query/test_pinecone.py | 45 ++++++++++ .../retrievers/self_query/test_weaviate.py | 88 ++++++++++++++++++ 6 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py create mode 100644 libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py diff --git a/libs/langchain/langchain/retrievers/self_query/qdrant.py b/libs/langchain/langchain/retrievers/self_query/qdrant.py index e421eef023..5d5d2a0469 100644 --- a/libs/langchain/langchain/retrievers/self_query/qdrant.py +++ b/libs/langchain/langchain/retrievers/self_query/qdrant.py @@ -31,7 +31,13 @@ class QdrantTranslator(Visitor): self.metadata_key = metadata_key def visit_operation(self, operation: Operation) -> rest.Filter: - from qdrant_client.http import models as rest + try: + from qdrant_client.http import models as rest + except ImportError as e: + raise ImportError( + "Cannot import qdrant_client. Please install with `pip install " + "qdrant-client`." + ) from e args = [arg.accept(self) for arg in operation.arguments] operator = { @@ -42,7 +48,13 @@ class QdrantTranslator(Visitor): return rest.Filter(**{operator: args}) def visit_comparison(self, comparison: Comparison) -> rest.FieldCondition: - from qdrant_client.http import models as rest + try: + from qdrant_client.http import models as rest + except ImportError as e: + raise ImportError( + "Cannot import qdrant_client. Please install with `pip install " + "qdrant-client`." + ) from e self._validate_func(comparison.comparator) attribute = self.metadata_key + "." + comparison.attribute diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py new file mode 100644 index 0000000000..47c74cb5f1 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_chroma.py @@ -0,0 +1,89 @@ +from typing import Dict, Tuple + +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, +) +from langchain.retrievers.self_query.chroma import ChromaTranslator + +DEFAULT_TRANSLATOR = ChromaTranslator() + + +def test_visit_comparison() -> None: + comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) + expected = {"foo": {"$lt": ["1", "2"]}} + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_operation() -> None: + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.LT, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]), + ], + ) + expected = { + "$and": [ + {"foo": {"$lt": 2}}, + {"bar": {"$eq": "baz"}}, + {"abc": {"$lt": ["1", "2"]}}, + ] + } + actual = DEFAULT_TRANSLATOR.visit_operation(op) + assert expected == actual + + +def test_visit_structured_query() -> None: + query = "What is the capital of France?" + structured_query = StructuredQuery( + query=query, + filter=None, + ) + expected: Tuple[str, Dict] = (query, {}) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) + expected = ( + query, + {"filter": {"foo": {"$lt": ["1", "2"]}}}, + ) + structured_query = StructuredQuery( + query=query, + filter=comp, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.LT, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]), + ], + ) + structured_query = StructuredQuery( + query=query, + filter=op, + ) + expected = ( + query, + { + "filter": { + "$and": [ + {"foo": {"$lt": 2}}, + {"bar": {"$eq": "baz"}}, + {"abc": {"$lt": ["1", "2"]}}, + ] + } + }, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py index f39349da37..055d434d30 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_deeplake.py @@ -1,8 +1,11 @@ +from typing import Dict, Tuple + from langchain.chains.query_constructor.ir import ( Comparator, Comparison, Operation, Operator, + StructuredQuery, ) from langchain.retrievers.self_query.deeplake import DeepLakeTranslator @@ -31,3 +34,49 @@ def test_visit_operation() -> None: ) actual = DEFAULT_TRANSLATOR.visit_operation(op) assert expected == actual + + +def test_visit_structured_query() -> None: + query = "What is the capital of France?" + structured_query = StructuredQuery( + query=query, + filter=None, + ) + expected: Tuple[str, Dict] = (query, {}) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) + structured_query = StructuredQuery( + query=query, + filter=comp, + ) + expected = ( + query, + {"tql": "SELECT * WHERE (metadata['foo'] < 1 or metadata['foo'] < 2)"}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.LT, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]), + ], + ) + structured_query = StructuredQuery( + query=query, + filter=op, + ) + expected = ( + query, + { + "tql": "SELECT * WHERE " + "(metadata['foo'] < 2 and metadata['bar'] == 'baz' and " + "(metadata['abc'] < 1 or metadata['abc'] < 2))" + }, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py index d75e2697da..ce54c83349 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_myscale.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple +from typing import Any, Dict, Tuple import pytest @@ -7,6 +7,7 @@ from langchain.chains.query_constructor.ir import ( Comparison, Operation, Operator, + StructuredQuery, ) from langchain.retrievers.self_query.myscale import MyScaleTranslator @@ -42,3 +43,44 @@ def test_visit_operation() -> None: expected = "metadata.foo < 2 AND metadata.bar = 'baz'" actual = DEFAULT_TRANSLATOR.visit_operation(op) assert expected == actual + + +def test_visit_structured_query() -> None: + query = "What is the capital of France?" + structured_query = StructuredQuery( + query=query, + filter=None, + ) + expected: Tuple[str, Dict] = (query, {}) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) + structured_query = StructuredQuery( + query=query, + filter=comp, + ) + expected = ( + query, + {"where_str": "metadata.foo < ['1', '2']"}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.LT, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + ], + ) + structured_query = StructuredQuery( + query=query, + filter=op, + ) + expected = ( + query, + {"where_str": "metadata.foo < 2 AND metadata.bar = 'baz'"}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py index 2927fccd67..7e818dcec5 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_pinecone.py @@ -1,8 +1,11 @@ +from typing import Dict, Tuple + from langchain.chains.query_constructor.ir import ( Comparator, Comparison, Operation, Operator, + StructuredQuery, ) from langchain.retrievers.self_query.pinecone import PineconeTranslator @@ -27,3 +30,45 @@ def test_visit_operation() -> None: expected = {"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}]} actual = DEFAULT_TRANSLATOR.visit_operation(op) assert expected == actual + + +def test_visit_structured_query() -> None: + query = "What is the capital of France?" + + structured_query = StructuredQuery( + query=query, + filter=None, + ) + expected: Tuple[str, Dict] = (query, {}) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"]) + structured_query = StructuredQuery( + query=query, + filter=comp, + ) + expected = ( + query, + {"filter": {"foo": {"$lt": ["1", "2"]}}}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.LT, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + ], + ) + structured_query = StructuredQuery( + query=query, + filter=op, + ) + expected = ( + query, + {"filter": {"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}]}}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py new file mode 100644 index 0000000000..0a7385af45 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_weaviate.py @@ -0,0 +1,88 @@ +from typing import Dict, Tuple + +from langchain.chains.query_constructor.ir import ( + Comparator, + Comparison, + Operation, + Operator, + StructuredQuery, +) +from langchain.retrievers.self_query.weaviate import WeaviateTranslator + +DEFAULT_TRANSLATOR = WeaviateTranslator() + + +def test_visit_comparison() -> None: + comp = Comparison(comparator=Comparator.EQ, attribute="foo", value="1") + expected = {"operator": "Equal", "path": ["foo"], "valueText": "1"} + actual = DEFAULT_TRANSLATOR.visit_comparison(comp) + assert expected == actual + + +def test_visit_operation() -> None: + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.EQ, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + ], + ) + expected = { + "operands": [ + {"operator": "Equal", "path": ["foo"], "valueText": 2}, + {"operator": "Equal", "path": ["bar"], "valueText": "baz"}, + ], + "operator": "And", + } + actual = DEFAULT_TRANSLATOR.visit_operation(op) + assert expected == actual + + +def test_visit_structured_query() -> None: + query = "What is the capital of France?" + + structured_query = StructuredQuery( + query=query, + filter=None, + ) + expected: Tuple[str, Dict] = (query, {}) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + comp = Comparison(comparator=Comparator.EQ, attribute="foo", value="1") + structured_query = StructuredQuery( + query=query, + filter=comp, + ) + expected = ( + query, + {"where_filter": {"path": ["foo"], "operator": "Equal", "valueText": "1"}}, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual + + op = Operation( + operator=Operator.AND, + arguments=[ + Comparison(comparator=Comparator.EQ, attribute="foo", value=2), + Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"), + ], + ) + structured_query = StructuredQuery( + query=query, + filter=op, + ) + expected = ( + query, + { + "where_filter": { + "operator": "And", + "operands": [ + {"path": ["foo"], "operator": "Equal", "valueText": 2}, + {"path": ["bar"], "operator": "Equal", "valueText": "baz"}, + ], + } + }, + ) + actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query) + assert expected == actual