2024-03-26 01:37:19 +00:00
|
|
|
"""Wrapper around Prem's Chat API."""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import logging
|
2024-06-04 17:19:54 +00:00
|
|
|
import warnings
|
2024-03-26 01:37:19 +00:00
|
|
|
from typing import (
|
|
|
|
TYPE_CHECKING,
|
|
|
|
Any,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
Iterator,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
Tuple,
|
|
|
|
Type,
|
|
|
|
Union,
|
|
|
|
)
|
|
|
|
|
|
|
|
from langchain_core.callbacks import (
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
from langchain_core.language_models.llms import create_base_retry_decorator
|
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
AIMessageChunk,
|
|
|
|
BaseMessage,
|
|
|
|
BaseMessageChunk,
|
|
|
|
ChatMessage,
|
|
|
|
ChatMessageChunk,
|
|
|
|
HumanMessage,
|
|
|
|
HumanMessageChunk,
|
|
|
|
SystemMessage,
|
|
|
|
SystemMessageChunk,
|
|
|
|
)
|
|
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
2024-05-06 22:12:29 +00:00
|
|
|
from langchain_core.pydantic_v1 import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
Field,
|
|
|
|
SecretStr,
|
|
|
|
root_validator,
|
|
|
|
)
|
2024-03-26 01:37:19 +00:00
|
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from premai.api.chat_completions.v1_chat_completions_create import (
|
|
|
|
ChatCompletionResponseStream,
|
|
|
|
)
|
|
|
|
from premai.models.chat_completion_response import ChatCompletionResponse
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class ChatPremAPIError(Exception):
|
|
|
|
"""Error with the `PremAI` API."""
|
|
|
|
|
|
|
|
|
|
|
|
def _truncate_at_stop_tokens(
|
|
|
|
text: str,
|
|
|
|
stop: Optional[List[str]],
|
|
|
|
) -> str:
|
|
|
|
"""Truncates text at the earliest stop token found."""
|
|
|
|
if stop is None:
|
|
|
|
return text
|
|
|
|
|
|
|
|
for stop_token in stop:
|
|
|
|
stop_token_idx = text.find(stop_token)
|
|
|
|
if stop_token_idx != -1:
|
|
|
|
text = text[:stop_token_idx]
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
def _response_to_result(
|
|
|
|
response: ChatCompletionResponse,
|
|
|
|
stop: Optional[List[str]],
|
|
|
|
) -> ChatResult:
|
|
|
|
"""Converts a Prem API response into a LangChain result"""
|
|
|
|
|
|
|
|
if not response.choices:
|
|
|
|
raise ChatPremAPIError("ChatResponse must have at least one candidate")
|
|
|
|
generations: List[ChatGeneration] = []
|
|
|
|
for choice in response.choices:
|
|
|
|
role = choice.message.role
|
|
|
|
if role is None:
|
|
|
|
raise ChatPremAPIError(f"ChatResponse {choice} must have a role.")
|
|
|
|
|
|
|
|
# If content is None then it will be replaced by ""
|
|
|
|
content = _truncate_at_stop_tokens(text=choice.message.content or "", stop=stop)
|
|
|
|
if content is None:
|
|
|
|
raise ChatPremAPIError(f"ChatResponse must have a content: {content}")
|
|
|
|
|
|
|
|
if role == "assistant":
|
|
|
|
generations.append(
|
|
|
|
ChatGeneration(text=content, message=AIMessage(content=content))
|
|
|
|
)
|
|
|
|
elif role == "user":
|
|
|
|
generations.append(
|
|
|
|
ChatGeneration(text=content, message=HumanMessage(content=content))
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
generations.append(
|
|
|
|
ChatGeneration(
|
|
|
|
text=content, message=ChatMessage(role=role, content=content)
|
|
|
|
)
|
|
|
|
)
|
2024-06-04 17:19:54 +00:00
|
|
|
|
|
|
|
if response.document_chunks is not None:
|
|
|
|
return ChatResult(
|
|
|
|
generations=generations,
|
|
|
|
llm_output={
|
|
|
|
"document_chunks": [
|
|
|
|
chunk.to_dict() for chunk in response.document_chunks
|
|
|
|
]
|
|
|
|
},
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
return ChatResult(generations=generations, llm_output={"document_chunks": None})
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _convert_delta_response_to_message_chunk(
|
|
|
|
response: ChatCompletionResponseStream, default_class: Type[BaseMessageChunk]
|
|
|
|
) -> Tuple[
|
|
|
|
Union[BaseMessageChunk, HumanMessageChunk, AIMessageChunk, SystemMessageChunk],
|
|
|
|
Optional[str],
|
|
|
|
]:
|
|
|
|
"""Converts delta response to message chunk"""
|
|
|
|
_delta = response.choices[0].delta # type: ignore
|
|
|
|
role = _delta.get("role", "") # type: ignore
|
|
|
|
content = _delta.get("content", "") # type: ignore
|
|
|
|
additional_kwargs: Dict = {}
|
|
|
|
finish_reasons: Optional[str] = response.choices[0].finish_reason
|
|
|
|
|
|
|
|
if role == "user" or default_class == HumanMessageChunk:
|
|
|
|
return HumanMessageChunk(content=content), finish_reasons
|
|
|
|
elif role == "assistant" or default_class == AIMessageChunk:
|
|
|
|
return (
|
|
|
|
AIMessageChunk(content=content, additional_kwargs=additional_kwargs),
|
|
|
|
finish_reasons,
|
|
|
|
)
|
|
|
|
elif role == "system" or default_class == SystemMessageChunk:
|
|
|
|
return SystemMessageChunk(content=content), finish_reasons
|
|
|
|
elif role or default_class == ChatMessageChunk:
|
|
|
|
return ChatMessageChunk(content=content, role=role), finish_reasons
|
|
|
|
else:
|
2024-05-13 18:55:07 +00:00
|
|
|
return default_class(content=content), finish_reasons # type: ignore[call-arg]
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _messages_to_prompt_dict(
|
|
|
|
input_messages: List[BaseMessage],
|
2024-06-14 02:59:28 +00:00
|
|
|
template_id: Optional[str] = None,
|
|
|
|
) -> Tuple[Optional[str], List[Dict[str, Any]]]:
|
2024-03-26 01:37:19 +00:00
|
|
|
"""Converts a list of LangChain Messages into a simple dict
|
|
|
|
which is the message structure in Prem"""
|
|
|
|
|
|
|
|
system_prompt: Optional[str] = None
|
2024-06-14 02:59:28 +00:00
|
|
|
examples_and_messages: List[Dict[str, Any]] = []
|
|
|
|
|
|
|
|
if template_id is not None:
|
|
|
|
params: Dict[str, str] = {}
|
|
|
|
for input_msg in input_messages:
|
|
|
|
if isinstance(input_msg, SystemMessage):
|
|
|
|
system_prompt = str(input_msg.content)
|
|
|
|
else:
|
|
|
|
assert (input_msg.id is not None) and (input_msg.id != ""), ValueError(
|
|
|
|
"When using prompt template there should be id associated ",
|
|
|
|
"with each HumanMessage",
|
|
|
|
)
|
|
|
|
params[str(input_msg.id)] = str(input_msg.content)
|
|
|
|
|
|
|
|
examples_and_messages.append(
|
|
|
|
{"role": "user", "template_id": template_id, "params": params}
|
|
|
|
)
|
|
|
|
|
|
|
|
for input_msg in input_messages:
|
|
|
|
if isinstance(input_msg, AIMessage):
|
|
|
|
examples_and_messages.append(
|
|
|
|
{"role": "assistant", "content": str(input_msg.content)}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
for input_msg in input_messages:
|
|
|
|
if isinstance(input_msg, SystemMessage):
|
|
|
|
system_prompt = str(input_msg.content)
|
|
|
|
elif isinstance(input_msg, HumanMessage):
|
|
|
|
examples_and_messages.append(
|
|
|
|
{"role": "user", "content": str(input_msg.content)}
|
|
|
|
)
|
|
|
|
elif isinstance(input_msg, AIMessage):
|
|
|
|
examples_and_messages.append(
|
|
|
|
{"role": "assistant", "content": str(input_msg.content)}
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ChatPremAPIError("No such role explicitly exists")
|
2024-03-26 01:37:19 +00:00
|
|
|
return system_prompt, examples_and_messages
|
|
|
|
|
|
|
|
|
|
|
|
class ChatPremAI(BaseChatModel, BaseModel):
|
2024-04-11 20:23:27 +00:00
|
|
|
"""PremAI Chat models.
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
To use, you will need to have an API key. You can find your existing API Key
|
|
|
|
or generate a new one here: https://app.premai.io/api_keys/
|
|
|
|
"""
|
|
|
|
|
|
|
|
# TODO: Need to add the default parameters through prem-sdk here
|
|
|
|
|
|
|
|
project_id: int
|
|
|
|
"""The project ID in which the experiments or deployments are carried out.
|
|
|
|
You can find all your projects here: https://app.premai.io/projects/"""
|
2024-05-06 22:12:29 +00:00
|
|
|
premai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
2024-03-26 01:37:19 +00:00
|
|
|
"""Prem AI API Key. Get it here: https://app.premai.io/api_keys/"""
|
|
|
|
|
2024-05-06 22:12:29 +00:00
|
|
|
model: Optional[str] = Field(default=None, alias="model_name")
|
2024-03-26 01:37:19 +00:00
|
|
|
"""Name of the model. This is an optional parameter.
|
|
|
|
The default model is the one deployed from Prem's LaunchPad: https://app.premai.io/projects/8/launchpad
|
|
|
|
If model name is other than default model then it will override the calls
|
|
|
|
from the model deployed from launchpad."""
|
|
|
|
|
|
|
|
temperature: Optional[float] = None
|
|
|
|
"""Model temperature. Value should be >= 0 and <= 1.0"""
|
|
|
|
|
|
|
|
max_tokens: Optional[int] = None
|
|
|
|
"""The maximum number of tokens to generate"""
|
|
|
|
|
|
|
|
max_retries: int = 1
|
|
|
|
"""Max number of retries to call the API"""
|
|
|
|
|
|
|
|
system_prompt: Optional[str] = ""
|
|
|
|
"""Acts like a default instruction that helps the LLM act or generate
|
|
|
|
in a specific way.This is an Optional Parameter. By default the
|
|
|
|
system prompt would be using Prem's Launchpad models system prompt.
|
|
|
|
Changing the system prompt would override the default system prompt.
|
|
|
|
"""
|
|
|
|
|
2024-06-04 17:19:54 +00:00
|
|
|
repositories: Optional[dict] = None
|
|
|
|
"""Add valid repository ids. This will be overriding existing connected
|
|
|
|
repositories (if any) and will use RAG with the connected repos.
|
|
|
|
"""
|
|
|
|
|
2024-03-26 01:37:19 +00:00
|
|
|
streaming: Optional[bool] = False
|
|
|
|
"""Whether to stream the responses or not."""
|
|
|
|
|
|
|
|
client: Any
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
2024-05-06 22:12:29 +00:00
|
|
|
allow_population_by_field_name = True
|
|
|
|
arbitrary_types_allowed = True
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
def validate_environments(cls, values: Dict) -> Dict:
|
|
|
|
"""Validate that the package is installed and that the API token is valid"""
|
|
|
|
try:
|
|
|
|
from premai import Prem
|
|
|
|
except ImportError as error:
|
|
|
|
raise ImportError(
|
|
|
|
"Could not import Prem Python package."
|
|
|
|
"Please install it with: `pip install premai`"
|
|
|
|
) from error
|
|
|
|
|
|
|
|
try:
|
2024-06-14 02:59:28 +00:00
|
|
|
premai_api_key: Union[str, SecretStr] = get_from_dict_or_env(
|
2024-03-26 01:37:19 +00:00
|
|
|
values, "premai_api_key", "PREMAI_API_KEY"
|
|
|
|
)
|
2024-06-14 02:59:28 +00:00
|
|
|
values["client"] = Prem(
|
|
|
|
api_key=premai_api_key
|
|
|
|
if isinstance(premai_api_key, str)
|
|
|
|
else premai_api_key._secret_value
|
|
|
|
)
|
2024-03-26 01:37:19 +00:00
|
|
|
except Exception as error:
|
|
|
|
raise ValueError("Your API Key is incorrect. Please try again.") from error
|
|
|
|
return values
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
return "premai"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
|
|
return {
|
|
|
|
"model": self.model,
|
|
|
|
"system_prompt": self.system_prompt,
|
|
|
|
"temperature": self.temperature,
|
|
|
|
"max_tokens": self.max_tokens,
|
2024-06-04 17:19:54 +00:00
|
|
|
"repositories": self.repositories,
|
2024-03-26 01:37:19 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
2024-06-04 17:19:54 +00:00
|
|
|
kwargs_to_ignore = [
|
|
|
|
"top_p",
|
|
|
|
"tools",
|
|
|
|
"frequency_penalty",
|
|
|
|
"presence_penalty",
|
|
|
|
"logit_bias",
|
|
|
|
"stop",
|
|
|
|
"seed",
|
|
|
|
]
|
|
|
|
keys_to_remove = []
|
|
|
|
|
|
|
|
for key in kwargs:
|
|
|
|
if key in kwargs_to_ignore:
|
|
|
|
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
|
|
|
|
keys_to_remove.append(key)
|
|
|
|
|
|
|
|
for key in keys_to_remove:
|
|
|
|
kwargs.pop(key)
|
|
|
|
|
2024-03-26 01:37:19 +00:00
|
|
|
all_kwargs = {**self._default_params, **kwargs}
|
|
|
|
for key in list(self._default_params.keys()):
|
|
|
|
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
|
|
|
|
all_kwargs.pop(key, None)
|
|
|
|
return all_kwargs
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> ChatResult:
|
2024-06-14 02:59:28 +00:00
|
|
|
if "template_id" in kwargs:
|
|
|
|
system_prompt, messages_to_pass = _messages_to_prompt_dict(
|
|
|
|
messages, template_id=kwargs["template_id"]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
if system_prompt is not None and system_prompt != "":
|
|
|
|
kwargs["system_prompt"] = system_prompt
|
|
|
|
|
|
|
|
all_kwargs = self._get_all_kwargs(**kwargs)
|
|
|
|
response = chat_with_retry(
|
|
|
|
self,
|
|
|
|
project_id=self.project_id,
|
|
|
|
messages=messages_to_pass,
|
|
|
|
stream=False,
|
|
|
|
run_manager=run_manager,
|
|
|
|
**all_kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
return _response_to_result(response=response, stop=stop)
|
|
|
|
|
|
|
|
def _stream(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Iterator[ChatGenerationChunk]:
|
2024-06-14 02:59:28 +00:00
|
|
|
if "template_id" in kwargs:
|
|
|
|
system_prompt, messages_to_pass = _messages_to_prompt_dict(
|
|
|
|
messages, template_id=kwargs["template_id"]
|
|
|
|
) # type: ignore
|
|
|
|
else:
|
|
|
|
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
2024-06-04 17:19:54 +00:00
|
|
|
|
|
|
|
if stop is not None:
|
|
|
|
logger.warning("stop is not supported in langchain streaming")
|
2024-03-26 01:37:19 +00:00
|
|
|
|
|
|
|
if "system_prompt" not in kwargs:
|
|
|
|
if system_prompt is not None and system_prompt != "":
|
|
|
|
kwargs["system_prompt"] = system_prompt
|
|
|
|
|
|
|
|
all_kwargs = self._get_all_kwargs(**kwargs)
|
|
|
|
|
|
|
|
default_chunk_class = AIMessageChunk
|
|
|
|
|
|
|
|
for streamed_response in chat_with_retry(
|
|
|
|
self,
|
|
|
|
project_id=self.project_id,
|
|
|
|
messages=messages_to_pass,
|
|
|
|
stream=True,
|
|
|
|
run_manager=run_manager,
|
|
|
|
**all_kwargs,
|
|
|
|
):
|
|
|
|
try:
|
|
|
|
chunk, finish_reason = _convert_delta_response_to_message_chunk(
|
|
|
|
response=streamed_response, default_class=default_chunk_class
|
|
|
|
)
|
|
|
|
generation_info = (
|
|
|
|
dict(finish_reason=finish_reason)
|
|
|
|
if finish_reason is not None
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
cg_chunk = ChatGenerationChunk(
|
|
|
|
message=chunk, generation_info=generation_info
|
|
|
|
)
|
|
|
|
if run_manager:
|
|
|
|
run_manager.on_llm_new_token(cg_chunk.text, chunk=cg_chunk)
|
|
|
|
yield cg_chunk
|
|
|
|
except Exception as _:
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
def create_prem_retry_decorator(
|
|
|
|
llm: ChatPremAI,
|
|
|
|
*,
|
|
|
|
max_retries: int = 1,
|
|
|
|
run_manager: Optional[Union[CallbackManagerForLLMRun]] = None,
|
|
|
|
) -> Callable[[Any], Any]:
|
2024-04-11 20:23:27 +00:00
|
|
|
"""Create a retry decorator for PremAI API errors."""
|
2024-03-26 01:37:19 +00:00
|
|
|
import premai.models
|
|
|
|
|
|
|
|
errors = [
|
|
|
|
premai.models.api_response_validation_error.APIResponseValidationError,
|
|
|
|
premai.models.conflict_error.ConflictError,
|
|
|
|
premai.models.model_not_found_error.ModelNotFoundError,
|
|
|
|
premai.models.permission_denied_error.PermissionDeniedError,
|
|
|
|
premai.models.provider_api_connection_error.ProviderAPIConnectionError,
|
|
|
|
premai.models.provider_api_status_error.ProviderAPIStatusError,
|
|
|
|
premai.models.provider_api_timeout_error.ProviderAPITimeoutError,
|
|
|
|
premai.models.provider_internal_server_error.ProviderInternalServerError,
|
|
|
|
premai.models.provider_not_found_error.ProviderNotFoundError,
|
|
|
|
premai.models.rate_limit_error.RateLimitError,
|
|
|
|
premai.models.unprocessable_entity_error.UnprocessableEntityError,
|
|
|
|
premai.models.validation_error.ValidationError,
|
|
|
|
]
|
|
|
|
|
|
|
|
decorator = create_base_retry_decorator(
|
|
|
|
error_types=errors, max_retries=max_retries, run_manager=run_manager
|
|
|
|
)
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
def chat_with_retry(
|
|
|
|
llm: ChatPremAI,
|
|
|
|
project_id: int,
|
|
|
|
messages: List[dict],
|
|
|
|
stream: bool = False,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Any:
|
|
|
|
"""Using tenacity for retry in completion call"""
|
|
|
|
retry_decorator = create_prem_retry_decorator(
|
|
|
|
llm, max_retries=llm.max_retries, run_manager=run_manager
|
|
|
|
)
|
|
|
|
|
|
|
|
@retry_decorator
|
|
|
|
def _completion_with_retry(
|
|
|
|
project_id: int,
|
|
|
|
messages: List[dict],
|
|
|
|
stream: Optional[bool] = False,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> Any:
|
|
|
|
response = llm.client.chat.completions.create(
|
|
|
|
project_id=project_id,
|
|
|
|
messages=messages,
|
|
|
|
stream=stream,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
|
|
return _completion_with_retry(
|
|
|
|
project_id=project_id,
|
|
|
|
messages=messages,
|
|
|
|
stream=stream,
|
|
|
|
**kwargs,
|
|
|
|
)
|