fix send_query

pull/7/head
Saryev Rustam 12 months ago
parent 761210d861
commit 6d72db88bb

@ -1,6 +1,6 @@
[tool.poetry]
name = "talk-codebase"
version = "0.1.35"
version = "0.1.36"
description = "talk-codebase is a powerful tool for querying and analyzing codebases."
authors = ["Saryev Rustam <rustam1997@gmail.com>"]
readme = "README.md"

@ -4,13 +4,13 @@ from typing import Optional
import questionary
from halo import Halo
from langchain import FAISS
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.chains import RetrievalQA
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings, OpenAIEmbeddings
from langchain.llms import GPT4All
from langchain.schema import HumanMessage, SystemMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import AIMessage, HumanMessage, SystemMessage
from talk_codebase.consts import MODEL_TYPES
from talk_codebase.utils import load_files, get_local_vector_store, calculate_cost, StreamStdOut
@ -33,23 +33,6 @@ class BaseLLM:
def embedding_search(self, query, k):
return self.vector_store.search(query, k=k, search_type="similarity")
def send_query(self, query):
k = self.config.get("k")
docs = self.embedding_search(query, k=int(k))
content = "\n".join([f"content: \n```{s.page_content}```" for s in docs])
prompt = f"Given the following content, your task is to answer the question. \n{content}"
messages = [
SystemMessage(content=prompt),
HumanMessage(content=query),
]
self.llm(messages)
file_paths = [os.path.abspath(s.metadata["source"]) for s in docs]
print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
def _create_vector_store(self, embeddings, index, root_dir):
index_path = os.path.join(root_dir, f"vector_store/{index}")
new_db = get_local_vector_store(embeddings, index_path)
@ -101,6 +84,24 @@ class LocalLLM(BaseLLM):
llm = GPT4All(model=self.config.get("model_path"), n_ctx=int(self.config.get("max_tokens")), streaming=True)
return llm
def send_query(self, query):
k = self.config.get("k")
docs = self.embedding_search(query, k=int(k))
content = "\n".join([f"content: \n```{s.page_content}```" for s in docs])
template = """Given the following content, your task is to answer the question.
Content: {content}
Question: {question}
"""
prompt = PromptTemplate(template=template, input_variables=["content", "question"]).partial(content=content)
llm_chain = LLMChain(prompt=prompt, llm=self.llm)
llm_chain.run(query)
file_paths = [os.path.abspath(s.metadata["source"]) for s in docs]
print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
class OpenAILLM(BaseLLM):
def _create_store(self, root_dir: str) -> Optional[FAISS]:
@ -115,6 +116,23 @@ class OpenAILLM(BaseLLM):
callback_manager=CallbackManager([StreamStdOut()]),
temperature=float(self.config.get("temperature")))
def send_query(self, query):
k = self.config.get("k")
docs = self.embedding_search(query, k=int(k))
content = "\n".join([f"content: \n```{s.page_content}```" for s in docs])
prompt = f"Given the following content, your task is to answer the question. \n{content}"
messages = [
SystemMessage(content=prompt),
HumanMessage(content=query),
]
self.llm(messages)
file_paths = [os.path.abspath(s.metadata["source"]) for s in docs]
print('\n'.join([f'📄 {file_path}:' for file_path in file_paths]))
def factory_llm(root_dir, config):
if config.get("model_type") == "openai":

Loading…
Cancel
Save