From de9e5455429bcf08856b553594bbb1dade70a65c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=20=E6=96=B9=E7=91=9E?= Date: Mon, 4 Sep 2023 23:40:58 +0800 Subject: [PATCH] MyScale hot fix on type check (#10180) Previous PR #9353 has incomplete type checks and deprecation warnings. This PR will fix those type check and add deprecation warning to myscale vectorstore --- libs/langchain/langchain/vectorstores/myscale.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/vectorstores/myscale.py b/libs/langchain/langchain/vectorstores/myscale.py index 3c4361fcdc..81812550c7 100644 --- a/libs/langchain/langchain/vectorstores/myscale.py +++ b/libs/langchain/langchain/vectorstores/myscale.py @@ -147,7 +147,12 @@ class MyScale(VectorStore): ) for k in ["id", "vector", "text", "metadata"]: assert k in self.config.column_map - assert self.config.metric in ["ip", "cosine", "l2"] + assert self.config.metric.upper() in ["IP", "COSINE", "L2"] + if self.config.metric in ["ip", "cosine", "l2"]: + logger.warning( + "Lower case metric types will be deprecated " + "the future. Please use one of ('IP', 'Cosine', 'L2')" + ) # initialize the schema dim = len(embedding.embed_query("try this out")) @@ -174,7 +179,9 @@ class MyScale(VectorStore): self.BS = "\\" self.must_escape = ("\\", "'") self._embeddings = embedding - self.dist_order = "ASC" if self.config.metric in ["cosine", "l2"] else "DESC" + self.dist_order = ( + "ASC" if self.config.metric.upper() in ["COSINE", "L2"] else "DESC" + ) # Create a connection to myscale self.client = get_client(