community[patch]: chat model mypy fixes (#17061)

Related to #17048
bagatur/speedup_sphinx
Bagatur 5 months ago committed by GitHub
parent d93de71d08
commit 66e45e8ab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
class LlamaContentFormatter(ContentFormatterBase):
def __init__(self): # type: ignore[no-untyped-def]
def __init__(self) -> None:
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead."
@ -72,12 +72,12 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload( # type: ignore[override]
def format_messages_request_payload(
self,
messages: List[BaseMessage],
model_kwargs: Dict,
api_type: AzureMLEndpointApiType,
) -> str:
) -> bytes:
"""Formats the request according to the chosen api"""
chat_messages = [
LlamaChatContentFormatter._convert_message_to_dict(message)
@ -98,17 +98,19 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload) # type: ignore[return-value]
return str.encode(request_payload)
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
def format_response_payload(
self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> ChatGeneration:
"""Formats response"""
if api_type == AzureMLEndpointApiType.realtime:
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice.strip(),
@ -125,7 +127,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
"model. Expected `dict` but `{type(choice)}` was received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration(
message=BaseMessage(
content=choice["message"]["content"].strip(),
@ -187,7 +189,7 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
if stop:
_model_kwargs["stop"] = stop
request_payload = self.content_formatter.format_request_payload(
request_payload = self.content_formatter.format_messages_request_payload(
messages, _model_kwargs, self.endpoint_api_type
)
response_payload = self.http_client.call(

@ -327,7 +327,7 @@ class ChatDeepInfra(BaseChatModel):
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager:
run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type]
run_manager.on_llm_new_token(str(chunk.content))
async def _astream(
self,
@ -349,7 +349,7 @@ class ChatDeepInfra(BaseChatModel):
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager:
await run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type]
await run_manager.on_llm_new_token(str(chunk.content))
async def _agenerate(
self,

@ -165,6 +165,12 @@ class ChatEdenAI(BaseChatModel):
"""Return type of chat model."""
return "edenai-chat"
@property
def _api_key(self) -> str:
if self.edenai_api_key:
return self.edenai_api_key.get_secret_value()
return ""
def _stream(
self,
messages: List[BaseMessage],
@ -175,7 +181,7 @@ class ChatEdenAI(BaseChatModel):
"""Call out to EdenAI's chat endpoint."""
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -216,7 +222,7 @@ class ChatEdenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -265,7 +271,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -323,7 +329,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)

@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
generations = [
ChatGeneration(
message=AIMessage(
content=response.get("result"), # type: ignore[arg-type]
content=response.get("result", ""),
additional_kwargs={**additional_kwargs},
)
)

@ -14,6 +14,7 @@ from typing import (
Mapping,
Optional,
Tuple,
Type,
Union,
)
@ -27,7 +28,7 @@ from langchain_core.language_models.chat_models import (
generate_from_stream,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@ -56,9 +57,9 @@ class GPTRouterModel(BaseModel):
provider_name: str
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
models_priority_list: List[GPTRouterModel], **kwargs
):
def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs: Any
) -> List:
"""
Return the body for the model router input.
"""
@ -100,7 +101,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -122,7 +123,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -282,9 +283,9 @@ class GPTRouter(BaseChatModel):
)
return self._create_chat_result(response)
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
self, data: Mapping[str, Any], default_chunk_class
):
def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class
)
@ -293,8 +294,8 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
return chunk, default_chunk_class
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return gen_chunk, default_chunk_class
def _stream(
self,
@ -306,7 +307,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = completion_with_retry(
self,
messages=message_dicts,
@ -339,7 +340,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = acompletion_with_retry(
self,
messages=message_dicts,

@ -44,7 +44,7 @@ class ChatHuggingFace(BaseChatModel):
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None
model_id: str = None # type: ignore
model_id: Optional[str] = None
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM
self.model_id = self.llm.repo_id # type: ignore[assignment]
self.model_id = self.llm.repo_id
return
else:

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
class ChatParams(BaseModel, extra=Extra.allow):
"""Parameters for the `Javelin AI Gateway` LLM."""
temperature: float = 0.0

@ -13,6 +13,7 @@ from typing import (
Set,
Tuple,
Union,
cast,
)
import requests
@ -169,7 +170,9 @@ class ChatKonko(ChatOpenAI):
}
if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr]
headers["X-OpenAI-Api-Key"] = cast(
SecretStr, openai_api_key
).get_secret_value()
models_response = requests.get(models_url, headers=headers)

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
class ChatParams(BaseModel, extra=Extra.allow):
"""Parameters for the `MLflow AI Gateway` LLM."""
temperature: float = 0.0

@ -1,5 +1,5 @@
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core._api import deprecated
from langchain_core.callbacks import (
@ -74,10 +74,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
if message.content[0].get("type") == "text": # type: ignore[union-attr]
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
if isinstance(message.content, List):
first_content = cast(List[Dict], message.content)[0]
content_type = first_content.get("type")
if content_type == "text":
message_text = f"[INST] {first_content['text']} [/INST]"
elif content_type == "image_url":
message_text = first_content["image_url"]["url"]
else:
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
@ -94,7 +99,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = []
ollama_messages: List = []
for message in messages:
role = ""
if isinstance(message, HumanMessage):
@ -111,12 +116,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message.content, str):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text": # type: ignore[union-attr]
content += f"\n{content_part['text']}" # type: ignore[index]
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
for content_part in cast(List[Dict], message.content):
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
@ -142,7 +147,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
}
)
return ollama_messages # type: ignore[return-value]
return ollama_messages
def _create_chat_stream(
self,
@ -324,21 +329,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
try:
async for stream_resp in self._acreate_chat_stream(
messages, stop, **kwargs
):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
@deprecated("0.0.3", alternative="_stream")
def _legacy_stream(

@ -554,7 +554,7 @@ class ChatOpenAI(BaseChatModel):
if self.openai_proxy:
import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
return {**self._default_params, **openai_creds}
def _get_invocation_params(

@ -13,6 +13,7 @@ from typing import (
Mapping,
Optional,
Union,
cast,
)
from langchain_core.callbacks import (
@ -197,7 +198,7 @@ class ChatTongyi(BaseChatModel):
return {
"model": self.model_name,
"top_p": self.top_p,
"api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr]
"api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(),
"result_format": "message",
**self.model_kwargs,
}

@ -120,11 +120,10 @@ def _parse_chat_history_gemini(
image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr]
1
)
except AttributeError:
encoded: Any = re.search(r"data:image/\w{2,4};base64,(.*)", path)
if encoded:
encoded = encoded.group(1)
else:
raise ValueError(
"Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>."

@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
return chat_history
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc]
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
"""Wrapper around YandexGPT large language models.
There are two authentication options for the service account
@ -156,7 +156,7 @@ def _make_request(
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
res = stub.Completion(request, metadata=self._grpc_metadata)
return list(res)[0].alternatives[0].message.text
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
operation = await stub.Completion(request, metadata=self._grpc_metadata)
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
@ -211,7 +211,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined]
metadata=self._grpc_metadata,
)
completion_response = CompletionResponse()

@ -5,7 +5,7 @@ import asyncio
import json
import logging
from functools import partial
from typing import Any, Dict, Iterator, List, Optional
from typing import Any, Dict, Iterator, List, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
return attributes
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
try:
import zhipuai
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
"Please install it via 'pip install zhipuai'"
)
def invoke(self, prompt): # type: ignore[no-untyped-def]
def invoke(self, prompt: Any) -> Any: # type: ignore[override]
if self.model == "chatglm_turbo":
return self.zhipuai.model_api.invoke(
model=self.model,
@ -185,17 +185,17 @@ class ChatZhipuAI(BaseChatModel):
return_type=self.return_type,
)
elif self.model == "characterglm":
meta = self.meta.dict()
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.invoke(
model=self.model,
meta=meta,
meta=_meta,
prompt=prompt,
request_id=self.request_id,
return_type=self.return_type,
)
return None
def sse_invoke(self, prompt): # type: ignore[no-untyped-def]
def sse_invoke(self, prompt: Any) -> Any:
if self.model == "chatglm_turbo":
return self.zhipuai.model_api.sse_invoke(
model=self.model,
@ -207,18 +207,18 @@ class ChatZhipuAI(BaseChatModel):
incremental=self.incremental,
)
elif self.model == "characterglm":
meta = self.meta.dict()
_meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.sse_invoke(
model=self.model,
prompt=prompt,
meta=meta,
meta=_meta,
request_id=self.request_id,
return_type=self.return_type,
incremental=self.incremental,
)
return None
async def async_invoke(self, prompt): # type: ignore[no-untyped-def]
async def async_invoke(self, prompt: Any) -> Any:
loop = asyncio.get_running_loop()
partial_func = partial(
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
)
return response
async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def]
async def async_invoke_result(self, task_id: Any) -> Any:
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
@ -247,7 +247,7 @@ class ChatZhipuAI(BaseChatModel):
**kwargs: Any,
) -> ChatResult:
"""Generate a chat response."""
prompt = []
prompt: List = []
for message in messages:
if isinstance(message, AIMessage):
role = "assistant"
@ -270,7 +270,7 @@ class ChatZhipuAI(BaseChatModel):
else:
stream_iter = self._stream(
prompt=prompt, # type: ignore[arg-type]
prompt=prompt,
stop=stop,
run_manager=run_manager,
**kwargs,

@ -101,7 +101,7 @@ class ContentFormatterBase:
accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned from the endpoint"""
format_error_msg: Optional[str] = (
format_error_msg: str = (
"Error while formatting response payload for chat model of type "
" `{api_type}`. Are you using the right formatter for the deployed "
" model and endpoint type?"
@ -134,17 +134,17 @@ class ContentFormatterBase:
return [AzureMLEndpointApiType.realtime]
@abstractmethod
def format_request_payload(
self,
prompt: str,
model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> bytes:
) -> Any:
"""Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the
format specified in the content_type request header.
"""
raise NotImplementedError()
@abstractmethod
def format_response_payload(

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
@ -54,13 +54,14 @@ class _BaseYandexGPT(Serializable):
"""Maximum number of retries to make when generating."""
sleep_interval: float = 1.0
"""Delay between API requests"""
_grpc_metadata: Sequence
@property
def _llm_type(self) -> str:
return "yandex_gpt"
@property
def _identifying_params(self) -> Mapping[str, Any]:
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"model_uri": self.model_uri,

Loading…
Cancel
Save