integrate JinaChat (#6927)

Integration with https://chat.jina.ai/api. It is OpenAI compatible API.

- Twitter handle:
[https://twitter.com/JinaAI_](https://twitter.com/JinaAI_)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/7392/head
Delgermurun 1 year ago committed by GitHub
parent 4ba7396f96
commit a1603fccfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# JinaChat\n",
"\n",
"This notebook covers how to get started with JinaChat chat models."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "522686de",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chat_models import JinaChat\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
" AIMessagePromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
")\n",
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat = JinaChat(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(\n",
" content=\"You are a helpful assistant that translates English to French.\"\n",
" ),\n",
" HumanMessage(\n",
" content=\"Translate this sentence from English to French. I love programming.\"\n",
" ),\n",
"]\n",
"chat(messages)"
]
},
{
"cell_type": "markdown",
"id": "778f912a-66ea-4a5d-b3de-6c7db4baba26",
"metadata": {},
"source": [
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
"\n",
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "180c5cc8",
"metadata": {},
"outputs": [],
"source": [
"template = (\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
")\n",
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
"human_template = \"{text}\"\n",
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "fbb043e6",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_prompt = ChatPromptTemplate.from_messages(\n",
" [system_message_prompt, human_message_prompt]\n",
")\n",
"\n",
"# get a chat completion from the formatted messages\n",
"chat(\n",
" chat_prompt.format_prompt(\n",
" input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n",
" ).to_messages()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c095285d",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -3,6 +3,7 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
from langchain.chat_models.fake import FakeListChatModel
from langchain.chat_models.google_palm import ChatGooglePalm
from langchain.chat_models.human import HumanInputChatModel
from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
@ -15,5 +16,6 @@ __all__ = [
"ChatAnthropic",
"ChatGooglePalm",
"ChatVertexAI",
"JinaChat",
"HumanInputChatModel",
]

@ -0,0 +1,357 @@
"""JinaChat wrapper."""
from __future__ import annotations
import logging
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
)
from pydantic import Field, root_validator
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
return await llm.client.acreate(**kwargs)
return await _completion_with_retry(**kwargs)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
content = _dict["content"] or ""
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict
class JinaChat(BaseChatModel):
"""Wrapper around JinaChat API.
To use, you should have the ``openai`` python package installed, and the
environment variable ``JINACHAT_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain.chat_models import JinaChat
chat = JinaChat()
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"jinachat_api_key": "JINACHAT_API_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any #: :meta private:
temperature: float = 0.7
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
jinachat_api_key: Optional[str] = None
"""Base URL path for API requests,
leave blank if not using a proxy or service emulator."""
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to JinaChat completion API. Default is 600 seconds."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""
max_tokens: Optional[int] = None
"""Maximum number of tokens to generate."""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = cls.all_required_field_names()
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
logger.warning(
f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
f"Instead they were passed in as part of `model_kwargs` parameter."
)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["jinachat_api_key"] = get_from_dict_or_env(
values, "jinachat_api_key", "JINACHAT_API_KEY"
)
try:
import openai
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
try:
values["client"] = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling JinaChat API."""
return {
"request_timeout": self.request_timeout,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"temperature": self.temperature,
**self.model_kwargs,
}
def _create_retry_decorator(self) -> Callable[[Any], Any]:
import openai
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage}
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = dict(self._invocation_params)
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
generations = []
for res in response["choices"]:
message = _convert_dict_to_message(res["message"])
gen = ChatGeneration(message=message)
generations.append(gen)
llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=message_dicts, **params
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token or ""
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
return self._create_chat_result(response)
@property
def _invocation_params(self) -> Mapping[str, Any]:
"""Get the parameters used to invoke the model."""
jinachat_creds: Dict[str, Any] = {
"api_key": self.jinachat_api_key,
"api_base": "https://api.chat.jina.ai/v1",
"model": "jinachat",
}
return {**jinachat_creds, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "jinachat"

@ -0,0 +1,127 @@
"""Test JinaChat wrapper."""
import pytest
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.jinachat import JinaChat
from langchain.schema import (
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
SystemMessage,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_jinachat() -> None:
"""Test JinaChat wrapper."""
chat = JinaChat(max_tokens=10)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_jinachat_system_message() -> None:
"""Test JinaChat wrapper with system message."""
chat = JinaChat(max_tokens=10)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_jinachat_generate() -> None:
"""Test JinaChat wrapper with generate."""
chat = JinaChat(max_tokens=10)
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_jinachat_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = JinaChat(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = chat([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)
@pytest.mark.asyncio
async def test_async_jinachat() -> None:
"""Test async generation."""
chat = JinaChat(max_tokens=102)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
@pytest.mark.asyncio
async def test_async_jinachat_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = JinaChat(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
def test_jinachat_extra_kwargs() -> None:
"""Test extra kwargs to chat openai."""
# Check that foo is saved in extra_kwargs.
llm = JinaChat(foo=3, max_tokens=10)
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = JinaChat(foo=3, model_kwargs={"bar": 2})
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
JinaChat(foo=3, model_kwargs={"foo": 2})
# Test that if explicit param is specified in kwargs it errors
with pytest.raises(ValueError):
JinaChat(model_kwargs={"temperature": 0.2})
Loading…
Cancel
Save