diff --git a/langchain/vectorstores/weaviate.py b/langchain/vectorstores/weaviate.py index a7398ae7..b4ed3253 100644 --- a/langchain/vectorstores/weaviate.py +++ b/langchain/vectorstores/weaviate.py @@ -195,6 +195,8 @@ class Weaviate(VectorStore): query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) result = query_obj.with_near_text(content).with_limit(k).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") @@ -212,6 +214,8 @@ class Weaviate(VectorStore): query_obj = self._client.query.get(self._index_name, self._query_attrs) if kwargs.get("where_filter"): query_obj = query_obj.with_where(kwargs.get("where_filter")) + if kwargs.get("additional"): + query_obj = query_obj.with_additional(kwargs.get("additional")) result = query_obj.with_near_vector(vector).with_limit(k).do() if "errors" in result: raise ValueError(f"Error during query: {result['errors']}") diff --git a/tests/integration_tests/vectorstores/test_weaviate.py b/tests/integration_tests/vectorstores/test_weaviate.py index 127695ec..a9e1736a 100644 --- a/tests/integration_tests/vectorstores/test_weaviate.py +++ b/tests/integration_tests/vectorstores/test_weaviate.py @@ -81,6 +81,28 @@ class TestWeaviate: ) assert output == [Document(page_content="foo", metadata={"page": 0})] + @pytest.mark.vcr(ignore_localhost=True) + def test_similarity_search_with_metadata_and_additional( + self, weaviate_url: str, embedding_openai: OpenAIEmbeddings + ) -> None: + """Test end to end construction and search with metadata and additional.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = Weaviate.from_texts( + texts, embedding_openai, metadatas=metadatas, weaviate_url=weaviate_url + ) + output = docsearch.similarity_search( + "foo", + k=1, + additional=["certainty"], + ) + assert output == [ + Document( + page_content="foo", + metadata={"page": 0, "_additional": {"certainty": 1}}, + ) + ] + @pytest.mark.vcr(ignore_localhost=True) def test_similarity_search_with_uuids( self, weaviate_url: str, embedding_openai: OpenAIEmbeddings