community[patch]: chat model mypy fixes (#17061)

Related to #17048
bagatur/speedup_sphinx
Bagatur 8 months ago committed by GitHub
parent d93de71d08
commit 66e45e8ab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
class LlamaContentFormatter(ContentFormatterBase): class LlamaContentFormatter(ContentFormatterBase):
def __init__(self): # type: ignore[no-untyped-def] def __init__(self) -> None:
raise TypeError( raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use " "`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead." "`LlamaChatContentFormatter` instead."
@ -72,12 +72,12 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload( # type: ignore[override] def format_messages_request_payload(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
model_kwargs: Dict, model_kwargs: Dict,
api_type: AzureMLEndpointApiType, api_type: AzureMLEndpointApiType,
) -> str: ) -> bytes:
"""Formats the request according to the chosen api""" """Formats the request according to the chosen api"""
chat_messages = [ chat_messages = [
LlamaChatContentFormatter._convert_message_to_dict(message) LlamaChatContentFormatter._convert_message_to_dict(message)
@ -98,17 +98,19 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError( raise ValueError(
f"`api_type` {api_type} is not supported by this formatter" f"`api_type` {api_type} is not supported by this formatter"
) )
return str.encode(request_payload) # type: ignore[return-value] return str.encode(request_payload)
def format_response_payload( # type: ignore[override] def format_response_payload(
self, output: bytes, api_type: AzureMLEndpointApiType self,
output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> ChatGeneration: ) -> ChatGeneration:
"""Formats response""" """Formats response"""
if api_type == AzureMLEndpointApiType.realtime: if api_type == AzureMLEndpointApiType.realtime:
try: try:
choice = json.loads(output)["output"] choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration( return ChatGeneration(
message=BaseMessage( message=BaseMessage(
content=choice.strip(), content=choice.strip(),
@ -125,7 +127,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
"model. Expected `dict` but `{type(choice)}` was received." "model. Expected `dict` but `{type(choice)}` was received."
) )
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr] raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
return ChatGeneration( return ChatGeneration(
message=BaseMessage( message=BaseMessage(
content=choice["message"]["content"].strip(), content=choice["message"]["content"].strip(),
@ -187,7 +189,7 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
if stop: if stop:
_model_kwargs["stop"] = stop _model_kwargs["stop"] = stop
request_payload = self.content_formatter.format_request_payload( request_payload = self.content_formatter.format_messages_request_payload(
messages, _model_kwargs, self.endpoint_api_type messages, _model_kwargs, self.endpoint_api_type
) )
response_payload = self.http_client.call( response_payload = self.http_client.call(

@ -327,7 +327,7 @@ class ChatDeepInfra(BaseChatModel):
if chunk: if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None) yield ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager: if run_manager:
run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type] run_manager.on_llm_new_token(str(chunk.content))
async def _astream( async def _astream(
self, self,
@ -349,7 +349,7 @@ class ChatDeepInfra(BaseChatModel):
if chunk: if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None) yield ChatGenerationChunk(message=chunk, generation_info=None)
if run_manager: if run_manager:
await run_manager.on_llm_new_token(chunk.content) # type: ignore[arg-type] await run_manager.on_llm_new_token(str(chunk.content))
async def _agenerate( async def _agenerate(
self, self,

@ -165,6 +165,12 @@ class ChatEdenAI(BaseChatModel):
"""Return type of chat model.""" """Return type of chat model."""
return "edenai-chat" return "edenai-chat"
@property
def _api_key(self) -> str:
if self.edenai_api_key:
return self.edenai_api_key.get_secret_value()
return ""
def _stream( def _stream(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],
@ -175,7 +181,7 @@ class ChatEdenAI(BaseChatModel):
"""Call out to EdenAI's chat endpoint.""" """Call out to EdenAI's chat endpoint."""
url = f"{self.edenai_api_url}/text/chat/stream" url = f"{self.edenai_api_url}/text/chat/stream"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -216,7 +222,7 @@ class ChatEdenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.edenai_api_url}/text/chat/stream" url = f"{self.edenai_api_url}/text/chat/stream"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -265,7 +271,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat" url = f"{self.edenai_api_url}/text/chat"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)
@ -323,7 +329,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat" url = f"{self.edenai_api_url}/text/chat"
headers = { headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr] "Authorization": f"Bearer {self._api_key}",
"User-Agent": self.get_user_agent(), "User-Agent": self.get_user_agent(),
} }
formatted_data = _format_edenai_messages(messages=messages) formatted_data = _format_edenai_messages(messages=messages)

@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
generations = [ generations = [
ChatGeneration( ChatGeneration(
message=AIMessage( message=AIMessage(
content=response.get("result"), # type: ignore[arg-type] content=response.get("result", ""),
additional_kwargs={**additional_kwargs}, additional_kwargs={**additional_kwargs},
) )
) )

@ -14,6 +14,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -27,7 +28,7 @@ from langchain_core.language_models.chat_models import (
generate_from_stream, generate_from_stream,
) )
from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.messages import AIMessageChunk, BaseMessage from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
@ -56,9 +57,9 @@ class GPTRouterModel(BaseModel):
provider_name: str provider_name: str
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def] def get_ordered_generation_requests(
models_priority_list: List[GPTRouterModel], **kwargs models_priority_list: List[GPTRouterModel], **kwargs: Any
): ) -> List:
""" """
Return the body for the model router input. Return the body for the model router input.
""" """
@ -100,7 +101,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel], models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg] ) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
"""Use tenacity to retry the completion call.""" """Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -122,7 +123,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel], models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg] ) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
"""Use tenacity to retry the async completion call.""" """Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -282,9 +283,9 @@ class GPTRouter(BaseChatModel):
) )
return self._create_chat_result(response) return self._create_chat_result(response)
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def] def _create_chat_generation_chunk(
self, data: Mapping[str, Any], default_chunk_class self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
): ) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
chunk = _convert_delta_to_message_chunk( chunk = _convert_delta_to_message_chunk(
{"content": data.get("text", "")}, default_chunk_class {"content": data.get("text", "")}, default_chunk_class
) )
@ -293,8 +294,8 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None dict(finish_reason=finish_reason) if finish_reason is not None else None
) )
default_chunk_class = chunk.__class__ default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment] gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
return chunk, default_chunk_class return gen_chunk, default_chunk_class
def _stream( def _stream(
self, self,
@ -306,7 +307,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = completion_with_retry( generator_response = completion_with_retry(
self, self,
messages=message_dicts, messages=message_dicts,
@ -339,7 +340,7 @@ class GPTRouter(BaseChatModel):
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True} params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
generator_response = acompletion_with_retry( generator_response = acompletion_with_retry(
self, self,
messages=message_dicts, messages=message_dicts,

@ -44,7 +44,7 @@ class ChatHuggingFace(BaseChatModel):
llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub] llm: Union[HuggingFaceTextGenInference, HuggingFaceEndpoint, HuggingFaceHub]
system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT) system_message: SystemMessage = SystemMessage(content=DEFAULT_SYSTEM_PROMPT)
tokenizer: Any = None tokenizer: Any = None
model_id: str = None # type: ignore model_id: Optional[str] = None
def __init__(self, **kwargs: Any): def __init__(self, **kwargs: Any):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(self.llm, HuggingFaceHub): elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM # no need to look up model_id for HuggingFaceHub LLM
self.model_id = self.llm.repo_id # type: ignore[assignment] self.model_id = self.llm.repo_id
return return
else: else:

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code # Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg] # Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg] class ChatParams(BaseModel, extra=Extra.allow):
"""Parameters for the `Javelin AI Gateway` LLM.""" """Parameters for the `Javelin AI Gateway` LLM."""
temperature: float = 0.0 temperature: float = 0.0

@ -13,6 +13,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
Union, Union,
cast,
) )
import requests import requests
@ -169,7 +170,9 @@ class ChatKonko(ChatOpenAI):
} }
if openai_api_key: if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr] headers["X-OpenAI-Api-Key"] = cast(
SecretStr, openai_api_key
).get_secret_value()
models_response = requests.get(models_url, headers=headers) models_response = requests.get(models_url, headers=headers)

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Ignoring type because below is valid pydantic code # Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg] # Unexpected keyword argument "extra" for "__init_subclass__" of "object" [call-arg]
class ChatParams(BaseModel, extra=Extra.allow): # type: ignore[call-arg] class ChatParams(BaseModel, extra=Extra.allow):
"""Parameters for the `MLflow AI Gateway` LLM.""" """Parameters for the `MLflow AI Gateway` LLM."""
temperature: float = 0.0 temperature: float = 0.0

@ -1,5 +1,5 @@
import json import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -74,10 +74,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message, ChatMessage): if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}" message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
if message.content[0].get("type") == "text": # type: ignore[union-attr] if isinstance(message.content, List):
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index] first_content = cast(List[Dict], message.content)[0]
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr] content_type = first_content.get("type")
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index] if content_type == "text":
message_text = f"[INST] {first_content['text']} [/INST]"
elif content_type == "image_url":
message_text = first_content["image_url"]["url"]
else:
message_text = f"[INST] {message.content} [/INST]"
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
message_text = f"{message.content}" message_text = f"{message.content}"
elif isinstance(message, SystemMessage): elif isinstance(message, SystemMessage):
@ -94,7 +99,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
def _convert_messages_to_ollama_messages( def _convert_messages_to_ollama_messages(
self, messages: List[BaseMessage] self, messages: List[BaseMessage]
) -> List[Dict[str, Union[str, List[str]]]]: ) -> List[Dict[str, Union[str, List[str]]]]:
ollama_messages = [] ollama_messages: List = []
for message in messages: for message in messages:
role = "" role = ""
if isinstance(message, HumanMessage): if isinstance(message, HumanMessage):
@ -111,12 +116,12 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message.content, str): if isinstance(message.content, str):
content = message.content content = message.content
else: else:
for content_part in message.content: for content_part in cast(List[Dict], message.content):
if content_part.get("type") == "text": # type: ignore[union-attr] if content_part.get("type") == "text":
content += f"\n{content_part['text']}" # type: ignore[index] content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url": # type: ignore[union-attr] elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr] if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",") # type: ignore[index] image_url_components = content_part["image_url"].split(",")
# Support data:image/jpeg;base64,<image> format # Support data:image/jpeg;base64,<image> format
# and base64 strings # and base64 strings
if len(image_url_components) > 1: if len(image_url_components) > 1:
@ -142,7 +147,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
} }
) )
return ollama_messages # type: ignore[return-value] return ollama_messages
def _create_chat_stream( def _create_chat_stream(
self, self,
@ -324,21 +329,15 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]: ) -> AsyncIterator[ChatGenerationChunk]:
try: async for stream_resp in self._acreate_chat_stream(messages, stop, **kwargs):
async for stream_resp in self._acreate_chat_stream( if stream_resp:
messages, stop, **kwargs chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
):
if stream_resp:
chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
yield chunk yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
@deprecated("0.0.3", alternative="_stream") @deprecated("0.0.3", alternative="_stream")
def _legacy_stream( def _legacy_stream(

@ -554,7 +554,7 @@ class ChatOpenAI(BaseChatModel):
if self.openai_proxy: if self.openai_proxy:
import openai import openai
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501 openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy}
return {**self._default_params, **openai_creds} return {**self._default_params, **openai_creds}
def _get_invocation_params( def _get_invocation_params(

@ -13,6 +13,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Union, Union,
cast,
) )
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -197,7 +198,7 @@ class ChatTongyi(BaseChatModel):
return { return {
"model": self.model_name, "model": self.model_name,
"top_p": self.top_p, "top_p": self.top_p,
"api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr] "api_key": cast(SecretStr, self.dashscope_api_key).get_secret_value(),
"result_format": "message", "result_format": "message",
**self.model_kwargs, **self.model_kwargs,
} }

@ -120,11 +120,10 @@ def _parse_chat_history_gemini(
image = load_image_from_gcs(path=path, project=project) image = load_image_from_gcs(path=path, project=project)
elif path.startswith("data:image/"): elif path.startswith("data:image/"):
# extract base64 component from image uri # extract base64 component from image uri
try: encoded: Any = re.search(r"data:image/\w{2,4};base64,(.*)", path)
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr] if encoded:
1 encoded = encoded.group(1)
) else:
except AttributeError:
raise ValueError( raise ValueError(
"Invalid image uri. It should be in the format " "Invalid image uri. It should be in the format "
"data:image/<image_type>;base64,<base64_encoded_image>." "data:image/<image_type>;base64,<base64_encoded_image>."

@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
return chat_history return chat_history
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc] class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
"""Wrapper around YandexGPT large language models. """Wrapper around YandexGPT large language models.
There are two authentication options for the service account There are two authentication options for the service account
@ -156,7 +156,7 @@ def _make_request(
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationServiceStub(channel) stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] res = stub.Completion(request, metadata=self._grpc_metadata)
return list(res)[0].alternatives[0].message.text return list(res)[0].alternatives[0].message.text
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
messages=[Message(**message) for message in message_history], messages=[Message(**message) for message in message_history],
) )
stub = TextGenerationAsyncServiceStub(channel) stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined] operation = await stub.Completion(request, metadata=self._grpc_metadata)
async with grpc.aio.secure_channel( async with grpc.aio.secure_channel(
operation_api_url, channel_credentials operation_api_url, channel_credentials
) as operation_channel: ) as operation_channel:
@ -211,7 +211,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
operation_request = GetOperationRequest(operation_id=operation.id) operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get( operation = await operation_stub.Get(
operation_request, operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined] metadata=self._grpc_metadata,
) )
completion_response = CompletionResponse() completion_response = CompletionResponse()

@ -5,7 +5,7 @@ import asyncio
import json import json
import logging import logging
from functools import partial from functools import partial
from typing import Any, Dict, Iterator, List, Optional from typing import Any, Dict, Iterator, List, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import (
@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
return attributes return attributes
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
try: try:
import zhipuai import zhipuai
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
"Please install it via 'pip install zhipuai'" "Please install it via 'pip install zhipuai'"
) )
def invoke(self, prompt): # type: ignore[no-untyped-def] def invoke(self, prompt: Any) -> Any: # type: ignore[override]
if self.model == "chatglm_turbo": if self.model == "chatglm_turbo":
return self.zhipuai.model_api.invoke( return self.zhipuai.model_api.invoke(
model=self.model, model=self.model,
@ -185,17 +185,17 @@ class ChatZhipuAI(BaseChatModel):
return_type=self.return_type, return_type=self.return_type,
) )
elif self.model == "characterglm": elif self.model == "characterglm":
meta = self.meta.dict() _meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.invoke( return self.zhipuai.model_api.invoke(
model=self.model, model=self.model,
meta=meta, meta=_meta,
prompt=prompt, prompt=prompt,
request_id=self.request_id, request_id=self.request_id,
return_type=self.return_type, return_type=self.return_type,
) )
return None return None
def sse_invoke(self, prompt): # type: ignore[no-untyped-def] def sse_invoke(self, prompt: Any) -> Any:
if self.model == "chatglm_turbo": if self.model == "chatglm_turbo":
return self.zhipuai.model_api.sse_invoke( return self.zhipuai.model_api.sse_invoke(
model=self.model, model=self.model,
@ -207,18 +207,18 @@ class ChatZhipuAI(BaseChatModel):
incremental=self.incremental, incremental=self.incremental,
) )
elif self.model == "characterglm": elif self.model == "characterglm":
meta = self.meta.dict() _meta = cast(meta, self.meta).dict()
return self.zhipuai.model_api.sse_invoke( return self.zhipuai.model_api.sse_invoke(
model=self.model, model=self.model,
prompt=prompt, prompt=prompt,
meta=meta, meta=_meta,
request_id=self.request_id, request_id=self.request_id,
return_type=self.return_type, return_type=self.return_type,
incremental=self.incremental, incremental=self.incremental,
) )
return None return None
async def async_invoke(self, prompt): # type: ignore[no-untyped-def] async def async_invoke(self, prompt: Any) -> Any:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
partial_func = partial( partial_func = partial(
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
) )
return response return response
async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def] async def async_invoke_result(self, task_id: Any) -> Any:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
response = await loop.run_in_executor( response = await loop.run_in_executor(
None, None,
@ -247,7 +247,7 @@ class ChatZhipuAI(BaseChatModel):
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate a chat response.""" """Generate a chat response."""
prompt = [] prompt: List = []
for message in messages: for message in messages:
if isinstance(message, AIMessage): if isinstance(message, AIMessage):
role = "assistant" role = "assistant"
@ -270,7 +270,7 @@ class ChatZhipuAI(BaseChatModel):
else: else:
stream_iter = self._stream( stream_iter = self._stream(
prompt=prompt, # type: ignore[arg-type] prompt=prompt,
stop=stop, stop=stop,
run_manager=run_manager, run_manager=run_manager,
**kwargs, **kwargs,

@ -101,7 +101,7 @@ class ContentFormatterBase:
accepts: Optional[str] = "application/json" accepts: Optional[str] = "application/json"
"""The MIME type of the response data returned from the endpoint""" """The MIME type of the response data returned from the endpoint"""
format_error_msg: Optional[str] = ( format_error_msg: str = (
"Error while formatting response payload for chat model of type " "Error while formatting response payload for chat model of type "
" `{api_type}`. Are you using the right formatter for the deployed " " `{api_type}`. Are you using the right formatter for the deployed "
" model and endpoint type?" " model and endpoint type?"
@ -134,17 +134,17 @@ class ContentFormatterBase:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.realtime]
@abstractmethod
def format_request_payload( def format_request_payload(
self, self,
prompt: str, prompt: str,
model_kwargs: Dict, model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime,
) -> bytes: ) -> Any:
"""Formats the request body according to the input schema of """Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the the model. Returns bytes or seekable file like object in the
format specified in the content_type request header. format specified in the content_type request header.
""" """
raise NotImplementedError()
@abstractmethod @abstractmethod
def format_response_payload( def format_response_payload(

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Callable, Dict, List, Mapping, Optional from typing import Any, Callable, Dict, List, Optional, Sequence
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -54,13 +54,14 @@ class _BaseYandexGPT(Serializable):
"""Maximum number of retries to make when generating.""" """Maximum number of retries to make when generating."""
sleep_interval: float = 1.0 sleep_interval: float = 1.0
"""Delay between API requests""" """Delay between API requests"""
_grpc_metadata: Sequence
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "yandex_gpt" return "yandex_gpt"
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return { return {
"model_uri": self.model_uri, "model_uri": self.model_uri,

Loading…
Cancel
Save