From 0ce5ed24c2a21f2b051a5ee073b076b20c5bfa7f Mon Sep 17 00:00:00 2001 From: Jonathan Dunn Date: Tue, 5 Mar 2024 14:43:34 -0500 Subject: [PATCH] Added support for local models --- installer/client/cli/fabric.py | 8 +++- installer/client/cli/utils.py | 81 +++++++++++++++++++++------------- 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/installer/client/cli/fabric.py b/installer/client/cli/fabric.py index 34c85dd..ddcb633 100755 --- a/installer/client/cli/fabric.py +++ b/installer/client/cli/fabric.py @@ -43,6 +43,8 @@ def main(): parser.add_argument( "--setup", help="Set up your fabric instance", action="store_true" ) + parser.add_argument( + '--local', '-L', help="Use local LLM. Default is llama2", action="store_true") parser.add_argument( "--model", "-m", help="Select the model to use (GPT-4 by default)", default="gpt-4-turbo-preview" ) @@ -90,7 +92,11 @@ def main(): if not os.path.exists(os.path.join(config, "context.md")): print("Please create a context.md file in ~/.config/fabric") sys.exit() - standalone = Standalone(args, args.pattern) + standalone = None + if args.local: + standalone = Standalone(args, args.pattern, local=True) + else: + standalone = Standalone(args, args.pattern) if args.list: try: direct = sorted(os.listdir(config_patterns_directory)) diff --git a/installer/client/cli/utils.py b/installer/client/cli/utils.py index 6d806a6..5d0619f 100644 --- a/installer/client/cli/utils.py +++ b/installer/client/cli/utils.py @@ -1,6 +1,7 @@ import requests import os from openai import OpenAI +import asyncio import pyperclip import sys import platform @@ -15,7 +16,7 @@ env_file = os.path.join(config_directory, ".env") class Standalone: - def __init__(self, args, pattern="", env_file="~/.config/fabric/.env"): + def __init__(self, args, pattern="", env_file="~/.config/fabric/.env", local=False): """ Initialize the class with the provided arguments and environment file. Args: @@ -44,10 +45,24 @@ class Standalone: except FileNotFoundError: print("No API key found. Use the --apikey option to set the key") sys.exit() + self.local = local self.config_pattern_directory = config_directory self.pattern = pattern self.args = args self.model = args.model + if self.local: + if self.args.model == 'gpt-4-turbo-preview': + self.args.model = 'llama2' + + async def localChat(self, messages): + from ollama import AsyncClient + response = await AsyncClient().chat(model=self.args.model, messages=messages) + print(response['message']['content']) + + async def localStream(self, messages): + from ollama import AsyncClient + async for part in await AsyncClient().chat(model=self.args.model, messages=messages, stream=True): + print(part['message']['content'], end='', flush=True) def streamMessage(self, input_data: str, context=""): """ Stream a message and handle exceptions. @@ -87,26 +102,29 @@ class Standalone: else: messages = [user_message] try: - stream = self.client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, - stream=True, - ) - for chunk in stream: - if chunk.choices[0].delta.content is not None: - char = chunk.choices[0].delta.content - buffer += char - if char not in ["\n", " "]: - print(char, end="") - elif char == " ": - print(" ", end="") # Explicitly handle spaces - elif char == "\n": - print() # Handle newlines - sys.stdout.flush() + if self.local: + asyncio.run(self.localStream(messages)) + else: + stream = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.0, + top_p=1, + frequency_penalty=0.1, + presence_penalty=0.1, + stream=True, + ) + for chunk in stream: + if chunk.choices[0].delta.content is not None: + char = chunk.choices[0].delta.content + buffer += char + if char not in ["\n", " "]: + print(char, end="") + elif char == " ": + print(" ", end="") # Explicitly handle spaces + elif char == "\n": + print() # Handle newlines + sys.stdout.flush() except Exception as e: print(f"Error: {e}") print(e) @@ -153,15 +171,18 @@ class Standalone: else: messages = [user_message] try: - response = self.client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.0, - top_p=1, - frequency_penalty=0.1, - presence_penalty=0.1, - ) - print(response.choices[0].message.content) + if self.local: + asyncio.run(self.localChat(messages)) + else: + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.0, + top_p=1, + frequency_penalty=0.1, + presence_penalty=0.1, + ) + print(response.choices[0].message.content) except Exception as e: print(f"Error: {e}") print(e)