Merge pull request #1213 from thatlukinhasguy1/main

Make the API use FastAPI instead of Flask
This commit is contained in:
Tekky 2023-11-05 19:16:12 +01:00 committed by GitHub
commit d5a499d064
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 204 deletions

View File

@ -1,163 +1,137 @@
import typing from fastapi import FastAPI, Response, Request
from .. import BaseProvider from typing import List, Union, Any, Dict, AnyStr
import g4f; g4f.debug.logging = True from ._tokenizer import tokenize
from .. import BaseProvider
import time import time
import json import json
import random import random
import string import string
import logging import uvicorn
import nest_asyncio
from typing import Union import g4f
from loguru import logger
from waitress import serve
from ._logging import hook_logging
from ._tokenizer import tokenize
from flask_cors import CORS
from werkzeug.serving import WSGIRequestHandler
from werkzeug.exceptions import default_exceptions
from werkzeug.middleware.proxy_fix import ProxyFix
from flask import (
Flask,
jsonify,
make_response,
request,
)
class Api: class Api:
__default_ip = '127.0.0.1'
__default_port = 1337
def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False, def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False,
list_ignored_providers:typing.List[typing.Union[str, BaseProvider]]=None) -> None: list_ignored_providers: List[Union[str, BaseProvider]] = None) -> None:
self.engine = engine self.engine = engine
self.debug = debug self.debug = debug
self.sentry = sentry self.sentry = sentry
self.list_ignored_providers = list_ignored_providers self.list_ignored_providers = list_ignored_providers
self.log_level = logging.DEBUG if debug else logging.WARN
hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') self.app = FastAPI()
self.logger = logging.getLogger('waitress') nest_asyncio.apply()
self.app = Flask(__name__) JSONObject = Dict[AnyStr, Any]
self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) JSONArray = List[Any]
self.app.after_request(self.__after_request) JSONStructure = Union[JSONArray, JSONObject]
def run(self, bind_str, threads=8): @self.app.get("/")
host, port = self.__parse_bind(bind_str) async def read_root():
return Response(content=json.dumps({"info": "g4f API"}, indent=4), media_type="application/json")
CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ @self.app.get("/v1")
'Content-Type', async def read_root_v1():
'Authorization', return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json")
'X-Requested-With',
'Accept',
'Origin',
'Access-Control-Request-Method',
'Access-Control-Request-Headers',
'Content-Disposition'], 'max_age': 600}})
self.app.route('/v1/models', methods=['GET'])(self.models) @self.app.get("/v1/models")
self.app.route('/v1/models/<model_id>', methods=['GET'])(self.model_info) async def models():
model_list = [{
'id': model,
'object': 'model',
'created': 0,
'owned_by': 'g4f'} for model in g4f.Model.__all__()]
self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) return Response(content=json.dumps({
self.app.route('/v1/completions', methods=['POST'])(self.completions) 'object': 'list',
'data': model_list}, indent=4), media_type="application/json")
for ex in default_exceptions: @self.app.get("/v1/models/{model_name}")
self.app.register_error_handler(ex, self.__handle_error) async def model_info(model_name: str):
if not self.debug:
self.logger.warning(f'Serving on http://{host}:{port}')
WSGIRequestHandler.protocol_version = 'HTTP/1.1'
serve(self.app, host=host, port=port, ident=None, threads=threads)
def __handle_error(self, e: Exception):
self.logger.error(e)
return make_response(jsonify({
'code': e.code,
'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500)
@staticmethod
def __after_request(resp):
resp.headers['X-Server'] = f'g4f/{g4f.version}'
return resp
def __parse_bind(self, bind_str):
sections = bind_str.split(':', 2)
if len(sections) < 2:
try: try:
port = int(sections[0]) model_info = (g4f.ModelUtils.convert[model_name])
return self.__default_ip, port
except ValueError:
return sections[0], self.__default_port
return sections[0], int(sections[1]) return Response(content=json.dumps({
'id': model_name,
'object': 'model',
'created': 0,
'owned_by': model_info.base_provider
}, indent=4), media_type="application/json")
except:
return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json")
async def home(self): @self.app.post("/v1/chat/completions")
return 'Hello world | https://127.0.0.1:1337/v1' async def chat_completions(request: Request, item: JSONStructure = None):
item_data = {
async def chat_completions(self): 'model': 'gpt-3.5-turbo',
model = request.json.get('model', 'gpt-3.5-turbo') 'stream': False,
stream = request.json.get('stream', False)
messages = request.json.get('messages')
logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}')
config = None
proxy = None
try:
config = json.load(open("config.json","r",encoding="utf-8"))
proxy = config["proxy"]
except Exception:
pass
if proxy != None:
response = self.engine.ChatCompletion.create(model=model,
stream=stream, messages=messages,
ignored=self.list_ignored_providers,
proxy=proxy)
else:
response = self.engine.ChatCompletion.create(model=model,
stream=stream, messages=messages,
ignored=self.list_ignored_providers)
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
completion_timestamp = int(time.time())
if not stream:
prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
completion_tokens, _ = tokenize(response)
return {
'id': f'chatcmpl-{completion_id}',
'object': 'chat.completion',
'created': completion_timestamp,
'model': model,
'choices': [
{
'index': 0,
'message': {
'role': 'assistant',
'content': response,
},
'finish_reason': 'stop',
}
],
'usage': {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': prompt_tokens + completion_tokens,
},
} }
def streaming(): item_data.update(item or {})
model = item_data.get('model')
stream = item_data.get('stream')
messages = item_data.get('messages')
try: try:
for chunk in response: response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages)
completion_data = { except:
return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json")
completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
completion_timestamp = int(time.time())
if not stream:
prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
completion_tokens, _ = tokenize(response)
json_data = {
'id': f'chatcmpl-{completion_id}',
'object': 'chat.completion',
'created': completion_timestamp,
'model': model,
'choices': [
{
'index': 0,
'message': {
'role': 'assistant',
'content': response,
},
'finish_reason': 'stop',
}
],
'usage': {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': prompt_tokens + completion_tokens,
},
}
return Response(content=json.dumps(json_data, indent=4), media_type="application/json")
def streaming():
try:
for chunk in response:
completion_data = {
'id': f'chatcmpl-{completion_id}',
'object': 'chat.completion.chunk',
'created': completion_timestamp,
'model': model,
'choices': [
{
'index': 0,
'delta': {
'content': chunk,
},
'finish_reason': None,
}
],
}
content = json.dumps(completion_data, separators=(',', ':'))
yield f'data: {content}\n\n'
time.sleep(0.03)
end_completion_data = {
'id': f'chatcmpl-{completion_id}', 'id': f'chatcmpl-{completion_id}',
'object': 'chat.completion.chunk', 'object': 'chat.completion.chunk',
'created': completion_timestamp, 'created': completion_timestamp,
@ -165,63 +139,24 @@ class Api:
'choices': [ 'choices': [
{ {
'index': 0, 'index': 0,
'delta': { 'delta': {},
'content': chunk, 'finish_reason': 'stop',
},
'finish_reason': None,
} }
], ],
} }
content = json.dumps(completion_data, separators=(',', ':')) content = json.dumps(end_completion_data, separators=(',', ':'))
yield f'data: {content}\n\n' yield f'data: {content}\n\n'
time.sleep(0.03)
end_completion_data = { except GeneratorExit:
'id': f'chatcmpl-{completion_id}', pass
'object': 'chat.completion.chunk',
'created': completion_timestamp,
'model': model,
'choices': [
{
'index': 0,
'delta': {},
'finish_reason': 'stop',
}
],
}
content = json.dumps(end_completion_data, separators=(',', ':')) return Response(content=json.dumps(streaming(), indent=4), media_type="application/json")
yield f'data: {content}\n\n'
logger.success(f'model: {model}, stream: {stream}') @self.app.post("/v1/completions")
async def completions():
except GeneratorExit: return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
pass
return self.app.response_class(streaming(), mimetype='text/event-stream')
async def completions(self):
return 'not working yet', 500
async def model_info(self, model_name):
model_info = (g4f.ModelUtils.convert[model_name])
return jsonify({
'id' : model_name,
'object' : 'model',
'created' : 0,
'owned_by' : model_info.base_provider
})
async def models(self):
model_list = [{
'id' : model,
'object' : 'model',
'created' : 0,
'owned_by' : 'g4f'} for model in g4f.Model.__all__()]
return jsonify({
'object': 'list',
'data': model_list})
def run(self, ip):
split_ip = ip.split(":")
uvicorn.run(app=self.app, host=split_ip[0], port=int(split_ip[1]), use_colors=False)

View File

@ -3,4 +3,4 @@ import g4f.api
if __name__ == "__main__": if __name__ == "__main__":
print(f'Starting server... [g4f v-{g4f.version}]') print(f'Starting server... [g4f v-{g4f.version}]')
g4f.api.Api(g4f).run('127.0.0.1:1337', 8) g4f.api.Api(engine = g4f, debug = True).run(ip = "127.0.0.1:1337")

View File

@ -7,11 +7,9 @@ from g4f import Provider
from g4f.api import Api from g4f.api import Api
from g4f.gui.run import gui_parser, run_gui_args from g4f.gui.run import gui_parser, run_gui_args
def run_gui(args): def run_gui(args):
print("Running GUI...") print("Running GUI...")
def main(): def main():
IgnoredProviders = Enum("ignore_providers", {key: key for key in Provider.__all__}) IgnoredProviders = Enum("ignore_providers", {key: key for key in Provider.__all__})
parser = argparse.ArgumentParser(description="Run gpt4free") parser = argparse.ArgumentParser(description="Run gpt4free")
@ -19,22 +17,19 @@ def main():
api_parser=subparsers.add_parser("api") api_parser=subparsers.add_parser("api")
api_parser.add_argument("--bind", default="127.0.0.1:1337", help="The bind string.") api_parser.add_argument("--bind", default="127.0.0.1:1337", help="The bind string.")
api_parser.add_argument("--debug", type=bool, default=False, help="Enable verbose logging") api_parser.add_argument("--debug", type=bool, default=False, help="Enable verbose logging")
api_parser.add_argument("--num-threads", type=int, default=8, help="The number of threads.")
api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.name for provider in IgnoredProviders], api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.name for provider in IgnoredProviders],
default=[], help="List of providers to ignore when processing request.") default=[], help="List of providers to ignore when processing request.")
subparsers.add_parser("gui", parents=[gui_parser()], add_help=False) subparsers.add_parser("gui", parents=[gui_parser()], add_help=False)
args = parser.parse_args() args = parser.parse_args()
if args.mode == "api": if args.mode == "api":
controller=Api(g4f, debug=args.debug) controller=Api(engine=g4f, debug=args.debug, list_ignored_providers=args.ignored_providers)
controller.list_ignored_providers=args.ignored_providers controller.run(args.bind)
controller.run(args.bind, args.num_threads)
elif args.mode == "gui": elif args.mode == "gui":
run_gui_args(args) run_gui_args(args)
else: else:
parser.print_help() parser.print_help()
exit(1) exit(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -6,8 +6,6 @@ certifi
browser_cookie3 browser_cookie3
websockets websockets
js2py js2py
flask[async]
flask-cors
typing-extensions typing-extensions
PyExecJS PyExecJS
duckduckgo-search duckduckgo-search
@ -20,3 +18,5 @@ pillow
platformdirs platformdirs
numpy numpy
asgiref asgiref
fastapi
uvicorn