|
|
@ -5,15 +5,6 @@ import logging
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
|
|
from tenacity import (
|
|
|
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
|
|
|
retry,
|
|
|
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
|
|
|
wait_exponential,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
from langchain.schema import (
|
|
|
|
from langchain.schema import (
|
|
|
|
AIMessage,
|
|
|
|
AIMessage,
|
|
|
@ -25,6 +16,14 @@ from langchain.schema import (
|
|
|
|
SystemMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
|
|
from tenacity import (
|
|
|
|
|
|
|
|
before_sleep_log,
|
|
|
|
|
|
|
|
retry,
|
|
|
|
|
|
|
|
retry_if_exception_type,
|
|
|
|
|
|
|
|
stop_after_attempt,
|
|
|
|
|
|
|
|
wait_exponential,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
from logger import logger
|
|
|
|
from logger import logger
|
|
|
|
|
|
|
|
|
|
|
@ -128,7 +127,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
"""Whether to stream the results or not."""
|
|
|
|
"""Whether to stream the results or not."""
|
|
|
|
n: int = 1
|
|
|
|
n: int = 1
|
|
|
|
"""Number of chat completions to generate for each prompt."""
|
|
|
|
"""Number of chat completions to generate for each prompt."""
|
|
|
|
max_tokens: int = 256
|
|
|
|
max_tokens: int = 2048
|
|
|
|
"""Maximum number of tokens to generate."""
|
|
|
|
"""Maximum number of tokens to generate."""
|
|
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
class Config:
|
|
|
@ -226,7 +225,6 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
def _generate(
|
|
|
|
def _generate(
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
) -> ChatResult:
|
|
|
|
) -> ChatResult:
|
|
|
|
|
|
|
|
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
|
|
logger.debug("Messages:\n")
|
|
|
|
logger.debug("Messages:\n")
|
|
|
|
for item in message_dicts:
|
|
|
|
for item in message_dicts:
|
|
|
|