mongodb[patch]: Added scoring threshold to caching (#19286)

## Description
Semantic Cache can retrieve noisy information if the score threshold for
the value is too low. Adding the ability to set a `score_threshold` on
cache construction can allow for less noisy scores to appear.


- [x] **Add tests and docs**
  1. Added tests that confirm the `score_threshold` query is valid.


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/19287/head
Jib 7 months ago committed by GitHub
parent 30e4a35d7a
commit f8078e41e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -217,6 +217,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
database_name: str = "default",
index_name: str = "default",
wait_until_ready: bool = False,
score_threshold: Optional[float] = None,
**kwargs: Dict[str, Any],
):
"""
@ -237,6 +238,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
"""
client = _generate_mongo_client(connection_string)
self.collection = client[database_name][collection_name]
self.score_threshold = score_threshold
self._wait_until_ready = wait_until_ready
super().__init__(
collection=self.collection,
@ -247,8 +249,17 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
"""Look up based on prompt and llm_string."""
post_filter_pipeline = (
[{"$match": {"score": {"$gte": self.score_threshold}}}]
if self.score_threshold
else None
)
search_response = self.similarity_search_with_score(
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
prompt,
1,
pre_filter={self.LLM: {"$eq": llm_string}},
post_filter_pipeline=post_filter_pipeline,
)
if search_response:
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-mongodb"
version = "0.1.2"
version = "0.1.3"
description = "An integration package connecting MongoDB and LangChain"
authors = []
readme = "README.md"

@ -30,6 +30,7 @@ def llm_cache(cls: Any) -> BaseCache:
collection_name=COLLECTION,
database_name=DATABASE,
index_name=INDEX_NAME,
score_threshold=0.5,
wait_until_ready=True,
)
)
@ -92,13 +93,17 @@ def _execute_test(
],
)
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
@pytest.mark.parametrize("remove_score", [True, False])
def test_mongodb_cache(
remove_score: bool,
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
llm_cache(cacher)
if remove_score:
get_llm_cache().score_threshold = None # type: ignore
try:
_execute_test(prompt, llm, response)
finally:

@ -54,6 +54,7 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
):
self.collection = MockCollection()
self._wait_until_ready = False
self.score_threshold = None
MongoDBAtlasVectorSearch.__init__(
self,
self.collection,
@ -144,13 +145,17 @@ def _execute_test(
@pytest.mark.parametrize(
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
)
@pytest.mark.parametrize("remove_score", [True, False])
def test_mongodb_cache(
remove_score: bool,
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
prompt: Union[str, List[BaseMessage]],
llm: Union[str, FakeLLM, FakeChatModel],
response: List[Generation],
) -> None:
llm_cache(cacher)
if remove_score:
get_llm_cache().score_threshold = None # type: ignore
try:
_execute_test(prompt, llm, response)
finally:

Loading…
Cancel
Save