implement embedding_search method, and modify send_query and add temperature config

pull/7/head
Saryev Rustam 12 months ago
parent 2a6c5ad707
commit 70fee6d501

@ -1,6 +1,6 @@
[tool.poetry]
name = "talk-codebase"
version = "0.1.31"
version = "0.1.32"
description = "talk-codebase is a powerful tool for querying and analyzing codebases."
authors = ["Saryev Rustam <rustam1997@gmail.com>"]
readme = "README.md"

@ -4,7 +4,7 @@ import fire
import questionary
import yaml
from talk_codebase.LLM import factory_llm
from talk_codebase.llm import factory_llm
from talk_codebase.consts import DEFAULT_CONFIG
config_path = os.path.join(os.path.expanduser("~"), ".talk_codebase_config.yaml")

@ -19,6 +19,7 @@ DEFAULT_CONFIG = {
"model_name": "gpt-3.5-turbo-0613",
"model_path": "models/ggml-gpt4all-j-v1.3-groovy.bin",
"model_type": MODEL_TYPES["OPENAI"],
"temperature": "0.7",
}
LOADER_MAPPING = {

@ -10,6 +10,7 @@ from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.llms import GPT4All
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from talk_codebase.consts import MODEL_TYPES
from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
@ -29,13 +30,25 @@ class BaseLLM:
def _create_model(self):
raise NotImplementedError("Subclasses must implement this method.")
def send_query(self, question):
def embedding_search(self, query, k):
return self.vector_store.search(query, k=k, search_type="similarity")
def send_query(self, query):
k = self.config.get("k")
qa = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": int(k)}),
return_source_documents=True)
answer = qa(question)
print('\n' + '\n'.join([f'📄 {os.path.abspath(s.metadata["source"])}:' for s in answer["source_documents"]]))
docs = self.embedding_search(query, k=int(k))
content = "\n".join([f"content: \n```{s.page_content}```" for s in docs])
prompt = f"Given the following content, your task is to answer the question. \n{content}"
messages = [
SystemMessage(content=prompt),
HumanMessage(content=query),
]
self.llm(messages)
file_paths = [os.path.abspath(s.metadata["source"]) for s in docs]
print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
def _create_vector_store(self, embeddings, index, root_dir):
index_path = os.path.join(root_dir, f"vector_store/{index}")
@ -95,10 +108,12 @@ class OpenAILLM(BaseLLM):
return self._create_vector_store(embeddings, MODEL_TYPES["OPENAI"], root_dir)
def _create_model(self):
return ChatOpenAI(model_name=self.config.get("model_name"), openai_api_key=self.config.get("api_key"),
return ChatOpenAI(model_name=self.config.get("model_name"),
openai_api_key=self.config.get("api_key"),
streaming=True,
max_tokens=int(self.config.get("max_tokens")),
callback_manager=CallbackManager([StreamStdOut()]))
callback_manager=CallbackManager([StreamStdOut()]),
temperature=float(self.config.get("temperature")))
def factory_llm(root_dir, config):
Loading…
Cancel
Save