|
|
@ -14,6 +14,7 @@ from typing import (
|
|
|
|
Mapping,
|
|
|
|
Mapping,
|
|
|
|
Optional,
|
|
|
|
Optional,
|
|
|
|
Tuple,
|
|
|
|
Tuple,
|
|
|
|
|
|
|
|
Type,
|
|
|
|
Union,
|
|
|
|
Union,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
@ -27,7 +28,7 @@ from langchain_core.language_models.chat_models import (
|
|
|
|
generate_from_stream,
|
|
|
|
generate_from_stream,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
|
|
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
|
|
|
from langchain_core.messages import AIMessageChunk, BaseMessage
|
|
|
|
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
|
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
|
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
|
@ -56,9 +57,9 @@ class GPTRouterModel(BaseModel):
|
|
|
|
provider_name: str
|
|
|
|
provider_name: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
|
|
|
|
def get_ordered_generation_requests(
|
|
|
|
models_priority_list: List[GPTRouterModel], **kwargs
|
|
|
|
models_priority_list: List[GPTRouterModel], **kwargs: Any
|
|
|
|
):
|
|
|
|
) -> List:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Return the body for the model router input.
|
|
|
|
Return the body for the model router input.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -100,7 +101,7 @@ def completion_with_retry(
|
|
|
|
models_priority_list: List[GPTRouterModel],
|
|
|
|
models_priority_list: List[GPTRouterModel],
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
|
|
|
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]:
|
|
|
|
"""Use tenacity to retry the completion call."""
|
|
|
|
"""Use tenacity to retry the completion call."""
|
|
|
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
|
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
|
|
|
|
|
|
|
|
|
@ -122,7 +123,7 @@ async def acompletion_with_retry(
|
|
|
|
models_priority_list: List[GPTRouterModel],
|
|
|
|
models_priority_list: List[GPTRouterModel],
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
|
|
|
|
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]:
|
|
|
|
"""Use tenacity to retry the async completion call."""
|
|
|
|
"""Use tenacity to retry the async completion call."""
|
|
|
|
|
|
|
|
|
|
|
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
|
|
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
|
|
|
@ -282,9 +283,9 @@ class GPTRouter(BaseChatModel):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
|
|
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
|
|
|
|
def _create_chat_generation_chunk(
|
|
|
|
self, data: Mapping[str, Any], default_chunk_class
|
|
|
|
self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk]
|
|
|
|
):
|
|
|
|
) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]:
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
chunk = _convert_delta_to_message_chunk(
|
|
|
|
{"content": data.get("text", "")}, default_chunk_class
|
|
|
|
{"content": data.get("text", "")}, default_chunk_class
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -293,8 +294,8 @@ class GPTRouter(BaseChatModel):
|
|
|
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
|
|
|
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
|
|
|
)
|
|
|
|
)
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
default_chunk_class = chunk.__class__
|
|
|
|
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
|
|
|
|
gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
|
|
|
return chunk, default_chunk_class
|
|
|
|
return gen_chunk, default_chunk_class
|
|
|
|
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
def _stream(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
@ -306,7 +307,7 @@ class GPTRouter(BaseChatModel):
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
|
|
|
generator_response = completion_with_retry(
|
|
|
|
generator_response = completion_with_retry(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
messages=message_dicts,
|
|
|
|
messages=message_dicts,
|
|
|
@ -339,7 +340,7 @@ class GPTRouter(BaseChatModel):
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
params = {**params, **kwargs, "stream": True}
|
|
|
|
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
|
|
|
generator_response = acompletion_with_retry(
|
|
|
|
generator_response = acompletion_with_retry(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
messages=message_dicts,
|
|
|
|
messages=message_dicts,
|
|
|
|