pull/1111/head
Rasmus Storjohann 3 months ago
parent 1946b62bbc
commit 902697dc09

@ -5,7 +5,6 @@ from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
import csv
import json
@ -23,24 +22,23 @@ with open('data/noc.csv', newline='') as csvfile:
for row in csv.DictReader(csvfile)
]
def include_page(page):
return str(page['code']) not in ['11', '1', '0', '14', '12', '13', '10' ]
# Filter out duplicate codes
def include_code(row):
return row['code'] not in ['11', '1', '0', '14', '12', '13', '10' ]
def to_page_content(page):
return json.dumps(page)
filtered_noc_codes = [code for code in noc_codes if include_code(code)]
filtered_noc_codes = [page for page in noc_codes if include_page(page)]
def to_page_content(code):
return json.dumps(code)
nested_docs = [[Document(page_content=to_page_content(page)) for page in filtered_noc_codes]]
documents = [Document(page_content=to_page_content(code)) for code in filtered_noc_codes]
print('total documents included = ', len(documents))
# Sources
# https://www.youtube.com/watch?v=jENqvjpkwmw
model_local = ChatOllama(model="mistral")
flattened_docs = [doc for sublist in nested_docs for doc in sublist]
print('total documents included = ', len(flattened_docs))
# TODO don't build the vectors each time, store in a vector database, this needs to be persisted, maybe local redis
def load_embeddings():
@ -52,16 +50,15 @@ def load_embeddings():
def compute_embeddings():
return Chroma.from_documents(
documents=flattened_docs,
documents=documents,
collection_name="rag-chroma",
embedding=embeddings.ollama.OllamaEmbeddings(model='nomic-embed-text'),
persist_directory="./chroma_db"
)
def load_or_compute_embeddings():
if os.path.isfile("./chroma_db/chroma.sqlite3"):
return load_embeddings()
return compute_embeddings()
embeddings_exist = os.path.isfile("./chroma_db/chroma.sqlite3")
return load_embeddings() if embeddings_exist else compute_embeddings()
embeddings = load_or_compute_embeddings()

Loading…
Cancel
Save