mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
8021d2a2ab
Thank you for contributing to LangChain! - Oracle AI Vector Search Oracle AI Vector Search is designed for Artificial Intelligence (AI) workloads that allows you to query data based on semantics, rather than keywords. One of the biggest benefit of Oracle AI Vector Search is that semantic search on unstructured data can be combined with relational search on business data in one single system. This is not only powerful but also significantly more effective because you don't need to add a specialized vector database, eliminating the pain of data fragmentation between multiple systems. - Oracle AI Vector Search is designed for Artificial Intelligence (AI) workloads that allows you to query data based on semantics, rather than keywords. One of the biggest benefit of Oracle AI Vector Search is that semantic search on unstructured data can be combined with relational search on business data in one single system. This is not only powerful but also significantly more effective because you don't need to add a specialized vector database, eliminating the pain of data fragmentation between multiple systems. This Pull Requests Adds the following functionalities Oracle AI Vector Search : Vector Store Oracle AI Vector Search : Document Loader Oracle AI Vector Search : Document Splitter Oracle AI Vector Search : Summary Oracle AI Vector Search : Oracle Embeddings - We have added unit tests and have our own local unit test suite which verifies all the code is correct. We have made sure to add guides for each of the components and one end to end guide that shows how the entire thing runs. - We have made sure that make format and make lint run clean. Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17. --------- Co-authored-by: skmishraoracle <shailendra.mishra@oracle.com> Co-authored-by: hroyofc <harichandan.roy@oracle.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
448 lines
15 KiB
Python
448 lines
15 KiB
Python
# Authors:
|
|
# Harichandan Roy (hroy)
|
|
# David Jiang (ddjiang)
|
|
#
|
|
# -----------------------------------------------------------------------------
|
|
# oracleai.py
|
|
# -----------------------------------------------------------------------------
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import struct
|
|
import time
|
|
import traceback
|
|
from html.parser import HTMLParser
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from langchain_core.document_loaders import BaseLoader
|
|
from langchain_core.documents import Document
|
|
from langchain_text_splitters import TextSplitter
|
|
|
|
if TYPE_CHECKING:
|
|
from oracledb import Connection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
"""ParseOracleDocMetadata class"""
|
|
|
|
|
|
class ParseOracleDocMetadata(HTMLParser):
|
|
"""Parse Oracle doc metadata..."""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.reset()
|
|
self.match = False
|
|
self.metadata: Dict[str, Any] = {}
|
|
|
|
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
|
|
if tag == "meta":
|
|
entry: Optional[str] = ""
|
|
for name, value in attrs:
|
|
if name == "name":
|
|
entry = value
|
|
if name == "content":
|
|
if entry:
|
|
self.metadata[entry] = value
|
|
elif tag == "title":
|
|
self.match = True
|
|
|
|
def handle_data(self, data: str) -> None:
|
|
if self.match:
|
|
self.metadata["title"] = data
|
|
self.match = False
|
|
|
|
def get_metadata(self) -> Dict[str, Any]:
|
|
return self.metadata
|
|
|
|
|
|
"""OracleDocReader class"""
|
|
|
|
|
|
class OracleDocReader:
|
|
"""Read a file"""
|
|
|
|
@staticmethod
|
|
def generate_object_id(input_string: Union[str, None] = None) -> str:
|
|
out_length = 32 # output length
|
|
hash_len = 8 # hash value length
|
|
|
|
if input_string is None:
|
|
input_string = "".join(
|
|
random.choices(
|
|
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789",
|
|
k=16,
|
|
)
|
|
)
|
|
|
|
# timestamp
|
|
timestamp = int(time.time())
|
|
timestamp_bin = struct.pack(">I", timestamp) # 4 bytes
|
|
|
|
# hash_value
|
|
hashval_bin = hashlib.sha256(input_string.encode()).digest()
|
|
hashval_bin = hashval_bin[:hash_len] # 8 bytes
|
|
|
|
# counter
|
|
counter_bin = struct.pack(">I", random.getrandbits(32)) # 4 bytes
|
|
|
|
# binary object id
|
|
object_id = timestamp_bin + hashval_bin + counter_bin # 16 bytes
|
|
object_id_hex = object_id.hex() # 32 bytes
|
|
object_id_hex = object_id_hex.zfill(
|
|
out_length
|
|
) # fill with zeros if less than 32 bytes
|
|
|
|
object_id_hex = object_id_hex[:out_length]
|
|
|
|
return object_id_hex
|
|
|
|
@staticmethod
|
|
def read_file(
|
|
conn: Connection, file_path: str, params: dict
|
|
) -> Union[Document, None]:
|
|
"""Read a file using OracleReader
|
|
Args:
|
|
conn: Oracle Connection,
|
|
file_path: Oracle Directory,
|
|
params: ONNX file name.
|
|
Returns:
|
|
Plain text and metadata as Langchain Document.
|
|
"""
|
|
|
|
metadata: Dict[str, Any] = {}
|
|
try:
|
|
import oracledb
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import oracledb, please install with "
|
|
"`pip install -U oracledb`."
|
|
) from e
|
|
try:
|
|
oracledb.defaults.fetch_lobs = False
|
|
cursor = conn.cursor()
|
|
|
|
with open(file_path, "rb") as f:
|
|
data = f.read()
|
|
|
|
if data is None:
|
|
return Document(page_content="", metadata=metadata)
|
|
|
|
mdata = cursor.var(oracledb.DB_TYPE_CLOB)
|
|
text = cursor.var(oracledb.DB_TYPE_CLOB)
|
|
cursor.execute(
|
|
"""
|
|
declare
|
|
input blob;
|
|
begin
|
|
input := :blob;
|
|
:mdata := dbms_vector_chain.utl_to_text(input, json(:pref));
|
|
:text := dbms_vector_chain.utl_to_text(input);
|
|
end;""",
|
|
blob=data,
|
|
pref=json.dumps(params),
|
|
mdata=mdata,
|
|
text=text,
|
|
)
|
|
cursor.close()
|
|
|
|
if mdata is None:
|
|
metadata = {}
|
|
else:
|
|
doc_data = str(mdata.getvalue())
|
|
if doc_data.startswith("<!DOCTYPE html") or doc_data.startswith(
|
|
"<HTML>"
|
|
):
|
|
p = ParseOracleDocMetadata()
|
|
p.feed(doc_data)
|
|
metadata = p.get_metadata()
|
|
|
|
doc_id = OracleDocReader.generate_object_id(conn.username + "$" + file_path)
|
|
metadata["_oid"] = doc_id
|
|
metadata["_file"] = file_path
|
|
|
|
if text is None:
|
|
return Document(page_content="", metadata=metadata)
|
|
else:
|
|
return Document(page_content=str(text.getvalue()), metadata=metadata)
|
|
|
|
except Exception as ex:
|
|
logger.info(f"An exception occurred :: {ex}")
|
|
logger.info(f"Skip processing {file_path}")
|
|
cursor.close()
|
|
return None
|
|
|
|
|
|
"""OracleDocLoader class"""
|
|
|
|
|
|
class OracleDocLoader(BaseLoader):
|
|
"""Read documents using OracleDocLoader
|
|
Args:
|
|
conn: Oracle Connection,
|
|
params: Loader parameters.
|
|
"""
|
|
|
|
def __init__(self, conn: Connection, params: Dict[str, Any], **kwargs: Any):
|
|
self.conn = conn
|
|
self.params = json.loads(json.dumps(params))
|
|
super().__init__(**kwargs)
|
|
|
|
def load(self) -> List[Document]:
|
|
"""Load data into LangChain Document objects..."""
|
|
try:
|
|
import oracledb
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import oracledb, please install with "
|
|
"`pip install -U oracledb`."
|
|
) from e
|
|
|
|
ncols = 0
|
|
results: List[Document] = []
|
|
metadata: Dict[str, Any] = {}
|
|
m_params = {"plaintext": "false"}
|
|
try:
|
|
# extract the parameters
|
|
if self.params is not None:
|
|
self.file = self.params.get("file")
|
|
self.dir = self.params.get("dir")
|
|
self.owner = self.params.get("owner")
|
|
self.tablename = self.params.get("tablename")
|
|
self.colname = self.params.get("colname")
|
|
else:
|
|
raise Exception("Missing loader parameters")
|
|
|
|
oracledb.defaults.fetch_lobs = False
|
|
|
|
if self.file:
|
|
doc = OracleDocReader.read_file(self.conn, self.file, m_params)
|
|
|
|
if doc is None:
|
|
return results
|
|
|
|
results.append(doc)
|
|
|
|
if self.dir:
|
|
skip_count = 0
|
|
for file_name in os.listdir(self.dir):
|
|
file_path = os.path.join(self.dir, file_name)
|
|
if os.path.isfile(file_path):
|
|
doc = OracleDocReader.read_file(self.conn, file_path, m_params)
|
|
|
|
if doc is None:
|
|
skip_count = skip_count + 1
|
|
logger.info(f"Total skipped: {skip_count}\n")
|
|
else:
|
|
results.append(doc)
|
|
|
|
if self.tablename:
|
|
try:
|
|
if self.owner is None or self.colname is None:
|
|
raise Exception("Missing owner or column name or both.")
|
|
|
|
cursor = self.conn.cursor()
|
|
self.mdata_cols = self.params.get("mdata_cols")
|
|
if self.mdata_cols is not None:
|
|
if len(self.mdata_cols) > 3:
|
|
raise Exception(
|
|
"Exceeds the max number of columns "
|
|
+ "you can request for metadata."
|
|
)
|
|
|
|
# execute a query to get column data types
|
|
sql = (
|
|
"select column_name, data_type from all_tab_columns "
|
|
+ "where owner = :ownername and "
|
|
+ "table_name = :tablename"
|
|
)
|
|
cursor.execute(
|
|
sql,
|
|
ownername=self.owner.upper(),
|
|
tablename=self.tablename.upper(),
|
|
)
|
|
|
|
# cursor.execute(sql)
|
|
rows = cursor.fetchall()
|
|
for row in rows:
|
|
if row[0] in self.mdata_cols:
|
|
if row[1] not in [
|
|
"NUMBER",
|
|
"BINARY_DOUBLE",
|
|
"BINARY_FLOAT",
|
|
"LONG",
|
|
"DATE",
|
|
"TIMESTAMP",
|
|
"VARCHAR2",
|
|
]:
|
|
raise Exception(
|
|
"The datatype for the column requested "
|
|
+ "for metadata is not supported."
|
|
)
|
|
|
|
self.mdata_cols_sql = ", rowid"
|
|
if self.mdata_cols is not None:
|
|
for col in self.mdata_cols:
|
|
self.mdata_cols_sql = self.mdata_cols_sql + ", " + col
|
|
|
|
# [TODO] use bind variables
|
|
sql = (
|
|
"select dbms_vector_chain.utl_to_text(t."
|
|
+ self.colname
|
|
+ ", json('"
|
|
+ json.dumps(m_params)
|
|
+ "')) mdata, dbms_vector_chain.utl_to_text(t."
|
|
+ self.colname
|
|
+ ") text"
|
|
+ self.mdata_cols_sql
|
|
+ " from "
|
|
+ self.owner
|
|
+ "."
|
|
+ self.tablename
|
|
+ " t"
|
|
)
|
|
|
|
cursor.execute(sql)
|
|
for row in cursor:
|
|
metadata = {}
|
|
|
|
if row is None:
|
|
doc_id = OracleDocReader.generate_object_id(
|
|
self.conn.username
|
|
+ "$"
|
|
+ self.owner
|
|
+ "$"
|
|
+ self.tablename
|
|
+ "$"
|
|
+ self.colname
|
|
)
|
|
metadata["_oid"] = doc_id
|
|
results.append(Document(page_content="", metadata=metadata))
|
|
else:
|
|
if row[0] is not None:
|
|
data = str(row[0])
|
|
if data.startswith("<!DOCTYPE html") or data.startswith(
|
|
"<HTML>"
|
|
):
|
|
p = ParseOracleDocMetadata()
|
|
p.feed(data)
|
|
metadata = p.get_metadata()
|
|
|
|
doc_id = OracleDocReader.generate_object_id(
|
|
self.conn.username
|
|
+ "$"
|
|
+ self.owner
|
|
+ "$"
|
|
+ self.tablename
|
|
+ "$"
|
|
+ self.colname
|
|
+ "$"
|
|
+ str(row[2])
|
|
)
|
|
metadata["_oid"] = doc_id
|
|
metadata["_rowid"] = row[2]
|
|
|
|
# process projected metadata cols
|
|
if self.mdata_cols is not None:
|
|
ncols = len(self.mdata_cols)
|
|
|
|
for i in range(0, ncols):
|
|
metadata[self.mdata_cols[i]] = row[i + 2]
|
|
|
|
if row[1] is None:
|
|
results.append(
|
|
Document(page_content="", metadata=metadata)
|
|
)
|
|
else:
|
|
results.append(
|
|
Document(
|
|
page_content=str(row[1]), metadata=metadata
|
|
)
|
|
)
|
|
except Exception as ex:
|
|
logger.info(f"An exception occurred :: {ex}")
|
|
traceback.print_exc()
|
|
cursor.close()
|
|
raise
|
|
|
|
return results
|
|
except Exception as ex:
|
|
logger.info(f"An exception occurred :: {ex}")
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
class OracleTextSplitter(TextSplitter):
|
|
"""Splitting text using Oracle chunker."""
|
|
|
|
def __init__(self, conn: Connection, params: Dict[str, Any], **kwargs: Any) -> None:
|
|
"""Initialize."""
|
|
self.conn = conn
|
|
self.params = params
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import json
|
|
|
|
try:
|
|
import oracledb
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import oracledb, please install with "
|
|
"`pip install -U oracledb`."
|
|
) from e
|
|
|
|
self._oracledb = oracledb
|
|
self._json = json
|
|
except ImportError:
|
|
raise ImportError(
|
|
"oracledb or json or both are not installed. "
|
|
+ "Please install them. "
|
|
+ "Recommendations: `pip install oracledb`. "
|
|
)
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
|
|
try:
|
|
import oracledb
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Unable to import oracledb, please install with "
|
|
"`pip install -U oracledb`."
|
|
) from e
|
|
|
|
splits = []
|
|
|
|
try:
|
|
# returns strings or bytes instead of a locator
|
|
self._oracledb.defaults.fetch_lobs = False
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
cursor.setinputsizes(content=oracledb.CLOB)
|
|
cursor.execute(
|
|
"select t.column_value from "
|
|
+ "dbms_vector_chain.utl_to_chunks(:content, json(:params)) t",
|
|
content=text,
|
|
params=self._json.dumps(self.params),
|
|
)
|
|
|
|
while True:
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
break
|
|
d = self._json.loads(row[0])
|
|
splits.append(d["chunk_data"])
|
|
|
|
return splits
|
|
|
|
except Exception as ex:
|
|
logger.info(f"An exception occurred :: {ex}")
|
|
traceback.print_exc()
|
|
raise
|