Improve code by AI

pull/1000/head
Heiner Lohaus 1 year ago
parent 4fa6e9c0f5
commit f7bb30036e

@ -0,0 +1,44 @@
import sys, re
from pathlib import Path
from os import path
sys.path.append(str(Path(__file__).parent.parent.parent))
import g4f
def read_code(text):
match = re.search(r"```(python|py|)\n(?P<code>[\S\s]+?)\n```", text)
if match:
return match.group("code")
path = input("Path: ")
with open(path, "r") as file:
code = file.read()
prompt = f"""
Improve the code in this file:
```py
{code}
```
Don't remove anything. Add type hints if possible.
"""
print("Create code...")
response = []
for chunk in g4f.ChatCompletion.create(
model=g4f.models.gpt_35_long,
messages=[{"role": "user", "content": prompt}],
timeout=0,
stream=True
):
response.append(chunk)
print(chunk, end="", flush=True)
print()
response = "".join(response)
code = read_code(response)
if code:
with open(path, "w") as file:
file.write(code)

@ -10,11 +10,11 @@ from ..typing import AsyncGenerator, CreateResult
class BaseProvider(ABC):
url: str
working = False
needs_auth = False
supports_stream = False
supports_gpt_35_turbo = False
supports_gpt_4 = False
working: bool = False
needs_auth: bool = False
supports_stream: bool = False
supports_gpt_35_turbo: bool = False
supports_gpt_4: bool = False
@staticmethod
@abstractmethod
@ -38,13 +38,15 @@ class BaseProvider(ABC):
) -> str:
if not loop:
loop = get_event_loop()
def create_func():
def create_func() -> str:
return "".join(cls.create_completion(
model,
messages,
False,
**kwargs
))
return await loop.run_in_executor(
executor,
create_func
@ -52,7 +54,7 @@ class BaseProvider(ABC):
@classmethod
@property
def params(cls):
def params(cls) -> str:
params = [
("model", "str"),
("messages", "list[dict[str, str]]"),
@ -103,7 +105,7 @@ class AsyncGeneratorProvider(AsyncProvider):
stream=stream,
**kwargs
)
gen = generator.__aiter__()
gen = generator.__aiter__()
while True:
try:
yield loop.run_until_complete(gen.__anext__())
@ -125,7 +127,7 @@ class AsyncGeneratorProvider(AsyncProvider):
**kwargs
)
])
@staticmethod
@abstractmethod
def create_async_generator(

@ -1,33 +1,33 @@
from __future__ import annotations
import random
from typing import List, Type, Dict
from ..typing import CreateResult
from .base_provider import BaseProvider, AsyncProvider
from ..debug import logging
class RetryProvider(AsyncProvider):
__name__ = "RetryProvider"
working = True
needs_auth = False
supports_stream = True
supports_gpt_35_turbo = False
supports_gpt_4 = False
__name__: str = "RetryProvider"
working: bool = True
needs_auth: bool = False
supports_stream: bool = True
supports_gpt_35_turbo: bool = False
supports_gpt_4: bool = False
def __init__(
self,
providers: list[type[BaseProvider]],
providers: List[Type[BaseProvider]],
shuffle: bool = True
) -> None:
self.providers = providers
self.shuffle = shuffle
self.providers: List[Type[BaseProvider]] = providers
self.shuffle: bool = shuffle
def create_completion(
self,
model: str,
messages: list[dict[str, str]],
messages: List[Dict[str, str]],
stream: bool = False,
**kwargs
) -> CreateResult:
@ -38,8 +38,8 @@ class RetryProvider(AsyncProvider):
if self.shuffle:
random.shuffle(providers)
self.exceptions = {}
started = False
self.exceptions: Dict[str, Exception] = {}
started: bool = False
for provider in providers:
try:
if logging:
@ -61,14 +61,14 @@ class RetryProvider(AsyncProvider):
async def create_async(
self,
model: str,
messages: list[dict[str, str]],
messages: List[Dict[str, str]],
**kwargs
) -> str:
providers = [provider for provider in self.providers]
if self.shuffle:
random.shuffle(providers)
self.exceptions = {}
self.exceptions: Dict[str, Exception] = {}
for provider in providers:
try:
return await provider.create_async(model, messages, **kwargs)
@ -79,7 +79,7 @@ class RetryProvider(AsyncProvider):
self.raise_exceptions()
def raise_exceptions(self):
def raise_exceptions(self) -> None:
if self.exceptions:
raise RuntimeError("\n".join(["All providers failed:"] + [
f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions

@ -1,30 +1,30 @@
from __future__ import annotations
from g4f import models
from .Provider import BaseProvider
from .typing import CreateResult, Union
from .debug import logging
from requests import get
from requests import get
from g4f.models import Model, ModelUtils
from .Provider import BaseProvider
from .typing import CreateResult, Union
from .debug import logging
version = '0.1.5.4'
def check_pypi_version():
def check_pypi_version() -> None:
try:
response = get(f"https://pypi.org/pypi/g4f/json").json()
response = get("https://pypi.org/pypi/g4f/json").json()
latest_version = response["info"]["version"]
if version != latest_version:
print(f'New pypi version: {latest_version} (current: {version}) | pip install -U g4f')
except Exception as e:
print(f'Failed to check g4f pypi version: {e}')
check_pypi_version()
def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseProvider], stream: bool):
def get_model_and_provider(model: Union[Model, str], provider: Union[type[BaseProvider], None], stream: bool) -> tuple[Model, type[BaseProvider]]:
if isinstance(model, str):
if model in models.ModelUtils.convert:
model = models.ModelUtils.convert[model]
if model in ModelUtils.convert:
model = ModelUtils.convert[model]
else:
raise Exception(f'The model: {model} does not exist')
@ -33,14 +33,13 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP
if not provider:
raise Exception(f'No provider found for model: {model}')
if not provider.working:
raise Exception(f'{provider.__name__} is not working')
if not provider.supports_stream and stream:
raise Exception(
f'ValueError: {provider.__name__} does not support "stream" argument')
raise Exception(f'ValueError: {provider.__name__} does not support "stream" argument')
if logging:
print(f'Using {provider.__name__} provider')
@ -49,11 +48,11 @@ def get_model_and_provider(model: Union[models.Model, str], provider: type[BaseP
class ChatCompletion:
@staticmethod
def create(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
auth : Union[str, None] = None,
model: Union[Model, str],
messages: list[dict[str, str]],
provider: Union[type[BaseProvider], None] = None,
stream: bool = False,
auth: Union[str, None] = None,
**kwargs
) -> Union[CreateResult, str]:
@ -62,7 +61,7 @@ class ChatCompletion:
if provider.needs_auth and not auth:
raise Exception(
f'ValueError: {provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)')
if provider.needs_auth:
kwargs['auth'] = auth
@ -71,9 +70,9 @@ class ChatCompletion:
@staticmethod
async def create_async(
model : Union[models.Model, str],
messages : list[dict[str, str]],
provider : Union[type[BaseProvider], None] = None,
model: Union[Model, str],
messages: list[dict[str, str]],
provider: Union[type[BaseProvider], None] = None,
**kwargs
) -> str:
model, provider = get_model_and_provider(model, provider, False)
@ -83,11 +82,13 @@ class ChatCompletion:
class Completion:
@staticmethod
def create(
model : Union[models.Model, str],
prompt : str,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False, **kwargs) -> Union[CreateResult, str]:
model: str,
prompt: str,
provider: Union[type[BaseProvider], None] = None,
stream: bool = False,
**kwargs
) -> Union[CreateResult, str]:
allowed_models = [
'code-davinci-002',
'text-ada-001',
@ -96,13 +97,12 @@ class Completion:
'text-davinci-002',
'text-davinci-003'
]
if model not in allowed_models:
raise Exception(f'ValueError: Can\'t use {model} with Completion.create()')
model, provider = get_model_and_provider(model, provider, stream)
result = provider.create_completion(model.name,
[{"role": "user", "content": prompt}], stream, **kwargs)
result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)
return result if stream else ''.join(result)
return result if stream else ''.join(result)

@ -1,47 +1,44 @@
from __future__ import annotations
import warnings, json, asyncio
import warnings
import json
import asyncio
from functools import partialmethod
from asyncio import Future, Queue
from typing import AsyncGenerator
from curl_cffi.requests import AsyncSession, Response
import curl_cffi
is_newer_0_5_8 = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl")
is_newer_0_5_9 = hasattr(curl_cffi.AsyncCurl, "remove_handle")
is_newer_0_5_10 = hasattr(AsyncSession, "release_curl")
is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl")
is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle")
is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl")
class StreamResponse:
def __init__(self, inner: Response, queue: Queue):
self.inner = inner
self.queue = queue
def __init__(self, inner: Response, queue: Queue[bytes]) -> None:
self.inner: Response = inner
self.queue: Queue[bytes] = queue
self.request = inner.request
self.status_code = inner.status_code
self.reason = inner.reason
self.ok = inner.ok
self.status_code: int = inner.status_code
self.reason: str = inner.reason
self.ok: bool = inner.ok
self.headers = inner.headers
self.cookies = inner.cookies
async def text(self) -> str:
content = await self.read()
content: bytes = await self.read()
return content.decode()
def raise_for_status(self):
def raise_for_status(self) -> None:
if not self.ok:
raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}")
async def json(self, **kwargs):
async def json(self, **kwargs) -> dict:
return json.loads(await self.read(), **kwargs)
async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes]:
"""
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None
async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes, None]:
pending: bytes = None
async for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
@ -63,7 +60,7 @@ class StreamResponse:
if pending is not None:
yield pending
async def iter_content(self, chunk_size=None, decode_unicode=False) -> As:
async def iter_content(self, chunk_size=None, decode_unicode=False) -> AsyncGenerator[bytes, None]:
if chunk_size:
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
@ -77,22 +74,23 @@ class StreamResponse:
async def read(self) -> bytes:
return b"".join([chunk async for chunk in self.iter_content()])
class StreamRequest:
def __init__(self, session: AsyncSession, method: str, url: str, **kwargs):
self.session = session
self.loop = session.loop if session.loop else asyncio.get_running_loop()
self.queue = Queue()
self.method = method
self.url = url
self.options = kwargs
self.handle = None
def _on_content(self, data):
def __init__(self, session: AsyncSession, method: str, url: str, **kwargs) -> None:
self.session: AsyncSession = session
self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop()
self.queue: Queue[bytes] = Queue()
self.method: str = method
self.url: str = url
self.options: dict = kwargs
self.handle: curl_cffi.AsyncCurl = None
def _on_content(self, data: bytes) -> None:
if not self.enter.done():
self.enter.set_result(None)
self.queue.put_nowait(data)
def _on_done(self, task: Future):
def _on_done(self, task: Future) -> None:
if not self.enter.done():
self.enter.set_result(None)
self.queue.put_nowait(None)
@ -102,8 +100,8 @@ class StreamRequest:
async def fetch(self) -> StreamResponse:
if self.handle:
raise RuntimeError("Request already started")
self.curl = await self.session.pop_curl()
self.enter = self.loop.create_future()
self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl()
self.enter: asyncio.Future = self.loop.create_future()
if is_newer_0_5_10:
request, _, header_buffer, _, _ = self.session._set_curl_options(
self.curl,
@ -121,7 +119,7 @@ class StreamRequest:
**self.options
)
if is_newer_0_5_9:
self.handle = self.session.acurl.add_handle(self.curl)
self.handle = self.session.acurl.add_handle(self.curl)
else:
await self.session.acurl.add_handle(self.curl, False)
self.handle = self.session.acurl._curl2future[self.curl]
@ -140,14 +138,14 @@ class StreamRequest:
response,
self.queue
)
async def __aenter__(self) -> StreamResponse:
return await self.fetch()
async def __aexit__(self, *args):
async def __aexit__(self, *args) -> None:
self.release_curl()
def release_curl(self):
def release_curl(self) -> None:
if is_newer_0_5_10:
self.session.release_curl(self.curl)
return
@ -162,6 +160,7 @@ class StreamRequest:
self.session.push_curl(self.curl)
self.curl = None
class StreamSession(AsyncSession):
def request(
self,
@ -170,7 +169,7 @@ class StreamSession(AsyncSession):
**kwargs
) -> StreamRequest:
return StreamRequest(self, method, url, **kwargs)
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")

Loading…
Cancel
Save