From a1603fccfbde7c44a66e313587f9fe075dd76978 Mon Sep 17 00:00:00 2001 From: Delgermurun Date: Sat, 8 Jul 2023 08:17:04 +0200 Subject: [PATCH] integrate JinaChat (#6927) Integration with https://chat.jina.ai/api. It is OpenAI compatible API. - Twitter handle: [https://twitter.com/JinaAI_](https://twitter.com/JinaAI_) --------- Co-authored-by: Harrison Chase --- .../models/chat/integrations/jinachat.ipynb | 162 ++++++++ langchain/chat_models/__init__.py | 2 + langchain/chat_models/jinachat.py | 357 ++++++++++++++++++ .../chat_models/test_jinachat.py | 127 +++++++ 4 files changed, 648 insertions(+) create mode 100644 docs/extras/modules/model_io/models/chat/integrations/jinachat.ipynb create mode 100644 langchain/chat_models/jinachat.py create mode 100644 tests/integration_tests/chat_models/test_jinachat.py diff --git a/docs/extras/modules/model_io/models/chat/integrations/jinachat.ipynb b/docs/extras/modules/model_io/models/chat/integrations/jinachat.ipynb new file mode 100644 index 0000000000..18fac8b41a --- /dev/null +++ b/docs/extras/modules/model_io/models/chat/integrations/jinachat.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e49f1e0d", + "metadata": {}, + "source": [ + "# JinaChat\n", + "\n", + "This notebook covers how to get started with JinaChat chat models." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "522686de", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.chat_models import JinaChat\n", + "from langchain.prompts.chat import (\n", + " ChatPromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + " AIMessagePromptTemplate,\n", + " HumanMessagePromptTemplate,\n", + ")\n", + "from langchain.schema import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "62e0dbc3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat = JinaChat(temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ce16ad78-8e6f-48cd-954e-98be75eb5836", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful assistant that translates English to French.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"Translate this sentence from English to French. I love programming.\"\n", + " ),\n", + "]\n", + "chat(messages)" + ] + }, + { + "cell_type": "markdown", + "id": "778f912a-66ea-4a5d-b3de-6c7db4baba26", + "metadata": {}, + "source": [ + "You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n", + "\n", + "For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "180c5cc8", + "metadata": {}, + "outputs": [], + "source": [ + "template = (\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n", + ")\n", + "system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n", + "human_template = \"{text}\"\n", + "human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fbb043e6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_prompt = ChatPromptTemplate.from_messages(\n", + " [system_message_prompt, human_message_prompt]\n", + ")\n", + "\n", + "# get a chat completion from the formatted messages\n", + "chat(\n", + " chat_prompt.format_prompt(\n", + " input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n", + " ).to_messages()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c095285d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py index 8f02497bd8..f58acc8dd4 100644 --- a/langchain/chat_models/__init__.py +++ b/langchain/chat_models/__init__.py @@ -3,6 +3,7 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.human import HumanInputChatModel +from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.vertexai import ChatVertexAI @@ -15,5 +16,6 @@ __all__ = [ "ChatAnthropic", "ChatGooglePalm", "ChatVertexAI", + "JinaChat", "HumanInputChatModel", ] diff --git a/langchain/chat_models/jinachat.py b/langchain/chat_models/jinachat.py new file mode 100644 index 0000000000..e5410b92b7 --- /dev/null +++ b/langchain/chat_models/jinachat.py @@ -0,0 +1,357 @@ +"""JinaChat wrapper.""" +from __future__ import annotations + +import logging +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Tuple, + Union, +) + +from pydantic import Field, root_validator +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatMessage, + ChatResult, + HumanMessage, + SystemMessage, +) +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]: + import openai + + min_seconds = 1 + max_seconds = 60 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(llm.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + +async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + # Use OpenAI's async api https://github.com/openai/openai-python#async-api + return await llm.client.acreate(**kwargs) + + return await _completion_with_retry(**kwargs) + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + if role == "user": + return HumanMessage(content=_dict["content"]) + elif role == "assistant": + content = _dict["content"] or "" + return AIMessage(content=content) + elif role == "system": + return SystemMessage(content=_dict["content"]) + else: + return ChatMessage(content=_dict["content"], role=role) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +class JinaChat(BaseChatModel): + """Wrapper around JinaChat API. + + To use, you should have the ``openai`` python package installed, and the + environment variable ``JINACHAT_API_KEY`` set with your API key. + + Any parameters that are valid to be passed to the openai.create call can be passed + in, even if not explicitly saved on this class. + + Example: + .. code-block:: python + + from langchain.chat_models import JinaChat + chat = JinaChat() + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return {"jinachat_api_key": "JINACHAT_API_KEY"} + + @property + def lc_serializable(self) -> bool: + return True + + client: Any #: :meta private: + temperature: float = 0.7 + """What sampling temperature to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + jinachat_api_key: Optional[str] = None + """Base URL path for API requests, + leave blank if not using a proxy or service emulator.""" + request_timeout: Optional[Union[float, Tuple[float, float]]] = None + """Timeout for requests to JinaChat completion API. Default is 600 seconds.""" + max_retries: int = 6 + """Maximum number of retries to make when generating.""" + streaming: bool = False + """Whether to stream the results or not.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = cls.all_required_field_names() + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["jinachat_api_key"] = get_from_dict_or_env( + values, "jinachat_api_key", "JINACHAT_API_KEY" + ) + try: + import openai + + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling JinaChat API.""" + return { + "request_timeout": self.request_timeout, + "max_tokens": self.max_tokens, + "stream": self.streaming, + "temperature": self.temperature, + **self.model_kwargs, + } + + def _create_retry_decorator(self) -> Callable[[Any], Any]: + import openai + + min_seconds = 1 + max_seconds = 60 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def completion_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = self._create_retry_decorator() + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + + def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: + overall_token_usage: dict = {} + for output in llm_outputs: + if output is None: + # Happens in streaming + continue + token_usage = output["token_usage"] + for k, v in token_usage.items(): + if k in overall_token_usage: + overall_token_usage[k] += v + else: + overall_token_usage[k] = v + return {"token_usage": overall_token_usage} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + if self.streaming: + inner_completion = "" + role = "assistant" + params["stream"] = True + for stream_resp in self.completion_with_retry( + messages=message_dicts, **params + ): + role = stream_resp["choices"][0]["delta"].get("role", role) + token = stream_resp["choices"][0]["delta"].get("content") or "" + inner_completion += token + if run_manager: + run_manager.on_llm_new_token(token) + message = _convert_dict_to_message( + { + "content": inner_completion, + "role": role, + } + ) + return ChatResult(generations=[ChatGeneration(message=message)]) + response = self.completion_with_retry(messages=message_dicts, **params) + return self._create_chat_result(response) + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = dict(self._invocation_params) + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + gen = ChatGeneration(message=message) + generations.append(gen) + llm_output = {"token_usage": response["usage"]} + return ChatResult(generations=generations, llm_output=llm_output) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} + if self.streaming: + inner_completion = "" + role = "assistant" + params["stream"] = True + async for stream_resp in await acompletion_with_retry( + self, messages=message_dicts, **params + ): + role = stream_resp["choices"][0]["delta"].get("role", role) + token = stream_resp["choices"][0]["delta"].get("content", "") + inner_completion += token or "" + if run_manager: + await run_manager.on_llm_new_token(token) + message = _convert_dict_to_message( + { + "content": inner_completion, + "role": role, + } + ) + return ChatResult(generations=[ChatGeneration(message=message)]) + else: + response = await acompletion_with_retry( + self, messages=message_dicts, **params + ) + return self._create_chat_result(response) + + @property + def _invocation_params(self) -> Mapping[str, Any]: + """Get the parameters used to invoke the model.""" + jinachat_creds: Dict[str, Any] = { + "api_key": self.jinachat_api_key, + "api_base": "https://api.chat.jina.ai/v1", + "model": "jinachat", + } + return {**jinachat_creds, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "jinachat" diff --git a/tests/integration_tests/chat_models/test_jinachat.py b/tests/integration_tests/chat_models/test_jinachat.py new file mode 100644 index 0000000000..f4100a2138 --- /dev/null +++ b/tests/integration_tests/chat_models/test_jinachat.py @@ -0,0 +1,127 @@ +"""Test JinaChat wrapper.""" + + +import pytest + +from langchain.callbacks.manager import CallbackManager +from langchain.chat_models.jinachat import JinaChat +from langchain.schema import ( + BaseMessage, + ChatGeneration, + HumanMessage, + LLMResult, + SystemMessage, +) +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + + +def test_jinachat() -> None: + """Test JinaChat wrapper.""" + chat = JinaChat(max_tokens=10) + message = HumanMessage(content="Hello") + response = chat([message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_jinachat_system_message() -> None: + """Test JinaChat wrapper with system message.""" + chat = JinaChat(max_tokens=10) + system_message = SystemMessage(content="You are to chat with the user.") + human_message = HumanMessage(content="Hello") + response = chat([system_message, human_message]) + assert isinstance(response, BaseMessage) + assert isinstance(response.content, str) + + +def test_jinachat_generate() -> None: + """Test JinaChat wrapper with generate.""" + chat = JinaChat(max_tokens=10) + message = HumanMessage(content="Hello") + response = chat.generate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +def test_jinachat_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = JinaChat( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + +@pytest.mark.asyncio +async def test_async_jinachat() -> None: + """Test async generation.""" + chat = JinaChat(max_tokens=102) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +@pytest.mark.asyncio +async def test_async_jinachat_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = JinaChat( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + +def test_jinachat_extra_kwargs() -> None: + """Test extra kwargs to chat openai.""" + # Check that foo is saved in extra_kwargs. + llm = JinaChat(foo=3, max_tokens=10) + assert llm.max_tokens == 10 + assert llm.model_kwargs == {"foo": 3} + + # Test that if extra_kwargs are provided, they are added to it. + llm = JinaChat(foo=3, model_kwargs={"bar": 2}) + assert llm.model_kwargs == {"foo": 3, "bar": 2} + + # Test that if provided twice it errors + with pytest.raises(ValueError): + JinaChat(foo=3, model_kwargs={"foo": 2}) + + # Test that if explicit param is specified in kwargs it errors + with pytest.raises(ValueError): + JinaChat(model_kwargs={"temperature": 0.2})