mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add missing test for retrievers self_query (#8783)
# What - Add missing test for retrievers self_query - Add missing import validation <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: Add missing test for retrievers self_query - Issue: None - Dependencies: None - Tag maintainer: @rlancemartin, @eyurtsev - Twitter handle: @MlopsJ Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
1bd4890506
commit
d9bc46186d
@ -31,7 +31,13 @@ class QdrantTranslator(Visitor):
|
||||
self.metadata_key = metadata_key
|
||||
|
||||
def visit_operation(self, operation: Operation) -> rest.Filter:
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import qdrant_client. Please install with `pip install "
|
||||
"qdrant-client`."
|
||||
) from e
|
||||
|
||||
args = [arg.accept(self) for arg in operation.arguments]
|
||||
operator = {
|
||||
@ -42,7 +48,13 @@ class QdrantTranslator(Visitor):
|
||||
return rest.Filter(**{operator: args})
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> rest.FieldCondition:
|
||||
try:
|
||||
from qdrant_client.http import models as rest
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Cannot import qdrant_client. Please install with `pip install "
|
||||
"qdrant-client`."
|
||||
) from e
|
||||
|
||||
self._validate_func(comparison.comparator)
|
||||
attribute = self.metadata_key + "." + comparison.attribute
|
||||
|
@ -0,0 +1,89 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.chroma import ChromaTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = ChromaTranslator()
|
||||
|
||||
|
||||
def test_visit_comparison() -> None:
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
expected = {"foo": {"$lt": ["1", "2"]}}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
"$and": [
|
||||
{"foo": {"$lt": 2}},
|
||||
{"bar": {"$eq": "baz"}},
|
||||
{"abc": {"$lt": ["1", "2"]}},
|
||||
]
|
||||
}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
expected = (
|
||||
query,
|
||||
{"filter": {"foo": {"$lt": ["1", "2"]}}},
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{
|
||||
"filter": {
|
||||
"$and": [
|
||||
{"foo": {"$lt": 2}},
|
||||
{"bar": {"$eq": "baz"}},
|
||||
{"abc": {"$lt": ["1", "2"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
@ -1,8 +1,11 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
|
||||
|
||||
@ -31,3 +34,49 @@ def test_visit_operation() -> None:
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"tql": "SELECT * WHERE (metadata['foo'] < 1 or metadata['foo'] < 2)"},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
Comparison(comparator=Comparator.LT, attribute="abc", value=["1", "2"]),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{
|
||||
"tql": "SELECT * WHERE "
|
||||
"(metadata['foo'] < 2 and metadata['bar'] == 'baz' and "
|
||||
"(metadata['abc'] < 1 or metadata['abc'] < 2))"
|
||||
},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -7,6 +7,7 @@ from langchain.chains.query_constructor.ir import (
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
||||
|
||||
@ -42,3 +43,44 @@ def test_visit_operation() -> None:
|
||||
expected = "metadata.foo < 2 AND metadata.bar = 'baz'"
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"where_str": "metadata.foo < ['1', '2']"},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"where_str": "metadata.foo < 2 AND metadata.bar = 'baz'"},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
@ -1,8 +1,11 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.pinecone import PineconeTranslator
|
||||
|
||||
@ -27,3 +30,45 @@ def test_visit_operation() -> None:
|
||||
expected = {"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}]}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"filter": {"foo": {"$lt": ["1", "2"]}}},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"filter": {"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}]}},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
@ -0,0 +1,88 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = WeaviateTranslator()
|
||||
|
||||
|
||||
def test_visit_comparison() -> None:
|
||||
comp = Comparison(comparator=Comparator.EQ, attribute="foo", value="1")
|
||||
expected = {"operator": "Equal", "path": ["foo"], "valueText": "1"}
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
"operands": [
|
||||
{"operator": "Equal", "path": ["foo"], "valueText": 2},
|
||||
{"operator": "Equal", "path": ["bar"], "valueText": "baz"},
|
||||
],
|
||||
"operator": "And",
|
||||
}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.EQ, attribute="foo", value="1")
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{"where_filter": {"path": ["foo"], "operator": "Equal", "valueText": "1"}},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{
|
||||
"where_filter": {
|
||||
"operator": "And",
|
||||
"operands": [
|
||||
{"path": ["foo"], "operator": "Equal", "valueText": 2},
|
||||
{"path": ["bar"], "operator": "Equal", "valueText": "baz"},
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue
Block a user