You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/templates/rag-gpt-crawler/rag_gpt_crawler/chain.py

63 lines
1.6 KiB
Python

import json
from pathlib import Path
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema import Document
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableParallel, RunnablePassthrough
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
# Load output from gpt crawler
path_to_gptcrawler = Path(__file__).parent.parent / "output.json"
data = json.loads(Path(path_to_gptcrawler).read_text())
docs = [
Document(
page_content=dict_["html"],
metadata={"title": dict_["title"], "url": dict_["url"]},
)
for dict_ in data
]
# Split
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
all_splits = text_splitter.split_documents(docs)
# Add to vectorDB
vectorstore = Chroma.from_documents(
documents=all_splits,
collection_name="rag-gpt-builder",
embedding=OpenAIEmbeddings(),
)
retriever = vectorstore.as_retriever()
# RAG prompt
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# LLM
model = ChatOpenAI()
# RAG chain
chain = (
RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
| prompt
| model
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
__root__: str
chain = chain.with_types(input_type=Question)