mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[patch]: Support for old clients (Thin and Thick) Oracle Vector Store (#22766)
Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" Support for old clients (Thin and Thick) Oracle Vector Store - [ ] **PR message**: ***Delete this entire checklist*** and replace with Support for old clients (Thin and Thick) Oracle Vector Store - [ ] **Add tests and docs**: If you're adding a new integration, please include Have our own local tests --------- Co-authored-by: rohan.aggarwal@oracle.com <rohaagga@phoenix95642.dev3sub2phx.databasede3phx.oraclevcn.com>
This commit is contained in:
parent
232908a46d
commit
86e8224cf1
@ -90,6 +90,22 @@ def _table_exists(client: Connection, table_name: str) -> bool:
|
||||
raise
|
||||
|
||||
|
||||
def _compare_version(version: str, target_version: str) -> bool:
|
||||
# Split both version strings into parts
|
||||
version_parts = [int(part) for part in version.split(".")]
|
||||
target_parts = [int(part) for part in target_version.split(".")]
|
||||
|
||||
# Compare each part
|
||||
for v, t in zip(version_parts, target_parts):
|
||||
if v < t:
|
||||
return True # Current version is less
|
||||
elif v > t:
|
||||
return False # Current version is greater
|
||||
|
||||
# If all parts equal so far, check if version has fewer parts than target_version
|
||||
return len(version_parts) < len(target_parts)
|
||||
|
||||
|
||||
@_handle_exceptions
|
||||
def _index_exists(client: Connection, index_name: str) -> bool:
|
||||
# Check if the index exists
|
||||
@ -401,6 +417,39 @@ class OracleVS(VectorStore):
|
||||
"`pip install -U oracledb`."
|
||||
) from e
|
||||
|
||||
self.insert_mode = "array"
|
||||
|
||||
if client.thin is True:
|
||||
if oracledb.__version__ == "2.1.0":
|
||||
raise Exception(
|
||||
"Oracle DB python thin client driver version 2.1.0 not supported"
|
||||
)
|
||||
elif _compare_version(oracledb.__version__, "2.2.0"):
|
||||
self.insert_mode = "clob"
|
||||
else:
|
||||
self.insert_mode = "array"
|
||||
else:
|
||||
if (_compare_version(oracledb.__version__, "2.1.0")) and (
|
||||
not (
|
||||
_compare_version(
|
||||
".".join(map(str, oracledb.clientversion())), "23.4"
|
||||
)
|
||||
)
|
||||
):
|
||||
raise Exception(
|
||||
"Oracle DB python thick client driver version earlier than "
|
||||
"2.1.0 not supported with client libraries greater than "
|
||||
"equal to 23.4"
|
||||
)
|
||||
|
||||
if _compare_version(".".join(map(str, oracledb.clientversion())), "23.4"):
|
||||
self.insert_mode = "clob"
|
||||
else:
|
||||
self.insert_mode = "array"
|
||||
|
||||
if _compare_version(oracledb.__version__, "2.1.0"):
|
||||
self.insert_mode = "clob"
|
||||
|
||||
try:
|
||||
"""Initialize with oracledb client."""
|
||||
self.client = client
|
||||
@ -520,12 +569,22 @@ class OracleVS(VectorStore):
|
||||
embeddings = self._embed_documents(texts)
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
docs = [
|
||||
(id_, text, json.dumps(metadata), array.array("f", embedding))
|
||||
for id_, text, metadata, embedding in zip(
|
||||
processed_ids, texts, metadatas, embeddings
|
||||
)
|
||||
]
|
||||
|
||||
docs: List[Tuple[Any, Any, Any, Any]]
|
||||
if self.insert_mode == "clob":
|
||||
docs = [
|
||||
(id_, text, json.dumps(metadata), json.dumps(embedding))
|
||||
for id_, text, metadata, embedding in zip(
|
||||
processed_ids, texts, metadatas, embeddings
|
||||
)
|
||||
]
|
||||
else:
|
||||
docs = [
|
||||
(id_, text, json.dumps(metadata), array.array("f", embedding))
|
||||
for id_, text, metadata, embedding in zip(
|
||||
processed_ids, texts, metadatas, embeddings
|
||||
)
|
||||
]
|
||||
|
||||
with self.client.cursor() as cursor:
|
||||
cursor.executemany(
|
||||
@ -613,7 +672,12 @@ class OracleVS(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
docs_and_scores = []
|
||||
embedding_arr = array.array("f", embedding)
|
||||
|
||||
embedding_arr: Any
|
||||
if self.insert_mode == "clob":
|
||||
embedding_arr = json.dumps(embedding)
|
||||
else:
|
||||
embedding_arr = array.array("f", embedding)
|
||||
|
||||
query = f"""
|
||||
SELECT id,
|
||||
@ -671,8 +735,13 @@ class OracleVS(VectorStore):
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float, np.ndarray[np.float32, Any]]]:
|
||||
embedding_arr: Any
|
||||
if self.insert_mode == "clob":
|
||||
embedding_arr = json.dumps(embedding)
|
||||
else:
|
||||
embedding_arr = array.array("f", embedding)
|
||||
|
||||
documents = []
|
||||
embedding_arr = array.array("f", embedding)
|
||||
|
||||
query = f"""
|
||||
SELECT id,
|
||||
@ -705,6 +774,7 @@ class OracleVS(VectorStore):
|
||||
page_content=page_content_str, metadata=metadata
|
||||
)
|
||||
distance = result[3]
|
||||
|
||||
# Assuming result[4] is already in the correct format;
|
||||
# adjust if necessary
|
||||
current_embedding = (
|
||||
@ -712,6 +782,7 @@ class OracleVS(VectorStore):
|
||||
if result[4]
|
||||
else np.empty(0, dtype=np.float32)
|
||||
)
|
||||
|
||||
documents.append((document, distance, current_embedding))
|
||||
return documents # type: ignore
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user