|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
from tenacity import (
|
|
|
|
@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
|
|
|
return message_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
|
|
|
|
generations = []
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = _convert_dict_to_message(res["message"])
|
|
|
|
|
gen = ChatGeneration(message=message)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
"""Wrapper around OpenAI Chat large language models.
|
|
|
|
|
|
|
|
|
@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
def _generate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
|
|
|
|
if stop is not None:
|
|
|
|
|
if "stop" in params:
|
|
|
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
|
|
|
params["stop"] = stop
|
|
|
|
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
|
if self.streaming:
|
|
|
|
|
inner_completion = ""
|
|
|
|
|
role = "assistant"
|
|
|
|
@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
|
|
|
|
generations = []
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = _convert_dict_to_message(res["message"])
|
|
|
|
|
gen = ChatGeneration(message=message)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
return _create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
def _create_message_dicts(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
|
|
|
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
|
|
|
|
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
|
|
|
|
if stop is not None:
|
|
|
|
|
if "stop" in params:
|
|
|
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
|
|
|
params["stop"] = stop
|
|
|
|
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
|
|
|
|
return message_dicts, params
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
|
if self.streaming:
|
|
|
|
|
inner_completion = ""
|
|
|
|
|
role = "assistant"
|
|
|
|
@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
)
|
|
|
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
|
else:
|
|
|
|
|
full_response = await acompletion_with_retry(
|
|
|
|
|
response = await acompletion_with_retry(
|
|
|
|
|
self, messages=message_dicts, **params
|
|
|
|
|
)
|
|
|
|
|
generations = []
|
|
|
|
|
for res in full_response["choices"]:
|
|
|
|
|
message = _convert_dict_to_message(res["message"])
|
|
|
|
|
gen = ChatGeneration(message=message)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
return _create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
|