community[minor]: Integration for `Friendli` LLM and `ChatFriendli` ChatModel. (#17913)

## Description
- Add [Friendli](https://friendli.ai/) integration for `Friendli` LLM
and `ChatFriendli` chat model.
- Unit tests and integration tests corresponding to this change are
added.
- Documentations corresponding to this change are added.

## Dependencies
- Optional dependency
[`friendli-client`](https://pypi.org/project/friendli-client/) package
is added only for those who use `Frienldi` or `ChatFriendli` model.

## Twitter handle
- https://twitter.com/friendliai
pull/18773/head
Yunmo Koo 3 months ago committed by GitHub
parent aed46cd6f2
commit fee6f983ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,286 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Friendli\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatFriendli\n",
"\n",
"> [Friendli](https://friendli.ai/) enhances AI application performance and optimizes cost savings with scalable, efficient deployment options, tailored for high-demand AI workloads.\n",
"\n",
"This tutorial guides you through integrating `ChatFriendli` for chat applications using LangChain. `ChatFriendli` offers a flexible approach to generating conversational AI responses, supporting both synchronous and asynchronous calls."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Ensure the `langchain_community` and `friendli-client` are installed.\n",
"\n",
"```sh\n",
"pip install -U langchain-comminity friendli-client.\n",
"```\n",
"\n",
"Sign in to [Friendli Suite](https://suite.friendli.ai/) to create a Personal Access Token, and set it as the `FRIENDLI_TOKEN` environment."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"FRIENDLI_TOKEN\"] = getpass.getpass(\"Friendi Personal Access Token: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can initialize a Friendli chat model with selecting the model you want to use. The default model is `mixtral-8x7b-instruct-v0-1`. You can check the available models at [docs.friendli.ai](https://docs.periflow.ai/guides/serverless_endpoints/pricing#text-generation-models)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.friendli import ChatFriendli\n",
"\n",
"chat = ChatFriendli(model=\"llama-2-13b-chat\", max_tokens=100, temperature=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"`FrienliChat` supports all methods of [`ChatModel`](/docs/modules/model_io/chat/) including async APIs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use functionality of `invoke`, `batch`, `generate`, and `stream`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\")"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.messages.human import HumanMessage\n",
"from langchain_core.messages.system import SystemMessage\n",
"\n",
"system_message = SystemMessage(content=\"Answer questions as short as you can.\")\n",
"human_message = HumanMessage(content=\"Tell me a joke.\")\n",
"messages = [system_message, human_message]\n",
"\n",
"chat.invoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"),\n",
" AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\")]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat.batch([messages, messages])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\", message=AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"))], [ChatGeneration(text=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\", message=AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"))]], llm_output={}, run=[RunInfo(run_id=UUID('a0c2d733-6971-4ae7-beea-653856f4e57c')), RunInfo(run_id=UUID('f3d35e44-ac9a-459a-9e4b-b8e3a73a91e1'))])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat.generate([messages, messages])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Knock, knock!\n",
"Who's there?\n",
"Cows go.\n",
"Cows go who?\n",
"MOO!"
]
}
],
"source": [
"for chunk in chat.stream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use all functionality of async APIs: `ainvoke`, `abatch`, `agenerate`, and `astream`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\")"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.ainvoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"),\n",
" AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\")]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.abatch([messages, messages])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\", message=AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"))], [ChatGeneration(text=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\", message=AIMessage(content=\" Knock, knock!\\nWho's there?\\nCows go.\\nCows go who?\\nMOO!\"))]], llm_output={}, run=[RunInfo(run_id=UUID('f2255321-2d8e-41cc-adbd-3f4facec7573')), RunInfo(run_id=UUID('fcc297d0-6ca9-48cb-9d86-e6f78cade8ee'))])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await chat.agenerate([messages, messages])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Knock, knock!\n",
"Who's there?\n",
"Cows go.\n",
"Cows go who?\n",
"MOO!"
]
}
],
"source": [
"async for chunk in chat.astream(messages):\n",
" print(chunk.content, end=\"\", flush=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,277 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Friendli\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Friendli\n",
"\n",
"> [Friendli](https://friendli.ai/) enhances AI application performance and optimizes cost savings with scalable, efficient deployment options, tailored for high-demand AI workloads.\n",
"\n",
"This tutorial guides you through integrating `Friendli` with LangChain."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"Ensure the `langchain_community` and `friendli-client` are installed.\n",
"\n",
"```sh\n",
"pip install -U langchain-comminity friendli-client.\n",
"```\n",
"\n",
"Sign in to [Friendli Suite](https://suite.friendli.ai/) to create a Personal Access Token, and set it as the `FRIENDLI_TOKEN` environment."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"FRIENDLI_TOKEN\"] = getpass.getpass(\"Friendi Personal Access Token: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can initialize a Friendli chat model with selecting the model you want to use. The default model is `mixtral-8x7b-instruct-v0-1`. You can check the available models at [docs.friendli.ai](https://docs.periflow.ai/guides/serverless_endpoints/pricing#text-generation-models)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms.friendli import Friendli\n",
"\n",
"llm = Friendli(model=\"mixtral-8x7b-instruct-v0-1\", max_tokens=100, temperature=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage\n",
"\n",
"`Frienli` supports all methods of [`LLM`](/docs/modules/model_io/llms/) including async APIs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use functionality of `invoke`, `batch`, `generate`, and `stream`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"'"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.invoke(\"Tell me a joke.\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"',\n",
" 'Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.batch([\"Tell me a joke.\", \"Tell me a joke.\"])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[Generation(text='Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"')], [Generation(text='Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"')]], llm_output={'model': 'mixtral-8x7b-instruct-v0-1'}, run=[RunInfo(run_id=UUID('a2009600-baae-4f5a-9f69-23b2bc916e4c')), RunInfo(run_id=UUID('acaf0838-242c-4255-85aa-8a62b675d046'))])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.generate([\"Tell me a joke.\", \"Tell me a joke.\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Username checks out.\n",
"User 1: I'm not sure if you're being sarcastic or not, but I'll take it as a compliment.\n",
"User 0: I'm not being sarcastic. I'm just saying that your username is very fitting.\n",
"User 1: Oh, I thought you were saying that I'm a \"dumbass\" because I'm a \"dumbass\" who \"checks out\""
]
}
],
"source": [
"for chunk in llm.stream(\"Tell me a joke.\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use all functionality of async APIs: `ainvoke`, `abatch`, `agenerate`, and `astream`."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.ainvoke(\"Tell me a joke.\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"',\n",
" 'Username checks out.\\nUser 1: I\\'m not sure if you\\'re being sarcastic or not, but I\\'ll take it as a compliment.\\nUser 0: I\\'m not being sarcastic. I\\'m just saying that your username is very fitting.\\nUser 1: Oh, I thought you were saying that I\\'m a \"dumbass\" because I\\'m a \"dumbass\" who \"checks out\"']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.abatch([\"Tell me a joke.\", \"Tell me a joke.\"])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LLMResult(generations=[[Generation(text=\"Username checks out.\\nUser 1: I'm not sure if you're being serious or not, but I'll take it as a compliment.\\nUser 0: I'm being serious. I'm not sure if you're being serious or not.\\nUser 1: I'm being serious. I'm not sure if you're being serious or not.\\nUser 0: I'm being serious. I'm not sure\")], [Generation(text=\"Username checks out.\\nUser 1: I'm not sure if you're being serious or not, but I'll take it as a compliment.\\nUser 0: I'm being serious. I'm not sure if you're being serious or not.\\nUser 1: I'm being serious. I'm not sure if you're being serious or not.\\nUser 0: I'm being serious. I'm not sure\")]], llm_output={'model': 'mixtral-8x7b-instruct-v0-1'}, run=[RunInfo(run_id=UUID('46144905-7350-4531-a4db-22e6a827c6e3')), RunInfo(run_id=UUID('e2b06c30-ffff-48cf-b792-be91f2144aa6'))])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.agenerate([\"Tell me a joke.\", \"Tell me a joke.\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Username checks out.\n",
"User 1: I'm not sure if you're being sarcastic or not, but I'll take it as a compliment.\n",
"User 0: I'm not being sarcastic. I'm just saying that your username is very fitting.\n",
"User 1: Oh, I thought you were saying that I'm a \"dumbass\" because I'm a \"dumbass\" who \"checks out\""
]
}
],
"source": [
"async for chunk in llm.astream(\"Tell me a joke.\"):\n",
" print(chunk, end=\"\", flush=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -30,6 +30,7 @@ from langchain_community.chat_models.ernie import ErnieBotChat
from langchain_community.chat_models.everlyai import ChatEverlyAI
from langchain_community.chat_models.fake import FakeListChatModel
from langchain_community.chat_models.fireworks import ChatFireworks
from langchain_community.chat_models.friendli import ChatFriendli
from langchain_community.chat_models.gigachat import GigaChat
from langchain_community.chat_models.google_palm import ChatGooglePalm
from langchain_community.chat_models.gpt_router import GPTRouter
@ -94,6 +95,7 @@ __all__ = [
"ChatYandexGPT",
"ChatBaichuan",
"ChatHunyuan",
"ChatFriendli",
"GigaChat",
"ChatSparkLLM",
"VolcEngineMaasChat",

@ -0,0 +1,217 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.friendli import BaseFriendli
def get_role(message: BaseMessage) -> str:
"""Get role of the message.
Args:
message (BaseMessage): The message object.
Raises:
ValueError: Raised when the message is of an unknown type.
Returns:
str: The role of the message.
"""
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
return "user"
if isinstance(message, AIMessage):
return "assistant"
if isinstance(message, SystemMessage):
return "system"
raise ValueError(f"Got unknown type {message}")
def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
"""Get a request of the Friendli chat API.
Args:
messages (List[BaseMessage]): Messages comprising the conversation so far.
Returns:
Dict[str, Any]: The request for the Friendli chat API.
"""
return {
"messages": [
{"role": get_role(message), "content": message.content}
for message in messages
]
}
class ChatFriendli(BaseChatModel, BaseFriendli):
"""Friendli LLM for chat.
``friendli-client`` package should be installed with `pip install friendli-client`.
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
personal access token for the ``friendli_token`` argument.
Example:
.. code-block:: python
from langchain_community.chat_models import FriendliChat
chat = Friendli(
model="llama-2-13b-chat", friendli_token="YOUR FRIENDLI TOKEN"
)
chat.invoke("What is generative AI?")
"""
model: str = "llama-2-13b-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"friendli_token": "FRIENDLI_TOKEN"}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Friendli completions API."""
return {
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"max_tokens": self.max_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {"model": self.model, **self._default_params}
@property
def _llm_type(self) -> str:
return "friendli-chat"
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop"] = self.stop
else:
params["stop"] = stop
return {**params, **kwargs}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = self.client.chat.completions.create(
**get_chat_request(messages), stream=True, model=self.model, **params
)
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = await self.async_client.chat.completions.create(
**get_chat_request(messages), stream=True, model=self.model, **params
)
async for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._get_invocation_params(stop=stop, **kwargs)
response = self.client.chat.completions.create(
messages=[
{
"role": get_role(message),
"content": message.content,
}
for message in messages
],
stream=False,
model=self.model,
**params,
)
message = AIMessage(content=response.choices[0].message.content)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
params = self._get_invocation_params(stop=stop, **kwargs)
response = await self.async_client.chat.completions.create(
messages=[
{
"role": get_role(message),
"content": message.content,
}
for message in messages
],
stream=False,
model=self.model,
**params,
)
message = AIMessage(content=response.choices[0].message.content)
return ChatResult(generations=[ChatGeneration(message=message)])

@ -209,6 +209,12 @@ def _import_forefrontai() -> Type[BaseLLM]:
return ForefrontAI
def _import_friendli() -> Type[BaseLLM]:
from langchain_community.llms.friendli import Friendli
return Friendli
def _import_gigachat() -> Type[BaseLLM]:
from langchain_community.llms.gigachat import GigaChat
@ -665,6 +671,8 @@ def __getattr__(name: str) -> Any:
return _import_fireworks()
elif name == "ForefrontAI":
return _import_forefrontai()
elif name == "Friendli":
return _import_friendli()
elif name == "GigaChat":
return _import_gigachat()
elif name == "GooglePalm":
@ -827,6 +835,7 @@ __all__ = [
"FakeListLLM",
"Fireworks",
"ForefrontAI",
"Friendli",
"GigaChat",
"GPT4All",
"GooglePalm",
@ -919,6 +928,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"edenai": _import_edenai,
"fake-list": _import_fake,
"forefrontai": _import_forefrontai,
"friendli": _import_friendli,
"giga-chat-model": _import_gigachat,
"google_palm": _import_google_palm,
"gooseai": _import_gooseai,

@ -0,0 +1,350 @@
from __future__ import annotations
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.load.serializable import Serializable
from langchain_core.outputs import GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils.env import get_from_dict_or_env
from langchain_core.utils.utils import convert_to_secret_str
def _stream_response_to_generation_chunk(stream_response: Any) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
if stream_response.event == "token_sampled":
return GenerationChunk(
text=stream_response.text,
generation_info={"token": str(stream_response.token)},
)
return GenerationChunk(text="")
class BaseFriendli(Serializable):
"""Base class of Friendli."""
# Friendli client.
client: Any = Field(default=None, exclude=True)
# Friendli Async client.
async_client: Any = Field(default=None, exclude=True)
# Model name to use.
model: str = "mixtral-8x7b-instruct-v0-1"
# Friendli personal access token to run as.
friendli_token: Optional[SecretStr] = None
# Friendli team ID to run as.
friendli_team: Optional[str] = None
# Whether to enable streaming mode.
streaming: bool = False
# Number between -2.0 and 2.0. Positive values penalizes tokens that have been
# sampled, taking into account their frequency in the preceding text. This
# penalization diminishes the model's tendency to reproduce identical lines
# verbatim.
frequency_penalty: Optional[float] = None
# Number between -2.0 and 2.0. Positive values penalizes tokens that have been
# sampled at least once in the existing text.
presence_penalty: Optional[float] = None
# The maximum number of tokens to generate. The length of your input tokens plus
# `max_tokens` should not exceed the model's maximum length (e.g., 2048 for OpenAI
# GPT-3)
max_tokens: Optional[int] = None
# When one of the stop phrases appears in the generation result, the API will stop
# generation. The phrase is included in the generated result. If you are using
# beam search, all of the active beams should contain the stop phrase to terminate
# generation. Before checking whether a stop phrase is included in the result, the
# phrase is converted into tokens.
stop: Optional[List[str]] = None
# Sampling temperature. Smaller temperature makes the generation result closer to
# greedy, argmax (i.e., `top_k = 1`) sampling. If it is `None`, then 1.0 is used.
temperature: Optional[float] = None
# Tokens comprising the top `top_p` probability mass are kept for sampling. Numbers
# between 0.0 (exclusive) and 1.0 (inclusive) are allowed. If it is `None`, then 1.0
# is used by default.
top_p: Optional[float] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate if personal access token is provided in environment."""
try:
import friendli
except ImportError as e:
raise ImportError(
"Could not import friendli-client python package. "
"Please install it with `pip install friendli-client`."
) from e
friendli_token = convert_to_secret_str(
get_from_dict_or_env(values, "friendli_token", "FRIENDLI_TOKEN")
)
values["friendli_token"] = friendli_token
friendli_token_str = friendli_token.get_secret_value()
friendli_team = values["friendli_team"] or os.getenv("FRIENDLI_TEAM")
values["friendli_team"] = friendli_team
values["client"] = values["client"] or friendli.Friendli(
token=friendli_token_str, team_id=friendli_team
)
values["async_client"] = values["async_client"] or friendli.AsyncFriendli(
token=friendli_token_str, team_id=friendli_team
)
return values
class Friendli(LLM, BaseFriendli):
"""Friendli LLM.
``friendli-client`` package should be installed with `pip install friendli-client`.
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
personal access token for the ``friendli_token`` argument.
Example:
.. code-block:: python
from langchain_community.llms import Friendli
friendli = Friendli(
model="mixtral-8x7b-instruct-v0-1", friendli_token="YOUR FRIENDLI TOKEN"
)
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"friendli_token": "FRIENDLI_TOKEN"}
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Friendli completions API."""
return {
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"max_tokens": self.max_tokens,
"stop": self.stop,
"temperature": self.temperature,
"top_p": self.top_p,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {"model": self.model, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "friendli"
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop"] = self.stop
else:
params["stop"] = stop
return {**params, **kwargs}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out Friendli's completions API.
Args:
prompt (str): The text prompt to generate completion for.
stop (Optional[List[str]], optional): When one of the stop phrases appears
in the generation result, the API will stop generation. The stop phrases
are excluded from the result. If beam search is enabled, all of the
active beams should contain the stop phrase to terminate generation.
Before checking whether a stop phrase is included in the result, the
phrase is converted into tokens. We recommend using stop_tokens because
it is clearer. For example, after tokenization, phrases "clear" and
" clear" can result in different token sequences due to the prepended
space character. Defaults to None.
Returns:
str: The generated text output.
Example:
.. code-block:: python
response = frienldi("Give me a recipe for the Old Fashioned cocktail.")
"""
params = self._get_invocation_params(stop=stop, **kwargs)
completion = self.client.completions.create(
model=self.model, prompt=prompt, stream=False, **params
)
return completion.choices[0].text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out Friendli's completions API Asynchronously.
Args:
prompt (str): The text prompt to generate completion for.
stop (Optional[List[str]], optional): When one of the stop phrases appears
in the generation result, the API will stop generation. The stop phrases
are excluded from the result. If beam search is enabled, all of the
active beams should contain the stop phrase to terminate generation.
Before checking whether a stop phrase is included in the result, the
phrase is converted into tokens. We recommend using stop_tokens because
it is clearer. For example, after tokenization, phrases "clear" and
" clear" can result in different token sequences due to the prepended
space character. Defaults to None.
Returns:
str: The generated text output.
Example:
.. code-block:: python
response = await frienldi("Tell me a joke.")
"""
params = self._get_invocation_params(stop=stop, **kwargs)
completion = await self.async_client.completions.create(
model=self.model, prompt=prompt, stream=False, **params
)
return completion.choices[0].text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = self.client.completions.create(
model=self.model, prompt=prompt, stream=True, **params
)
for line in stream:
chunk = _stream_response_to_generation_chunk(line)
yield chunk
if run_manager:
run_manager.on_llm_new_token(line.text, chunk=chunk)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = self._get_invocation_params(stop=stop, **kwargs)
stream = await self.async_client.completions.create(
model=self.model, prompt=prompt, stream=True, **params
)
async for line in stream:
chunk = _stream_response_to_generation_chunk(line)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(line.text, chunk=chunk)
def _generate(
self,
prompts: list[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out Friendli's completions API with k unique prompts.
Args:
prompt (str): The text prompt to generate completion for.
stop (Optional[List[str]], optional): When one of the stop phrases appears
in the generation result, the API will stop generation. The stop phrases
are excluded from the result. If beam search is enabled, all of the
active beams should contain the stop phrase to terminate generation.
Before checking whether a stop phrase is included in the result, the
phrase is converted into tokens. We recommend using stop_tokens because
it is clearer. For example, after tokenization, phrases "clear" and
" clear" can result in different token sequences due to the prepended
space character. Defaults to None.
Returns:
str: The generated text output.
Example:
.. code-block:: python
response = frienldi.generate(["Tell me a joke."])
"""
llm_output = {"model": self.model}
if self.streaming:
if len(prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]], llm_output=llm_output)
llm_result = super()._generate(prompts, stop, run_manager, **kwargs)
llm_result.llm_output = llm_output
return llm_result
async def _agenerate(
self,
prompts: list[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out Friendli's completions API asynchronously with k unique prompts.
Args:
prompt (str): The text prompt to generate completion for.
stop (Optional[List[str]], optional): When one of the stop phrases appears
in the generation result, the API will stop generation. The stop phrases
are excluded from the result. If beam search is enabled, all of the
active beams should contain the stop phrase to terminate generation.
Before checking whether a stop phrase is included in the result, the
phrase is converted into tokens. We recommend using stop_tokens because
it is clearer. For example, after tokenization, phrases "clear" and
" clear" can result in different token sequences due to the prepended
space character. Defaults to None.
Returns:
str: The generated text output.
Example:
.. code-block:: python
response = await frienldi.agenerate(
["Give me a recipe for the Old Fashioned cocktail."]
)
"""
llm_output = {"model": self.model}
if self.streaming:
if len(prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
generation = None
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]], llm_output=llm_output)
llm_result = await super()._agenerate(prompts, stop, run_manager, **kwargs)
llm_result.llm_output = llm_output
return llm_result

File diff suppressed because one or more lines are too long

@ -96,6 +96,7 @@ oci = {version = "^2.119.1", optional = true}
rdflib = {version = "7.0.0", optional = true}
nvidia-riva-client = {version = "^2.14.0", optional = true}
tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true}
friendli-client = {version = "^1.2.4", optional = true}
[tool.poetry.group.test]
optional = true
@ -266,6 +267,7 @@ extended_testing = [
"rdflib",
"tidb-vector",
"cloudpickle",
"friendli-client"
]
[tool.ruff]

@ -0,0 +1,105 @@
"""Test Friendli chat API."""
import pytest
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.llm_result import LLMResult
from langchain_community.chat_models.friendli import ChatFriendli
@pytest.fixture
def friendli_chat() -> ChatFriendli:
"""Friendli LLM for chat."""
return ChatFriendli(temperature=0, max_tokens=10)
def test_friendli_call(friendli_chat: ChatFriendli) -> None:
"""Test call."""
message = HumanMessage(content="What is generative AI?")
output = friendli_chat([message])
assert isinstance(output, AIMessage)
assert isinstance(output.content, str)
def test_friendli_invoke(friendli_chat: ChatFriendli) -> None:
"""Test invoke."""
output = friendli_chat.invoke("What is generative AI?")
assert isinstance(output, AIMessage)
assert isinstance(output.content, str)
async def test_friendli_ainvoke(friendli_chat: ChatFriendli) -> None:
"""Test async invoke."""
output = await friendli_chat.ainvoke("What is generative AI?")
assert isinstance(output, AIMessage)
assert isinstance(output.content, str)
def test_friendli_batch(friendli_chat: ChatFriendli) -> None:
"""Test batch."""
outputs = friendli_chat.batch(["What is generative AI?", "What is generative AI?"])
for output in outputs:
assert isinstance(output, AIMessage)
assert isinstance(output.content, str)
async def test_friendli_abatch(friendli_chat: ChatFriendli) -> None:
"""Test async batch."""
outputs = await friendli_chat.abatch(
["What is generative AI?", "What is generative AI?"]
)
for output in outputs:
assert isinstance(output, AIMessage)
assert isinstance(output.content, str)
def test_friendli_generate(friendli_chat: ChatFriendli) -> None:
"""Test generate."""
message = HumanMessage(content="What is generative AI?")
result = friendli_chat.generate([[message], [message]])
assert isinstance(result, LLMResult)
generations = result.generations
assert len(generations) == 2
for generation in generations:
gen_ = generation[0]
assert isinstance(gen_, Generation)
text = gen_.text
assert isinstance(text, str)
generation_info = gen_.generation_info
if generation_info is not None:
assert "token" in generation_info
async def test_friendli_agenerate(friendli_chat: ChatFriendli) -> None:
"""Test async generate."""
message = HumanMessage(content="What is generative AI?")
result = await friendli_chat.agenerate([[message], [message]])
assert isinstance(result, LLMResult)
generations = result.generations
assert len(generations) == 2
for generation in generations:
gen_ = generation[0]
assert isinstance(gen_, Generation)
text = gen_.text
assert isinstance(text, str)
generation_info = gen_.generation_info
if generation_info is not None:
assert "token" in generation_info
def test_friendli_stream(friendli_chat: ChatFriendli) -> None:
"""Test stream."""
stream = friendli_chat.stream("Say hello world.")
for chunk in stream:
assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str)
async def test_friendli_astream(friendli_chat: ChatFriendli) -> None:
"""Test async stream."""
stream = friendli_chat.astream("Say hello world.")
async for chunk in stream:
assert isinstance(chunk, AIMessage)
assert isinstance(chunk.content, str)

@ -0,0 +1,91 @@
"""Test Friendli API."""
import pytest
from langchain_core.outputs.generation import Generation
from langchain_core.outputs.llm_result import LLMResult
from langchain_community.llms.friendli import Friendli
@pytest.fixture
def friendli_llm() -> Friendli:
"""Friendli LLM."""
return Friendli(temperature=0, max_tokens=10)
def test_friendli_call(friendli_llm: Friendli) -> None:
"""Test call."""
output = friendli_llm("Say hello world.")
assert isinstance(output, str)
def test_friendli_invoke(friendli_llm: Friendli) -> None:
"""Test invoke."""
output = friendli_llm.invoke("Say hello world.")
assert isinstance(output, str)
async def test_friendli_ainvoke(friendli_llm: Friendli) -> None:
"""Test async invoke."""
output = await friendli_llm.ainvoke("Say hello world.")
assert isinstance(output, str)
def test_friendli_batch(friendli_llm: Friendli) -> None:
"""Test batch."""
outputs = friendli_llm.batch(["Say hello world.", "Say bye world."])
for output in outputs:
assert isinstance(output, str)
async def test_friendli_abatch(friendli_llm: Friendli) -> None:
"""Test async batch."""
outputs = await friendli_llm.abatch(["Say hello world.", "Say bye world."])
for output in outputs:
assert isinstance(output, str)
def test_friendli_generate(friendli_llm: Friendli) -> None:
"""Test generate."""
result = friendli_llm.generate(["Say hello world.", "Say bye world."])
assert isinstance(result, LLMResult)
generations = result.generations
assert len(generations) == 2
for generation in generations:
gen_ = generation[0]
assert isinstance(gen_, Generation)
text = gen_.text
assert isinstance(text, str)
generation_info = gen_.generation_info
if generation_info is not None:
assert "token" in generation_info
async def test_friendli_agenerate(friendli_llm: Friendli) -> None:
"""Test async generate."""
result = await friendli_llm.agenerate(["Say hello world.", "Say bye world."])
assert isinstance(result, LLMResult)
generations = result.generations
assert len(generations) == 2
for generation in generations:
gen_ = generation[0]
assert isinstance(gen_, Generation)
text = gen_.text
assert isinstance(text, str)
generation_info = gen_.generation_info
if generation_info is not None:
assert "token" in generation_info
def test_friendli_stream(friendli_llm: Friendli) -> None:
"""Test stream."""
stream = friendli_llm.stream("Say hello world.")
for chunk in stream:
assert isinstance(chunk, str)
async def test_friendli_astream(friendli_llm: Friendli) -> None:
"""Test async stream."""
stream = friendli_llm.astream("Say hello world.")
async for chunk in stream:
assert isinstance(chunk, str)

@ -0,0 +1,197 @@
"""Test Friendli LLM for chat."""
from unittest.mock import AsyncMock, MagicMock, Mock
import pytest
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_community.adapters.openai import aenumerate
from langchain_community.chat_models import ChatFriendli
@pytest.fixture
def mock_friendli_client() -> Mock:
"""Mock instance of Friendli client."""
return Mock()
@pytest.fixture
def mock_friendli_async_client() -> AsyncMock:
"""Mock instance of Friendli async client."""
return AsyncMock()
@pytest.fixture
def chat_friendli(
mock_friendli_client: Mock, mock_friendli_async_client: AsyncMock
) -> ChatFriendli:
"""Friendli LLM for chat with mock clients."""
return ChatFriendli(
friendli_token=SecretStr("personal-access-token"),
client=mock_friendli_client,
async_client=mock_friendli_async_client,
)
@pytest.mark.requires("friendli")
def test_friendli_token_is_secret_string(capsys: CaptureFixture) -> None:
"""Test if friendli token is stored as a SecretStr."""
fake_token_value = "personal-access-token"
chat = ChatFriendli(friendli_token=fake_token_value)
assert isinstance(chat.friendli_token, SecretStr)
assert chat.friendli_token.get_secret_value() == fake_token_value
print(chat.friendli_token, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
@pytest.mark.requires("friendli")
def test_friendli_token_read_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test if friendli token can be parsed from environment."""
fake_token_value = "personal-access-token"
monkeypatch.setenv("FRIENDLI_TOKEN", fake_token_value)
chat = ChatFriendli()
assert isinstance(chat.friendli_token, SecretStr)
assert chat.friendli_token.get_secret_value() == fake_token_value
print(chat.friendli_token, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
@pytest.mark.requires("friendli")
def test_friendli_invoke(
mock_friendli_client: Mock, chat_friendli: ChatFriendli
) -> None:
"""Test invocation with friendli."""
mock_message = Mock()
mock_message.content = "Hello Friendli"
mock_message.role = "assistant"
mock_choice = Mock()
mock_choice.message = mock_message
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_friendli_client.chat.completions.create.return_value = mock_response
result = chat_friendli.invoke("Hello langchain")
assert result.content == "Hello Friendli"
mock_friendli_client.chat.completions.create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello langchain"}],
stream=False,
model=chat_friendli.model,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
async def test_friendli_ainvoke(
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
) -> None:
"""Test async invocation with friendli."""
mock_message = Mock()
mock_message.content = "Hello Friendli"
mock_message.role = "assistant"
mock_choice = Mock()
mock_choice.message = mock_message
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_friendli_async_client.chat.completions.create.return_value = mock_response
result = await chat_friendli.ainvoke("Hello langchain")
assert result.content == "Hello Friendli"
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
messages=[{"role": "user", "content": "Hello langchain"}],
stream=False,
model=chat_friendli.model,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
def test_friendli_stream(
mock_friendli_client: Mock, chat_friendli: ChatFriendli
) -> None:
"""Test stream with friendli."""
mock_delta_0 = Mock()
mock_delta_0.content = "Hello "
mock_delta_1 = Mock()
mock_delta_1.content = "Friendli"
mock_choice_0 = Mock()
mock_choice_0.delta = mock_delta_0
mock_choice_1 = Mock()
mock_choice_1.delta = mock_delta_1
mock_chunk_0 = Mock()
mock_chunk_0.choices = [mock_choice_0]
mock_chunk_1 = Mock()
mock_chunk_1.choices = [mock_choice_1]
mock_stream = MagicMock()
mock_chunks = [mock_chunk_0, mock_chunk_1]
mock_stream.__iter__.return_value = mock_chunks
mock_friendli_client.chat.completions.create.return_value = mock_stream
stream = chat_friendli.stream("Hello langchain")
for i, chunk in enumerate(stream):
assert chunk.content == mock_chunks[i].choices[0].delta.content
mock_friendli_client.chat.completions.create.assert_called_once_with(
messages=[{"role": "user", "content": "Hello langchain"}],
stream=True,
model=chat_friendli.model,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
async def test_friendli_astream(
mock_friendli_async_client: AsyncMock, chat_friendli: ChatFriendli
) -> None:
"""Test async stream with friendli."""
mock_delta_0 = Mock()
mock_delta_0.content = "Hello "
mock_delta_1 = Mock()
mock_delta_1.content = "Friendli"
mock_choice_0 = Mock()
mock_choice_0.delta = mock_delta_0
mock_choice_1 = Mock()
mock_choice_1.delta = mock_delta_1
mock_chunk_0 = Mock()
mock_chunk_0.choices = [mock_choice_0]
mock_chunk_1 = Mock()
mock_chunk_1.choices = [mock_choice_1]
mock_stream = AsyncMock()
mock_chunks = [mock_chunk_0, mock_chunk_1]
mock_stream.__aiter__.return_value = mock_chunks
mock_friendli_async_client.chat.completions.create.return_value = mock_stream
stream = chat_friendli.astream("Hello langchain")
async for i, chunk in aenumerate(stream):
assert chunk.content == mock_chunks[i].choices[0].delta.content
mock_friendli_async_client.chat.completions.create.assert_awaited_once_with(
messages=[{"role": "user", "content": "Hello langchain"}],
stream=True,
model=chat_friendli.model,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)

@ -43,6 +43,7 @@ EXPECTED_ALL = [
"ChatZhipuAI",
"ChatPerplexity",
"ChatKinetica",
"ChatFriendli",
]

@ -0,0 +1,179 @@
"""Test Friendli LLM."""
from unittest.mock import AsyncMock, MagicMock, Mock
import pytest
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain_community.adapters.openai import aenumerate
from langchain_community.llms.friendli import Friendli
@pytest.fixture
def mock_friendli_client() -> Mock:
"""Mock instance of Friendli client."""
return Mock()
@pytest.fixture
def mock_friendli_async_client() -> AsyncMock:
"""Mock instance of Friendli async client."""
return AsyncMock()
@pytest.fixture
def friendli_llm(
mock_friendli_client: Mock, mock_friendli_async_client: AsyncMock
) -> Friendli:
"""Friendli LLM with mock clients."""
return Friendli(
friendli_token=SecretStr("personal-access-token"),
client=mock_friendli_client,
async_client=mock_friendli_async_client,
)
@pytest.mark.requires("friendli")
def test_friendli_token_is_secret_string(capsys: CaptureFixture) -> None:
"""Test if friendli token is stored as a SecretStr."""
fake_token_value = "personal-access-token"
chat = Friendli(friendli_token=fake_token_value)
assert isinstance(chat.friendli_token, SecretStr)
assert chat.friendli_token.get_secret_value() == fake_token_value
print(chat.friendli_token, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
@pytest.mark.requires("friendli")
def test_friendli_token_read_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test if friendli token can be parsed from environment."""
fake_token_value = "personal-access-token"
monkeypatch.setenv("FRIENDLI_TOKEN", fake_token_value)
chat = Friendli()
assert isinstance(chat.friendli_token, SecretStr)
assert chat.friendli_token.get_secret_value() == fake_token_value
print(chat.friendli_token, end="") # noqa: T201
captured = capsys.readouterr()
assert captured.out == "**********"
@pytest.mark.requires("friendli")
def test_friendli_invoke(mock_friendli_client: Mock, friendli_llm: Friendli) -> None:
"""Test invocation with friendli."""
mock_choice = Mock()
mock_choice.text = "Hello Friendli"
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_friendli_client.completions.create.return_value = mock_response
result = friendli_llm.invoke("Hello langchain")
assert result == "Hello Friendli"
mock_friendli_client.completions.create.assert_called_once_with(
model=friendli_llm.model,
prompt="Hello langchain",
stream=False,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
async def test_friendli_ainvoke(
mock_friendli_async_client: AsyncMock, friendli_llm: Friendli
) -> None:
"""Test async invocation with friendli."""
mock_choice = Mock()
mock_choice.text = "Hello Friendli"
mock_response = Mock()
mock_response.choices = [mock_choice]
mock_friendli_async_client.completions.create.return_value = mock_response
result = await friendli_llm.ainvoke("Hello langchain")
assert result == "Hello Friendli"
mock_friendli_async_client.completions.create.assert_awaited_once_with(
model=friendli_llm.model,
prompt="Hello langchain",
stream=False,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
def test_friendli_stream(mock_friendli_client: Mock, friendli_llm: Friendli) -> None:
"""Test stream with friendli."""
mock_chunk_0 = Mock()
mock_chunk_0.event = "token_sampled"
mock_chunk_0.text = "Hello "
mock_chunk_0.token = 0
mock_chunk_1 = Mock()
mock_chunk_1.event = "token_sampled"
mock_chunk_1.text = "Friendli"
mock_chunk_1.token = 1
mock_stream = MagicMock()
mock_chunks = [mock_chunk_0, mock_chunk_1]
mock_stream.__iter__.return_value = mock_chunks
mock_friendli_client.completions.create.return_value = mock_stream
stream = friendli_llm.stream("Hello langchain")
for i, chunk in enumerate(stream):
assert chunk == mock_chunks[i].text
mock_friendli_client.completions.create.assert_called_once_with(
model=friendli_llm.model,
prompt="Hello langchain",
stream=True,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)
@pytest.mark.requires("friendli")
async def test_friendli_astream(
mock_friendli_async_client: AsyncMock, friendli_llm: Friendli
) -> None:
"""Test async stream with friendli."""
mock_chunk_0 = Mock()
mock_chunk_0.event = "token_sampled"
mock_chunk_0.text = "Hello "
mock_chunk_0.token = 0
mock_chunk_1 = Mock()
mock_chunk_1.event = "token_sampled"
mock_chunk_1.text = "Friendli"
mock_chunk_1.token = 1
mock_stream = AsyncMock()
mock_chunks = [mock_chunk_0, mock_chunk_1]
mock_stream.__aiter__.return_value = mock_chunks
mock_friendli_async_client.completions.create.return_value = mock_stream
stream = friendli_llm.astream("Hello langchain")
async for i, chunk in aenumerate(stream):
assert chunk == mock_chunks[i].text
mock_friendli_async_client.completions.create.assert_awaited_once_with(
model=friendli_llm.model,
prompt="Hello langchain",
stream=True,
frequency_penalty=None,
presence_penalty=None,
max_tokens=None,
stop=None,
temperature=None,
top_p=None,
)

@ -30,6 +30,7 @@ EXPECT_ALL = [
"FakeListLLM",
"Fireworks",
"ForefrontAI",
"Friendli",
"GigaChat",
"GPT4All",
"GooglePalm",

Loading…
Cancel
Save