Improve code by AI

pull/1000/head
Heiner Lohaus 1 year ago
parent 4fa6e9c0f5
commit f7bb30036e

@ -0,0 +1,44 @@
import sys, re
from pathlib import Path
from os import path
sys.path.append(str(Path(__file__).parent.parent.parent))
import g4f
def read_code(text):
match = re.search(r"```(python|py|)\n(?P<code>[\S\s]+?)\n```", text)
if match:
return match.group("code")
path = input("Path: ")
with open(path, "r") as file:
code = file.read()
prompt = f"""
Improve the code in this file:
```py
{code}
```
Don't remove anything. Add type hints if possible.
"""
print("Create code...")
response = []
for chunk in g4f.ChatCompletion.create(
model=g4f.models.gpt_35_long,
messages=[{"role": "user", "content": prompt}],
timeout=0,
stream=True
):
response.append(chunk)
print(chunk, end="", flush=True)
print()
response = "".join(response)
code = read_code(response)
if code:
with open(path, "w") as file:
file.write(code)

@ -10,11 +10,11 @@ from ..typing import AsyncGenerator, CreateResult
class BaseProvider(ABC): class BaseProvider(ABC):
url: str url: str
working = False working: bool = False
needs_auth = False needs_auth: bool = False
supports_stream = False supports_stream: bool = False
supports_gpt_35_turbo = False supports_gpt_35_turbo: bool = False
supports_gpt_4 = False supports_gpt_4: bool = False
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -38,13 +38,15 @@ class BaseProvider(ABC):
) -> str: ) -> str:
if not loop: if not loop:
loop = get_event_loop() loop = get_event_loop()
def create_func():
def create_func() -> str:
return "".join(cls.create_completion( return "".join(cls.create_completion(
model, model,
messages, messages,
False, False,
**kwargs **kwargs
)) ))
return await loop.run_in_executor( return await loop.run_in_executor(
executor, executor,
create_func create_func
@ -52,7 +54,7 @@ class BaseProvider(ABC):
@classmethod @classmethod
@property @property
def params(cls): def params(cls) -> str:
params = [ params = [
("model", "str"), ("model", "str"),
("messages", "list[dict[str, str]]"), ("messages", "list[dict[str, str]]"),
@ -103,7 +105,7 @@ class AsyncGeneratorProvider(AsyncProvider):
stream=stream, stream=stream,
**kwargs **kwargs
) )
gen = generator.__aiter__() gen = generator.__aiter__()
while True: while True:
try: try:
yield loop.run_until_complete(gen.__anext__()) yield loop.run_until_complete(gen.__anext__())
@ -125,7 +127,7 @@ class AsyncGeneratorProvider(AsyncProvider):
**kwargs **kwargs
) )
]) ])
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def create_async_generator( def create_async_generator(

@ -1,33 +1,33 @@
from __future__ import annotations from __future__ import annotations
import random import random
from typing import List, Type, Dict
from ..typing import CreateResult from ..typing import CreateResult
from .base_provider import BaseProvider, AsyncProvider from .base_provider import BaseProvider, AsyncProvider
from ..debug import logging from ..debug import logging
class RetryProvider(AsyncProvider): class RetryProvider(AsyncProvider):
__name__ = "RetryProvider" __name__: str = "RetryProvider"
working = True working: bool = True
needs_auth = False needs_auth: bool = False
supports_stream = True supports_stream: bool = True
supports_gpt_35_turbo = False supports_gpt_35_turbo: bool = False
supports_gpt_4 = False supports_gpt_4: bool = False
def __init__( def __init__(
self, self,
providers: list[type[BaseProvider]], providers: List[Type[BaseProvider]],
shuffle: bool = True shuffle: bool = True
) -> None: ) -> None:
self.providers = providers self.providers: List[Type[BaseProvider]] = providers
self.shuffle = shuffle self.shuffle: bool = shuffle
def create_completion( def create_completion(
self, self,
model: str, model: str,
messages: list[dict[str, str]], messages: List[Dict[str, str]],
stream: bool = False, stream: bool = False,
**kwargs **kwargs
) -> CreateResult: ) -> CreateResult:
@ -38,8 +38,8 @@ class RetryProvider(AsyncProvider):
if self.shuffle: if self.shuffle:
random.shuffle(providers) random.shuffle(providers)
self.exceptions = {} self.exceptions: Dict[str, Exception] = {}
started = False started: bool = False
for provider in providers: for provider in providers:
try: try:
if logging: if logging:
@ -61,14 +61,14 @@ class RetryProvider(AsyncProvider):
async def create_async( async def create_async(
self, self,
model: str, model: str,
messages: list[dict[str, str]], messages: List[Dict[str, str]],
**kwargs **kwargs
) -> str: ) -> str:
providers = [provider for provider in self.providers] providers = [provider for provider in self.providers]
if self.shuffle: if self.shuffle:
random.shuffle(providers) random.shuffle(providers)
self.exceptions = {} self.exceptions: Dict[str, Exception] = {}
for provider in providers: for provider in providers:
try: try:
return await provider.create_async(model, messages, **kwargs) return await provider.create_async(model, messages, **kwargs)
@ -79,7 +79,7 @@ class RetryProvider(AsyncProvider):
self.raise_exceptions() self.raise_exceptions()
def raise_exceptions(self): def raise_exceptions(self) -> None:
if self.exceptions: if self.exceptions:
raise RuntimeError("\n".join(["All providers failed:"] + [ raise RuntimeError("\n".join(["All providers failed:"] + [
f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions

@ -1,30 +1,30 @@
from __future__ import annotations from __future__ import annotations
from g4f import models from requests import get
from .Provider import BaseProvider from g4f.models import Model, ModelUtils
from .typing import CreateResult, Union from .Provider import BaseProvider
from .debug import logging from .typing import CreateResult, Union
from requests import get from .debug import logging
version = '0.1.5.4' version = '0.1.5.4'
def check_pypi_version(): def check_pypi_version() -> None:
try: try:
response = get(f"https://pypi.org/pypi/g4f/json").json() response = get("https://pypi.org/pypi/g4f/json").json()
latest_version = response["info"]["version"] latest_version = response["info"]["version"]
if version != latest_version: if version != latest_version:
print(f'New pypi version: {latest_version} (current: {version}) | pip install -U g4f') print(f'New pypi version: {latest_version} (current: {version}) | pip install -U g4f')
except Exception as e: except Exception as e:
print(f'Failed to check g4f pypi version: {e}') print(f'Failed to check g4f pypi version: {e}')
check_pypi_version() check_pypi_version()
def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool): def get_model_and_provider(model: Union[Model, str], provider: Union[type[BaseProvider], None], stream: bool) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str): if isinstance(model, str):
if model in models.ModelUtils.convert: if model in ModelUtils.convert:
model = models.ModelUtils.convert[model] model = ModelUtils.convert[model]
else: else:
raise Exception(f'The model: {model} does not exist') raise Exception(f'The model: {model} does not exist')
@ -33,14 +33,13 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP
if not provider: if not provider:
raise Exception(f'No provider found for model: {model}') raise Exception(f'No provider found for model: {model}')
if not provider.working: if not provider.working:
raise Exception(f'{provider.__name__} is not working') raise Exception(f'{provider.__name__} is not working')
if not provider.supports_stream and stream: if not provider.supports_stream and stream:
raise Exception( raise Exception(f'ValueError: {provider.__name__} does not support "stream" argument')
f'ValueError: {provider.__name__} does not support "stream" argument')
if logging: if logging:
print(f'Using {provider.__name__} provider') print(f'Using {provider.__name__} provider')
@ -49,11 +48,11 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP
class ChatCompletion: class ChatCompletion:
@staticmethod @staticmethod
def create( def create(
model : Union[models.Model, str], model: Union[Model, str],
messages : list[dict[str, str]], messages: list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None, provider: Union[type[BaseProvider], None] = None,
stream : bool = False, stream: bool = False,
auth : Union[str, None] = None, auth: Union[str, None] = None,
**kwargs **kwargs
) -> Union[CreateResult, str]: ) -> Union[CreateResult, str]:
@ -62,7 +61,7 @@ class ChatCompletion:
if provider.needs_auth and not auth: if provider.needs_auth and not auth:
raise Exception( raise Exception(
f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)') f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)')
if provider.needs_auth: if provider.needs_auth:
kwargs['auth'] = auth kwargs['auth'] = auth
@ -71,9 +70,9 @@ class ChatCompletion:
@staticmethod @staticmethod
async def create_async( async def create_async(
model : Union[models.Model, str], model: Union[Model, str],
messages : list[dict[str, str]], messages: list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None, provider: Union[type[BaseProvider], None] = None,
**kwargs **kwargs
) -> str: ) -> str:
model, provider = get_model_and_provider(model, provider, False) model, provider = get_model_and_provider(model, provider, False)
@ -83,11 +82,13 @@ class ChatCompletion:
class Completion: class Completion:
@staticmethod @staticmethod
def create( def create(
model : Union[models.Model, str], model: str,
prompt : str, prompt: str,
provider : Union[type[BaseProvider], None] = None, provider: Union[type[BaseProvider], None] = None,
stream : bool = False, **kwargs) -> Union[CreateResult, str]: stream: bool = False,
**kwargs
) -> Union[CreateResult, str]:
allowed_models = [ allowed_models = [
'code-davinci-002', 'code-davinci-002',
'text-ada-001', 'text-ada-001',
@ -96,13 +97,12 @@ class Completion:
'text-davinci-002', 'text-davinci-002',
'text-davinci-003' 'text-davinci-003'
] ]
if model not in allowed_models: if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()') raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
model, provider = get_model_and_provider(model, provider, stream) model, provider = get_model_and_provider(model, provider, stream)
result = provider.create_completion(model.name, result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)
[{"role": "user", "content": prompt}], stream, **kwargs)
return result if stream else ''.join(result) return result if stream else ''.join(result)

@ -1,47 +1,44 @@
from __future__ import annotations from __future__ import annotations
import warnings, json, asyncio import warnings
import json
import asyncio
from functools import partialmethod from functools import partialmethod
from asyncio import Future, Queue from asyncio import Future, Queue
from typing import AsyncGenerator from typing import AsyncGenerator
from curl_cffi.requests import AsyncSession, Response from curl_cffi.requests import AsyncSession, Response
import curl_cffi import curl_cffi
is_newer_0_5_8 = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl")
is_newer_0_5_9 = hasattr(curl_cffi.AsyncCurl, "remove_handle") is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle")
is_newer_0_5_10 = hasattr(AsyncSession, "release_curl") is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl")
class StreamResponse: class StreamResponse:
def __init__(self, inner: Response, queue: Queue): def __init__(self, inner: Response, queue: Queue[bytes]) -> None:
self.inner = inner self.inner: Response = inner
self.queue = queue self.queue: Queue[bytes] = queue
self.request = inner.request self.request = inner.request
self.status_code = inner.status_code self.status_code: int = inner.status_code
self.reason = inner.reason self.reason: str = inner.reason
self.ok = inner.ok self.ok: bool = inner.ok
self.headers = inner.headers self.headers = inner.headers
self.cookies = inner.cookies self.cookies = inner.cookies
async def text(self) -> str: async def text(self) -> str:
content = await self.read() content: bytes = await self.read()
return content.decode() return content.decode()
def raise_for_status(self): def raise_for_status(self) -> None:
if not self.ok: if not self.ok:
raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}") raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}")
async def json(self, **kwargs): async def json(self, **kwargs) -> dict:
return json.loads(await self.read(), **kwargs) return json.loads(await self.read(), **kwargs)
async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes]: async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes, None]:
""" pending: bytes = None
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None
async for chunk in self.iter_content( async for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode chunk_size=chunk_size, decode_unicode=decode_unicode
@ -63,7 +60,7 @@ class StreamResponse:
if pending is not None: if pending is not None:
yield pending yield pending
async def iter_content(self, chunk_size=None, decode_unicode=False) -> As: async def iter_content(self, chunk_size=None, decode_unicode=False) -> AsyncGenerator[bytes, None]:
if chunk_size: if chunk_size:
warnings.warn("chunk_size is ignored, there is no way to tell curl that.") warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode: if decode_unicode:
@ -77,22 +74,23 @@ class StreamResponse:
async def read(self) -> bytes: async def read(self) -> bytes:
return b"".join([chunk async for chunk in self.iter_content()]) return b"".join([chunk async for chunk in self.iter_content()])
class StreamRequest: class StreamRequest:
def __init__(self, session: AsyncSession, method: str, url: str, **kwargs): def __init__(self, session: AsyncSession, method: str, url: str, **kwargs) -> None:
self.session = session self.session: AsyncSession = session
self.loop = session.loop if session.loop else asyncio.get_running_loop() self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop()
self.queue = Queue() self.queue: Queue[bytes] = Queue()
self.method = method self.method: str = method
self.url = url self.url: str = url
self.options = kwargs self.options: dict = kwargs
self.handle = None self.handle: curl_cffi.AsyncCurl = None
def _on_content(self, data): def _on_content(self, data: bytes) -> None:
if not self.enter.done(): if not self.enter.done():
self.enter.set_result(None) self.enter.set_result(None)
self.queue.put_nowait(data) self.queue.put_nowait(data)
def _on_done(self, task: Future): def _on_done(self, task: Future) -> None:
if not self.enter.done(): if not self.enter.done():
self.enter.set_result(None) self.enter.set_result(None)
self.queue.put_nowait(None) self.queue.put_nowait(None)
@ -102,8 +100,8 @@ class StreamRequest:
async def fetch(self) -> StreamResponse: async def fetch(self) -> StreamResponse:
if self.handle: if self.handle:
raise RuntimeError("Request already started") raise RuntimeError("Request already started")
self.curl = await self.session.pop_curl() self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl()
self.enter = self.loop.create_future() self.enter: asyncio.Future = self.loop.create_future()
if is_newer_0_5_10: if is_newer_0_5_10:
request, _, header_buffer, _, _ = self.session._set_curl_options( request, _, header_buffer, _, _ = self.session._set_curl_options(
self.curl, self.curl,
@ -121,7 +119,7 @@ class StreamRequest:
**self.options **self.options
) )
if is_newer_0_5_9: if is_newer_0_5_9:
self.handle = self.session.acurl.add_handle(self.curl) self.handle = self.session.acurl.add_handle(self.curl)
else: else:
await self.session.acurl.add_handle(self.curl, False) await self.session.acurl.add_handle(self.curl, False)
self.handle = self.session.acurl._curl2future[self.curl] self.handle = self.session.acurl._curl2future[self.curl]
@ -140,14 +138,14 @@ class StreamRequest:
response, response,
self.queue self.queue
) )
async def __aenter__(self) -> StreamResponse: async def __aenter__(self) -> StreamResponse:
return await self.fetch() return await self.fetch()
async def __aexit__(self, *args): async def __aexit__(self, *args) -> None:
self.release_curl() self.release_curl()
def release_curl(self): def release_curl(self) -> None:
if is_newer_0_5_10: if is_newer_0_5_10:
self.session.release_curl(self.curl) self.session.release_curl(self.curl)
return return
@ -162,6 +160,7 @@ class StreamRequest:
self.session.push_curl(self.curl) self.session.push_curl(self.curl)
self.curl = None self.curl = None
class StreamSession(AsyncSession): class StreamSession(AsyncSession):
def request( def request(
self, self,
@ -170,7 +169,7 @@ class StreamSession(AsyncSession):
**kwargs **kwargs
) -> StreamRequest: ) -> StreamRequest:
return StreamRequest(self, method, url, **kwargs) return StreamRequest(self, method, url, **kwargs)
head = partialmethod(request, "HEAD") head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET") get = partialmethod(request, "GET")
post = partialmethod(request, "POST") post = partialmethod(request, "POST")

Loading…
Cancel
Save