mirror of https://github.com/hwchase17/langchain
integrate JinaChat (#6927)
Integration with https://chat.jina.ai/api. It is OpenAI compatible API. - Twitter handle: [https://twitter.com/JinaAI_](https://twitter.com/JinaAI_) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>pull/7392/head
parent
4ba7396f96
commit
a1603fccfb
@ -0,0 +1,162 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e49f1e0d",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# JinaChat\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook covers how to get started with JinaChat chat models."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "522686de",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import JinaChat\n",
|
||||||
|
"from langchain.prompts.chat import (\n",
|
||||||
|
" ChatPromptTemplate,\n",
|
||||||
|
" SystemMessagePromptTemplate,\n",
|
||||||
|
" AIMessagePromptTemplate,\n",
|
||||||
|
" HumanMessagePromptTemplate,\n",
|
||||||
|
")\n",
|
||||||
|
"from langchain.schema import AIMessage, HumanMessage, SystemMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "62e0dbc3",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = JinaChat(temperature=0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "ce16ad78-8e6f-48cd-954e-98be75eb5836",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" SystemMessage(\n",
|
||||||
|
" content=\"You are a helpful assistant that translates English to French.\"\n",
|
||||||
|
" ),\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" ),\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "778f912a-66ea-4a5d-b3de-6c7db4baba26",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can make use of templating by using a `MessagePromptTemplate`. You can build a `ChatPromptTemplate` from one or more `MessagePromptTemplates`. You can use `ChatPromptTemplate`'s `format_prompt` -- this returns a `PromptValue`, which you can convert to a string or Message object, depending on whether you want to use the formatted value as input to an llm or chat model.\n",
|
||||||
|
"\n",
|
||||||
|
"For convenience, there is a `from_template` method exposed on the template. If you were to use this template, this is what it would look like:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "180c5cc8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = (\n",
|
||||||
|
" \"You are a helpful assistant that translates {input_language} to {output_language}.\"\n",
|
||||||
|
")\n",
|
||||||
|
"system_message_prompt = SystemMessagePromptTemplate.from_template(template)\n",
|
||||||
|
"human_template = \"{text}\"\n",
|
||||||
|
"human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "fbb043e6",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [system_message_prompt, human_message_prompt]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# get a chat completion from the formatted messages\n",
|
||||||
|
"chat(\n",
|
||||||
|
" chat_prompt.format_prompt(\n",
|
||||||
|
" input_language=\"English\", output_language=\"French\", text=\"I love programming.\"\n",
|
||||||
|
" ).to_messages()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c095285d",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,357 @@
|
|||||||
|
"""JinaChat wrapper."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Field, root_validator
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_retry_decorator(llm: JinaChat) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(llm.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the async completion call."""
|
||||||
|
retry_decorator = _create_retry_decorator(llm)
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
||||||
|
return await llm.client.acreate(**kwargs)
|
||||||
|
|
||||||
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=_dict["content"])
|
||||||
|
elif role == "assistant":
|
||||||
|
content = _dict["content"] or ""
|
||||||
|
return AIMessage(content=content)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=_dict["content"])
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=_dict["content"], role=role)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
if "name" in message.additional_kwargs:
|
||||||
|
message_dict["name"] = message.additional_kwargs["name"]
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
class JinaChat(BaseChatModel):
|
||||||
|
"""Wrapper around JinaChat API.
|
||||||
|
|
||||||
|
To use, you should have the ``openai`` python package installed, and the
|
||||||
|
environment variable ``JINACHAT_API_KEY`` set with your API key.
|
||||||
|
|
||||||
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||||
|
in, even if not explicitly saved on this class.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import JinaChat
|
||||||
|
chat = JinaChat()
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_secrets(self) -> Dict[str, str]:
|
||||||
|
return {"jinachat_api_key": "JINACHAT_API_KEY"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
client: Any #: :meta private:
|
||||||
|
temperature: float = 0.7
|
||||||
|
"""What sampling temperature to use."""
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||||
|
jinachat_api_key: Optional[str] = None
|
||||||
|
"""Base URL path for API requests,
|
||||||
|
leave blank if not using a proxy or service emulator."""
|
||||||
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
||||||
|
"""Timeout for requests to JinaChat completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
|
"""Maximum number of retries to make when generating."""
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results or not."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""Maximum number of tokens to generate."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
|
all_required_field_names = cls.all_required_field_names()
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
for field_name in list(values):
|
||||||
|
if field_name in extra:
|
||||||
|
raise ValueError(f"Found {field_name} supplied twice.")
|
||||||
|
if field_name not in all_required_field_names:
|
||||||
|
logger.warning(
|
||||||
|
f"""WARNING! {field_name} is not default parameter.
|
||||||
|
{field_name} was transferred to model_kwargs.
|
||||||
|
Please confirm that {field_name} is what you intended."""
|
||||||
|
)
|
||||||
|
extra[field_name] = values.pop(field_name)
|
||||||
|
|
||||||
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||||
|
if invalid_model_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||||
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||||
|
)
|
||||||
|
|
||||||
|
values["model_kwargs"] = extra
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["jinachat_api_key"] = get_from_dict_or_env(
|
||||||
|
values, "jinachat_api_key", "JINACHAT_API_KEY"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import openai python package. "
|
||||||
|
"Please install it with `pip install openai`."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
values["client"] = openai.ChatCompletion
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(
|
||||||
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade openai`."
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the default parameters for calling JinaChat API."""
|
||||||
|
return {
|
||||||
|
"request_timeout": self.request_timeout,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stream": self.streaming,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
**self.model_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||||
|
import openai
|
||||||
|
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||||
|
overall_token_usage: dict = {}
|
||||||
|
for output in llm_outputs:
|
||||||
|
if output is None:
|
||||||
|
# Happens in streaming
|
||||||
|
continue
|
||||||
|
token_usage = output["token_usage"]
|
||||||
|
for k, v in token_usage.items():
|
||||||
|
if k in overall_token_usage:
|
||||||
|
overall_token_usage[k] += v
|
||||||
|
else:
|
||||||
|
overall_token_usage[k] = v
|
||||||
|
return {"token_usage": overall_token_usage}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
if self.streaming:
|
||||||
|
inner_completion = ""
|
||||||
|
role = "assistant"
|
||||||
|
params["stream"] = True
|
||||||
|
for stream_resp in self.completion_with_retry(
|
||||||
|
messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
||||||
|
inner_completion += token
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token)
|
||||||
|
message = _convert_dict_to_message(
|
||||||
|
{
|
||||||
|
"content": inner_completion,
|
||||||
|
"role": role,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
def _create_message_dicts(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
|
params = dict(self._invocation_params)
|
||||||
|
if stop is not None:
|
||||||
|
if "stop" in params:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
params["stop"] = stop
|
||||||
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response["choices"]:
|
||||||
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
gen = ChatGeneration(message=message)
|
||||||
|
generations.append(gen)
|
||||||
|
llm_output = {"token_usage": response["usage"]}
|
||||||
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
|
params = {**params, **kwargs}
|
||||||
|
if self.streaming:
|
||||||
|
inner_completion = ""
|
||||||
|
role = "assistant"
|
||||||
|
params["stream"] = True
|
||||||
|
async for stream_resp in await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, **params
|
||||||
|
):
|
||||||
|
role = stream_resp["choices"][0]["delta"].get("role", role)
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
|
inner_completion += token or ""
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(token)
|
||||||
|
message = _convert_dict_to_message(
|
||||||
|
{
|
||||||
|
"content": inner_completion,
|
||||||
|
"role": role,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
else:
|
||||||
|
response = await acompletion_with_retry(
|
||||||
|
self, messages=message_dicts, **params
|
||||||
|
)
|
||||||
|
return self._create_chat_result(response)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invocation_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model."""
|
||||||
|
jinachat_creds: Dict[str, Any] = {
|
||||||
|
"api_key": self.jinachat_api_key,
|
||||||
|
"api_base": "https://api.chat.jina.ai/v1",
|
||||||
|
"model": "jinachat",
|
||||||
|
}
|
||||||
|
return {**jinachat_creds, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "jinachat"
|
@ -0,0 +1,127 @@
|
|||||||
|
"""Test JinaChat wrapper."""
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
from langchain.chat_models.jinachat import JinaChat
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
HumanMessage,
|
||||||
|
LLMResult,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinachat() -> None:
|
||||||
|
"""Test JinaChat wrapper."""
|
||||||
|
chat = JinaChat(max_tokens=10)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinachat_system_message() -> None:
|
||||||
|
"""Test JinaChat wrapper with system message."""
|
||||||
|
chat = JinaChat(max_tokens=10)
|
||||||
|
system_message = SystemMessage(content="You are to chat with the user.")
|
||||||
|
human_message = HumanMessage(content="Hello")
|
||||||
|
response = chat([system_message, human_message])
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinachat_generate() -> None:
|
||||||
|
"""Test JinaChat wrapper with generate."""
|
||||||
|
chat = JinaChat(max_tokens=10)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat.generate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinachat_streaming() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
chat = JinaChat(
|
||||||
|
max_tokens=10,
|
||||||
|
streaming=True,
|
||||||
|
temperature=0,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = chat([message])
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, BaseMessage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_jinachat() -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
chat = JinaChat(max_tokens=102)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_jinachat_streaming() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
chat = JinaChat(
|
||||||
|
max_tokens=10,
|
||||||
|
streaming=True,
|
||||||
|
temperature=0,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_jinachat_extra_kwargs() -> None:
|
||||||
|
"""Test extra kwargs to chat openai."""
|
||||||
|
# Check that foo is saved in extra_kwargs.
|
||||||
|
llm = JinaChat(foo=3, max_tokens=10)
|
||||||
|
assert llm.max_tokens == 10
|
||||||
|
assert llm.model_kwargs == {"foo": 3}
|
||||||
|
|
||||||
|
# Test that if extra_kwargs are provided, they are added to it.
|
||||||
|
llm = JinaChat(foo=3, model_kwargs={"bar": 2})
|
||||||
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||||
|
|
||||||
|
# Test that if provided twice it errors
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
JinaChat(foo=3, model_kwargs={"foo": 2})
|
||||||
|
|
||||||
|
# Test that if explicit param is specified in kwargs it errors
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
JinaChat(model_kwargs={"temperature": 0.2})
|
Loading…
Reference in New Issue