from __future__ import annotations import logging from typing import ( TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Tuple, Type, Union, ) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import ( BaseChatModel, agenerate_from_stream, generate_from_stream, ) from langchain_core.language_models.llms import create_base_retry_decorator from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.adapters.openai import ( convert_dict_to_message, convert_message_to_dict, ) from langchain_community.chat_models.openai import _convert_delta_to_message_chunk if TYPE_CHECKING: from gpt_router.models import ChunkedGenerationResponse, GenerationResponse logger = logging.getLogger(__name__) DEFAULT_API_BASE_URL = "https://gpt-router-preview.writesonic.com" class GPTRouterException(Exception): """Error with the `GPTRouter APIs`""" class GPTRouterModel(BaseModel): """GPTRouter model.""" name: str provider_name: str def get_ordered_generation_requests( models_priority_list: List[GPTRouterModel], **kwargs: Any ) -> List: """ Return the body for the model router input. """ from gpt_router.models import GenerationParams, ModelGenerationRequest return [ ModelGenerationRequest( model_name=model.name, provider_name=model.provider_name, order=index + 1, prompt_params=GenerationParams(**kwargs), ) for index, model in enumerate(models_priority_list) ] def _create_retry_decorator( llm: GPTRouter, run_manager: Optional[ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: from gpt_router import exceptions errors = [ exceptions.GPTRouterApiTimeoutError, exceptions.GPTRouterInternalServerError, exceptions.GPTRouterNotAvailableError, exceptions.GPTRouterTooManyRequestsError, ] return create_base_retry_decorator( error_types=errors, max_retries=llm.max_retries, run_manager=run_manager ) def completion_with_retry( llm: GPTRouter, models_priority_list: List[GPTRouterModel], run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse, None, None]]: """Use tenacity to retry the completion call.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: ordered_generation_requests = get_ordered_generation_requests( models_priority_list, **kwargs ) return llm.client.generate( ordered_generation_requests=ordered_generation_requests, is_stream=kwargs.get("stream", False), ) return _completion_with_retry(**kwargs) async def acompletion_with_retry( llm: GPTRouter, models_priority_list: List[GPTRouterModel], run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse, None]]: """Use tenacity to retry the async completion call.""" retry_decorator = _create_retry_decorator(llm, run_manager=run_manager) @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: ordered_generation_requests = get_ordered_generation_requests( models_priority_list, **kwargs ) return await llm.client.agenerate( ordered_generation_requests=ordered_generation_requests, is_stream=kwargs.get("stream", False), ) return await _completion_with_retry(**kwargs) class GPTRouter(BaseChatModel): """GPTRouter by Writesonic Inc. For more information, see https://gpt-router.writesonic.com/docs """ client: Any = Field(default=None, exclude=True) #: :meta private: models_priority_list: List[GPTRouterModel] = Field(min_items=1) gpt_router_api_base: str = Field(default=None) """WriteSonic GPTRouter custom endpoint""" gpt_router_api_key: Optional[SecretStr] = None """WriteSonic GPTRouter API Key""" temperature: float = 0.7 """What sampling temperature to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" max_retries: int = 4 """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" n: int = 1 """Number of chat completions to generate for each prompt.""" max_tokens: int = 256 @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: values["gpt_router_api_base"] = get_from_dict_or_env( values, "gpt_router_api_base", "GPT_ROUTER_API_BASE", DEFAULT_API_BASE_URL, ) values["gpt_router_api_key"] = convert_to_secret_str( get_from_dict_or_env( values, "gpt_router_api_key", "GPT_ROUTER_API_KEY", ) ) try: from gpt_router.client import GPTRouterClient except ImportError: raise GPTRouterException( "Could not import GPTRouter python package. " "Please install it with `pip install GPTRouter`." ) gpt_router_client = GPTRouterClient( values["gpt_router_api_base"], values["gpt_router_api_key"].get_secret_value(), ) values["client"] = gpt_router_client return values @property def lc_secrets(self) -> Dict[str, str]: return {"gpt_router_api_key": "GPT_ROUTER_API_KEY"} @property def lc_serializable(self) -> bool: return True @property def _llm_type(self) -> str: """Return type of chat model.""" return "gpt-router-chat" @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" return { **{"models_priority_list": self.models_priority_list}, **self._default_params, } @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling GPTRouter API.""" return { "max_tokens": self.max_tokens, "stream": self.streaming, "n": self.n, "temperature": self.temperature, **self.model_kwargs, } def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._stream( messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": False} response = completion_with_retry( self, messages=message_dicts, models_priority_list=self.models_priority_list, run_manager=run_manager, **params, ) return self._create_chat_result(response) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: should_stream = stream if stream is not None else self.streaming if should_stream: stream_iter = self._astream( messages, stop=stop, run_manager=run_manager, **kwargs ) return await agenerate_from_stream(stream_iter) message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": False} response = await acompletion_with_retry( self, messages=message_dicts, models_priority_list=self.models_priority_list, run_manager=run_manager, **params, ) return self._create_chat_result(response) def _create_chat_generation_chunk( self, data: Mapping[str, Any], default_chunk_class: Type[BaseMessageChunk] ) -> Tuple[ChatGenerationChunk, Type[BaseMessageChunk]]: chunk = _convert_delta_to_message_chunk( {"content": data.get("text", "")}, default_chunk_class ) finish_reason = data.get("finish_reason") generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None ) default_chunk_class = chunk.__class__ gen_chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) return gen_chunk, default_chunk_class def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk generator_response = completion_with_retry( self, messages=message_dicts, models_priority_list=self.models_priority_list, run_manager=run_manager, **params, ) for chunk in generator_response: if chunk.event != "update": continue chunk, default_chunk_class = self._create_chat_generation_chunk( chunk.data, default_chunk_class ) if run_manager: run_manager.on_llm_new_token( token=chunk.message.content, chunk=chunk.message ) yield chunk async def _astream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk generator_response = acompletion_with_retry( self, messages=message_dicts, models_priority_list=self.models_priority_list, run_manager=run_manager, **params, ) async for chunk in await generator_response: if chunk.event != "update": continue chunk, default_chunk_class = self._create_chat_generation_chunk( chunk.data, default_chunk_class ) if run_manager: await run_manager.on_llm_new_token( token=chunk.message.content, chunk=chunk.message ) yield chunk def _create_message_dicts( self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params = self._default_params if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop message_dicts = [convert_message_to_dict(m) for m in messages] return message_dicts, params def _create_chat_result(self, response: GenerationResponse) -> ChatResult: generations = [] for res in response.choices: message = convert_dict_to_message( { "role": "assistant", "content": res.text, } ) gen = ChatGeneration( message=message, generation_info=dict(finish_reason=res.finish_reason), ) generations.append(gen) llm_output = {"token_usage": response.meta, "model": response.model} return ChatResult(generations=generations, llm_output=llm_output)