|
|
|
@ -5,7 +5,6 @@ from langchain_community.chat_models import ChatOllama
|
|
|
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
import csv
|
|
|
|
|
import json
|
|
|
|
@ -23,24 +22,23 @@ with open('data/noc.csv', newline='') as csvfile:
|
|
|
|
|
for row in csv.DictReader(csvfile)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
def include_page(page):
|
|
|
|
|
return str(page['code']) not in ['11', '1', '0', '14', '12', '13', '10' ]
|
|
|
|
|
# Filter out duplicate codes
|
|
|
|
|
def include_code(row):
|
|
|
|
|
return row['code'] not in ['11', '1', '0', '14', '12', '13', '10' ]
|
|
|
|
|
|
|
|
|
|
def to_page_content(page):
|
|
|
|
|
return json.dumps(page)
|
|
|
|
|
filtered_noc_codes = [code for code in noc_codes if include_code(code)]
|
|
|
|
|
|
|
|
|
|
filtered_noc_codes = [page for page in noc_codes if include_page(page)]
|
|
|
|
|
def to_page_content(code):
|
|
|
|
|
return json.dumps(code)
|
|
|
|
|
|
|
|
|
|
nested_docs = [[Document(page_content=to_page_content(page)) for page in filtered_noc_codes]]
|
|
|
|
|
documents = [Document(page_content=to_page_content(code)) for code in filtered_noc_codes]
|
|
|
|
|
print('total documents included = ', len(documents))
|
|
|
|
|
|
|
|
|
|
# Sources
|
|
|
|
|
# https://www.youtube.com/watch?v=jENqvjpkwmw
|
|
|
|
|
|
|
|
|
|
model_local = ChatOllama(model="mistral")
|
|
|
|
|
|
|
|
|
|
flattened_docs = [doc for sublist in nested_docs for doc in sublist]
|
|
|
|
|
print('total documents included = ', len(flattened_docs))
|
|
|
|
|
|
|
|
|
|
# TODO don't build the vectors each time, store in a vector database, this needs to be persisted, maybe local redis
|
|
|
|
|
|
|
|
|
|
def load_embeddings():
|
|
|
|
@ -52,16 +50,15 @@ def load_embeddings():
|
|
|
|
|
|
|
|
|
|
def compute_embeddings():
|
|
|
|
|
return Chroma.from_documents(
|
|
|
|
|
documents=flattened_docs,
|
|
|
|
|
documents=documents,
|
|
|
|
|
collection_name="rag-chroma",
|
|
|
|
|
embedding=embeddings.ollama.OllamaEmbeddings(model='nomic-embed-text'),
|
|
|
|
|
persist_directory="./chroma_db"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def load_or_compute_embeddings():
|
|
|
|
|
if os.path.isfile("./chroma_db/chroma.sqlite3"):
|
|
|
|
|
return load_embeddings()
|
|
|
|
|
return compute_embeddings()
|
|
|
|
|
embeddings_exist = os.path.isfile("./chroma_db/chroma.sqlite3")
|
|
|
|
|
return load_embeddings() if embeddings_exist else compute_embeddings()
|
|
|
|
|
|
|
|
|
|
embeddings = load_or_compute_embeddings()
|
|
|
|
|
|
|
|
|
|