mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
fee6f983ef
## Description - Add [Friendli](https://friendli.ai/) integration for `Friendli` LLM and `ChatFriendli` chat model. - Unit tests and integration tests corresponding to this change are added. - Documentations corresponding to this change are added. ## Dependencies - Optional dependency [`friendli-client`](https://pypi.org/project/friendli-client/) package is added only for those who use `Frienldi` or `ChatFriendli` model. ## Twitter handle - https://twitter.com/friendliai
218 lines
6.9 KiB
Python
218 lines
6.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
agenerate_from_stream,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
ChatMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_community.llms.friendli import BaseFriendli
|
|
|
|
|
|
def get_role(message: BaseMessage) -> str:
|
|
"""Get role of the message.
|
|
|
|
Args:
|
|
message (BaseMessage): The message object.
|
|
|
|
Raises:
|
|
ValueError: Raised when the message is of an unknown type.
|
|
|
|
Returns:
|
|
str: The role of the message.
|
|
"""
|
|
if isinstance(message, ChatMessage) or isinstance(message, HumanMessage):
|
|
return "user"
|
|
if isinstance(message, AIMessage):
|
|
return "assistant"
|
|
if isinstance(message, SystemMessage):
|
|
return "system"
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
|
|
def get_chat_request(messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
"""Get a request of the Friendli chat API.
|
|
|
|
Args:
|
|
messages (List[BaseMessage]): Messages comprising the conversation so far.
|
|
|
|
Returns:
|
|
Dict[str, Any]: The request for the Friendli chat API.
|
|
"""
|
|
return {
|
|
"messages": [
|
|
{"role": get_role(message), "content": message.content}
|
|
for message in messages
|
|
]
|
|
}
|
|
|
|
|
|
class ChatFriendli(BaseChatModel, BaseFriendli):
|
|
"""Friendli LLM for chat.
|
|
|
|
``friendli-client`` package should be installed with `pip install friendli-client`.
|
|
You must set ``FRIENDLI_TOKEN`` environment variable or provide the value of your
|
|
personal access token for the ``friendli_token`` argument.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models import FriendliChat
|
|
|
|
chat = Friendli(
|
|
model="llama-2-13b-chat", friendli_token="YOUR FRIENDLI TOKEN"
|
|
)
|
|
chat.invoke("What is generative AI?")
|
|
"""
|
|
|
|
model: str = "llama-2-13b-chat"
|
|
|
|
@property
|
|
def lc_secrets(self) -> Dict[str, str]:
|
|
return {"friendli_token": "FRIENDLI_TOKEN"}
|
|
|
|
@property
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
"""Get the default parameters for calling Friendli completions API."""
|
|
return {
|
|
"frequency_penalty": self.frequency_penalty,
|
|
"presence_penalty": self.presence_penalty,
|
|
"max_tokens": self.max_tokens,
|
|
"stop": self.stop,
|
|
"temperature": self.temperature,
|
|
"top_p": self.top_p,
|
|
}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {"model": self.model, **self._default_params}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
return "friendli-chat"
|
|
|
|
def _get_invocation_params(
|
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
|
) -> Dict[str, Any]:
|
|
"""Get the parameters used to invoke the model."""
|
|
params = self._default_params
|
|
if self.stop is not None and stop is not None:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
elif self.stop is not None:
|
|
params["stop"] = self.stop
|
|
else:
|
|
params["stop"] = stop
|
|
return {**params, **kwargs}
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
stream = self.client.chat.completions.create(
|
|
**get_chat_request(messages), stream=True, model=self.model, **params
|
|
)
|
|
for chunk in stream:
|
|
delta = chunk.choices[0].delta.content
|
|
if delta:
|
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(delta)
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
stream = await self.async_client.chat.completions.create(
|
|
**get_chat_request(messages), stream=True, model=self.model, **params
|
|
)
|
|
async for chunk in stream:
|
|
delta = chunk.choices[0].delta.content
|
|
if delta:
|
|
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(delta)
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if self.streaming:
|
|
stream_iter = self._stream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
response = self.client.chat.completions.create(
|
|
messages=[
|
|
{
|
|
"role": get_role(message),
|
|
"content": message.content,
|
|
}
|
|
for message in messages
|
|
],
|
|
stream=False,
|
|
model=self.model,
|
|
**params,
|
|
)
|
|
|
|
message = AIMessage(content=response.choices[0].message.content)
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if self.streaming:
|
|
stream_iter = self._astream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await agenerate_from_stream(stream_iter)
|
|
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
response = await self.async_client.chat.completions.create(
|
|
messages=[
|
|
{
|
|
"role": get_role(message),
|
|
"content": message.content,
|
|
}
|
|
for message in messages
|
|
],
|
|
stream=False,
|
|
model=self.model,
|
|
**params,
|
|
)
|
|
|
|
message = AIMessage(content=response.choices[0].message.content)
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|