added copy to local models and claude

This commit is contained in:
xssdoctor 2024-03-13 20:13:57 -04:00
parent d17dafe46c
commit 3ec5058f8d
2 changed files with 17 additions and 12 deletions

View File

@ -37,8 +37,6 @@ def main():
parser.add_argument(
"--list", "-l", help="List available patterns", action="store_true"
)
parser.add_argument('--clear', help="Clears your persistent model choice so that you can once again use the --model flag",
action="store_true")
parser.add_argument(
"--update", "-u", help="Update patterns. NOTE: This will revert the default model to gpt4-turbo. please run --changeDefaultModel to once again set default model", action="store_true")
parser.add_argument("--pattern", "-p", help="The pattern (prompt) to use")
@ -100,10 +98,6 @@ def main():
if not os.path.exists(os.path.join(config, "context.md")):
print("Please create a context.md file in ~/.config/fabric")
sys.exit()
if args.clear:
Setup().clean_env()
print("Model choice cleared. please restart your session to use the --model flag.")
sys.exit()
standalone = Standalone(args, args.pattern)
if args.list:
try:

View File

@ -38,7 +38,8 @@ class Standalone:
load_dotenv(env_file)
assert 'OPENAI_API_KEY' in os.environ, "Error: OPENAI_API_KEY not found in environment variables. Please run fabric --setup and add a key."
api_key = os.environ['OPENAI_API_KEY']
base_url = os.environ.get('OPENAI_BASE_URL', 'https://api.openai.com/v1/')
base_url = os.environ.get(
'OPENAI_BASE_URL', 'https://api.openai.com/v1/')
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.local = False
self.config_pattern_directory = config_directory
@ -65,6 +66,9 @@ class Standalone:
else:
response = await AsyncClient().chat(model=self.model, messages=messages)
print(response['message']['content'])
copy = self.args.copy
if copy:
pyperclip.copy(response['message']['content'])
async def localStream(self, messages, host=''):
from ollama import AsyncClient
@ -91,7 +95,7 @@ class Standalone:
message = await stream.get_final_message()
async def claudeChat(self, system, user):
async def claudeChat(self, system, user, copy=False):
from anthropic import Anthropic
self.claudeApiKey = os.environ["CLAUDE_API_KEY"]
client = Anthropic(api_key=self.claudeApiKey)
@ -103,6 +107,9 @@ class Standalone:
temperature=0.0, top_p=1.0
)
print(message.content[0].text)
copy = self.args.copy
if copy:
pyperclip.copy(message.content[0].text)
def streamMessage(self, input_data: str, context="", host=''):
""" Stream a message and handle exceptions.
@ -279,14 +286,16 @@ class Standalone:
print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.")
else:
print(f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}")
print(
f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}")
sys.exit()
except Exception as e:
print(f"Error: {getattr(e.__context__, 'args', [''])[0]}")
sys.exit()
if "/" in models[0] or "\\" in models[0]:
# lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash
gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
gptlist = [item[item.rfind(
"/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
else:
# Keep items that start with "gpt"
gptlist = [item for item in models if item.startswith("gpt")]
@ -447,14 +456,16 @@ class Setup:
if getattr(e.__cause__, 'args', [''])[0] == "Illegal header value b'Bearer '":
print("Error: Cannot connect to the OpenAI API Server because the API key is not set. Please run fabric --setup and add a key.")
else:
print(f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}")
print(
f"Error: {e.message} trying to access {e.request.url}: {getattr(e.__cause__, 'args', [''])}")
sys.exit()
except Exception as e:
print(f"Error: {getattr(e.__context__, 'args', [''])[0]}")
sys.exit()
if "/" in models[0] or "\\" in models[0]:
# lmstudio returns full paths to models. Iterate and truncate everything before and including the last slash
self.gptlist = [item[item.rfind("/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
self.gptlist = [item[item.rfind(
"/") + 1:] if "/" in item else item[item.rfind("\\") + 1:] for item in models]
else:
# Keep items that start with "gpt"
self.gptlist = [item for item in models if item.startswith("gpt")]