|
|
|
@ -2,62 +2,50 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional, Tuple
|
|
|
|
|
from typing import Any, Dict
|
|
|
|
|
|
|
|
|
|
from pydantic import root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.chat_models.openai import (
|
|
|
|
|
ChatOpenAI,
|
|
|
|
|
acompletion_with_retry,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
ChatResult,
|
|
|
|
|
)
|
|
|
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_chat_prompt(messages: List[BaseMessage]) -> str:
|
|
|
|
|
"""Create a prompt for Azure OpenAI using ChatML."""
|
|
|
|
|
prompt = "\n".join([message.format_chatml() for message in messages])
|
|
|
|
|
return prompt + "\n<|im_start|>assistant\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
|
|
|
|
generations = []
|
|
|
|
|
for res in response["choices"]:
|
|
|
|
|
message = AIMessage(content=res["text"])
|
|
|
|
|
gen = ChatGeneration(message=message)
|
|
|
|
|
generations.append(gen)
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AzureChatOpenAI(ChatOpenAI):
|
|
|
|
|
"""Wrapper around Azure OpenAI Chat large language models.
|
|
|
|
|
"""Wrapper around Azure OpenAI Chat Completion API. To use this class you
|
|
|
|
|
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
|
|
|
|
|
constructor to refer to the "Model deployment name" in the Azure portal.
|
|
|
|
|
|
|
|
|
|
To use, you should have the ``openai`` python package installed, and the
|
|
|
|
|
following environment variables set:
|
|
|
|
|
- ``OPENAI_API_TYPE``
|
|
|
|
|
In addition, you should have the ``openai`` python package installed, and the
|
|
|
|
|
following environment variables set or passed in constructor in lower case:
|
|
|
|
|
- ``OPENAI_API_TYPE`` (default: ``azure``)
|
|
|
|
|
- ``OPENAI_API_KEY``
|
|
|
|
|
- ``OPENAI_API_BASE``
|
|
|
|
|
- ``OPENAI_API_VERSION``
|
|
|
|
|
|
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
|
|
|
in, even if not explicitly saved on this class.
|
|
|
|
|
For exmaple, if you have `gpt-35-turbo` deployed, with the deployment name
|
|
|
|
|
`35-turbo-dev`, the constructor should look like:
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
AzureChatOpenAI(
|
|
|
|
|
deployment_name="35-turbo-dev",
|
|
|
|
|
openai_api_version="2023-03-15-preview",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain.chat_models import AzureChatOpenAI
|
|
|
|
|
openai = AzureChatOpenAI(deployment_name="<your deployment name>")
|
|
|
|
|
Be aware the API version may change.
|
|
|
|
|
|
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
|
|
|
in, even if not explicitly saved on this class.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
deployment_name: str = ""
|
|
|
|
|
stop: List[str] = ["<|im_end|>"]
|
|
|
|
|
openai_api_type: str = "azure"
|
|
|
|
|
openai_api_base: str = ""
|
|
|
|
|
openai_api_version: str = ""
|
|
|
|
|
openai_api_key: str = ""
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
@ -95,10 +83,10 @@ class AzureChatOpenAI(ChatOpenAI):
|
|
|
|
|
"Please it install it with `pip install openai`."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
values["client"] = openai.Completion
|
|
|
|
|
values["client"] = openai.ChatCompletion
|
|
|
|
|
except AttributeError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"`openai` has no `Completion` attribute, this is likely "
|
|
|
|
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
|
|
|
|
"due to an old version of the openai package. Try upgrading it "
|
|
|
|
|
"with `pip install --upgrade openai`."
|
|
|
|
|
)
|
|
|
|
@ -113,66 +101,5 @@ class AzureChatOpenAI(ChatOpenAI):
|
|
|
|
|
"""Get the default parameters for calling OpenAI API."""
|
|
|
|
|
return {
|
|
|
|
|
**super()._default_params,
|
|
|
|
|
"stop": self.stop,
|
|
|
|
|
"engine": self.deployment_name,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
prompt, params = self._create_prompt(messages, stop)
|
|
|
|
|
if self.streaming:
|
|
|
|
|
inner_completion = ""
|
|
|
|
|
params["stream"] = True
|
|
|
|
|
for stream_resp in self.completion_with_retry(prompt=prompt, **params):
|
|
|
|
|
token = stream_resp["choices"][0]["delta"].get("text", "")
|
|
|
|
|
inner_completion += token
|
|
|
|
|
self.callback_manager.on_llm_new_token(
|
|
|
|
|
token,
|
|
|
|
|
verbose=self.verbose,
|
|
|
|
|
)
|
|
|
|
|
message = AIMessage(content=inner_completion)
|
|
|
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
|
response = self.completion_with_retry(prompt=prompt, **params)
|
|
|
|
|
return _create_chat_result(response)
|
|
|
|
|
|
|
|
|
|
def _create_prompt(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
|
|
|
|
) -> Tuple[str, Dict[str, Any]]:
|
|
|
|
|
params: Dict[str, Any] = {
|
|
|
|
|
**{"model": self.model_name, "engine": self.deployment_name},
|
|
|
|
|
**self._default_params,
|
|
|
|
|
}
|
|
|
|
|
if stop is not None:
|
|
|
|
|
if "stop" in params:
|
|
|
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
|
|
|
params["stop"] = stop
|
|
|
|
|
prompt = _create_chat_prompt(messages)
|
|
|
|
|
return prompt, params
|
|
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
prompt, params = self._create_prompt(messages, stop)
|
|
|
|
|
if self.streaming:
|
|
|
|
|
inner_completion = ""
|
|
|
|
|
params["stream"] = True
|
|
|
|
|
async for stream_resp in await acompletion_with_retry(
|
|
|
|
|
self, prompt=prompt, **params
|
|
|
|
|
):
|
|
|
|
|
token = stream_resp["choices"][0]["delta"].get("text", "")
|
|
|
|
|
inner_completion += token
|
|
|
|
|
if self.callback_manager.is_async:
|
|
|
|
|
await self.callback_manager.on_llm_new_token(
|
|
|
|
|
token,
|
|
|
|
|
verbose=self.verbose,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
self.callback_manager.on_llm_new_token(
|
|
|
|
|
token,
|
|
|
|
|
verbose=self.verbose,
|
|
|
|
|
)
|
|
|
|
|
message = AIMessage(content=inner_completion)
|
|
|
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
|
else:
|
|
|
|
|
response = await acompletion_with_retry(self, prompt=prompt, **params)
|
|
|
|
|
return _create_chat_result(response)
|
|
|
|
|