diff --git a/libs/partners/mongodb/langchain_mongodb/cache.py b/libs/partners/mongodb/langchain_mongodb/cache.py index b55840efa9..3dbde907b3 100644 --- a/libs/partners/mongodb/langchain_mongodb/cache.py +++ b/libs/partners/mongodb/langchain_mongodb/cache.py @@ -217,6 +217,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): database_name: str = "default", index_name: str = "default", wait_until_ready: bool = False, + score_threshold: Optional[float] = None, **kwargs: Dict[str, Any], ): """ @@ -237,6 +238,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): """ client = _generate_mongo_client(connection_string) self.collection = client[database_name][collection_name] + self.score_threshold = score_threshold self._wait_until_ready = wait_until_ready super().__init__( collection=self.collection, @@ -247,8 +249,17 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch): def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: """Look up based on prompt and llm_string.""" + post_filter_pipeline = ( + [{"$match": {"score": {"$gte": self.score_threshold}}}] + if self.score_threshold + else None + ) + search_response = self.similarity_search_with_score( - prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}} + prompt, + 1, + pre_filter={self.LLM: {"$eq": llm_string}}, + post_filter_pipeline=post_filter_pipeline, ) if search_response: return_val = search_response[0][0].metadata.get(self.RETURN_VAL) diff --git a/libs/partners/mongodb/pyproject.toml b/libs/partners/mongodb/pyproject.toml index dcc95c73d9..5efd814bd0 100644 --- a/libs/partners/mongodb/pyproject.toml +++ b/libs/partners/mongodb/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-mongodb" -version = "0.1.2" +version = "0.1.3" description = "An integration package connecting MongoDB and LangChain" authors = [] readme = "README.md" diff --git a/libs/partners/mongodb/tests/integration_tests/test_cache.py b/libs/partners/mongodb/tests/integration_tests/test_cache.py index 27e5846ac3..8d99f5e059 100644 --- a/libs/partners/mongodb/tests/integration_tests/test_cache.py +++ b/libs/partners/mongodb/tests/integration_tests/test_cache.py @@ -30,6 +30,7 @@ def llm_cache(cls: Any) -> BaseCache: collection_name=COLLECTION, database_name=DATABASE, index_name=INDEX_NAME, + score_threshold=0.5, wait_until_ready=True, ) ) @@ -92,13 +93,17 @@ def _execute_test( ], ) @pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache]) +@pytest.mark.parametrize("remove_score", [True, False]) def test_mongodb_cache( + remove_score: bool, cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache], prompt: Union[str, List[BaseMessage]], llm: Union[str, FakeLLM, FakeChatModel], response: List[Generation], ) -> None: llm_cache(cacher) + if remove_score: + get_llm_cache().score_threshold = None # type: ignore try: _execute_test(prompt, llm, response) finally: diff --git a/libs/partners/mongodb/tests/unit_tests/test_cache.py b/libs/partners/mongodb/tests/unit_tests/test_cache.py index 6b1932ad8c..22cef171d3 100644 --- a/libs/partners/mongodb/tests/unit_tests/test_cache.py +++ b/libs/partners/mongodb/tests/unit_tests/test_cache.py @@ -54,6 +54,7 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache): ): self.collection = MockCollection() self._wait_until_ready = False + self.score_threshold = None MongoDBAtlasVectorSearch.__init__( self, self.collection, @@ -144,13 +145,17 @@ def _execute_test( @pytest.mark.parametrize( "cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache] ) +@pytest.mark.parametrize("remove_score", [True, False]) def test_mongodb_cache( + remove_score: bool, cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache], prompt: Union[str, List[BaseMessage]], llm: Union[str, FakeLLM, FakeChatModel], response: List[Generation], ) -> None: llm_cache(cacher) + if remove_score: + get_llm_cache().score_threshold = None # type: ignore try: _execute_test(prompt, llm, response) finally: