mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
9678797625
- Description: callback on_llm_new_token before yield chunk for _stream/_astream for some chat models, make all chat models in a consistent behaviour. - Issue: N/A - Dependencies: N/A
182 lines
6.3 KiB
Python
182 lines
6.3 KiB
Python
import logging
|
|
from typing import Any, AsyncIterator, Iterator, List, Optional
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
agenerate_from_stream,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
ChatMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
|
|
from langchain_community.llms.gigachat import _BaseGigaChat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _convert_dict_to_message(message: Any) -> BaseMessage:
|
|
from gigachat.models import MessagesRole
|
|
|
|
if message.role == MessagesRole.SYSTEM:
|
|
return SystemMessage(content=message.content)
|
|
elif message.role == MessagesRole.USER:
|
|
return HumanMessage(content=message.content)
|
|
elif message.role == MessagesRole.ASSISTANT:
|
|
return AIMessage(content=message.content)
|
|
else:
|
|
raise TypeError(f"Got unknown role {message.role} {message}")
|
|
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> Any:
|
|
from gigachat.models import Messages, MessagesRole
|
|
|
|
if isinstance(message, SystemMessage):
|
|
return Messages(role=MessagesRole.SYSTEM, content=message.content)
|
|
elif isinstance(message, HumanMessage):
|
|
return Messages(role=MessagesRole.USER, content=message.content)
|
|
elif isinstance(message, AIMessage):
|
|
return Messages(role=MessagesRole.ASSISTANT, content=message.content)
|
|
elif isinstance(message, ChatMessage):
|
|
return Messages(role=MessagesRole(message.role), content=message.content)
|
|
else:
|
|
raise TypeError(f"Got unknown type {message}")
|
|
|
|
|
|
class GigaChat(_BaseGigaChat, BaseChatModel):
|
|
"""`GigaChat` large language models API.
|
|
|
|
To use, you should pass login and password to access GigaChat API or use token.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models import GigaChat
|
|
giga = GigaChat(credentials=..., verify_ssl_certs=False)
|
|
"""
|
|
|
|
def _build_payload(self, messages: List[BaseMessage]) -> Any:
|
|
from gigachat.models import Chat
|
|
|
|
payload = Chat(
|
|
messages=[_convert_message_to_dict(m) for m in messages],
|
|
profanity_check=self.profanity,
|
|
)
|
|
if self.temperature is not None:
|
|
payload.temperature = self.temperature
|
|
if self.max_tokens is not None:
|
|
payload.max_tokens = self.max_tokens
|
|
|
|
if self.verbose:
|
|
logger.info("Giga request: %s", payload.dict())
|
|
|
|
return payload
|
|
|
|
def _create_chat_result(self, response: Any) -> ChatResult:
|
|
generations = []
|
|
for res in response.choices:
|
|
message = _convert_dict_to_message(res.message)
|
|
finish_reason = res.finish_reason
|
|
gen = ChatGeneration(
|
|
message=message,
|
|
generation_info={"finish_reason": finish_reason},
|
|
)
|
|
generations.append(gen)
|
|
if finish_reason != "stop":
|
|
logger.warning(
|
|
"Giga generation stopped with reason: %s",
|
|
finish_reason,
|
|
)
|
|
if self.verbose:
|
|
logger.info("Giga response: %s", message.content)
|
|
llm_output = {"token_usage": response.usage, "model_name": response.model}
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
stream: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
should_stream = stream if stream is not None else self.streaming
|
|
if should_stream:
|
|
stream_iter = self._stream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
payload = self._build_payload(messages)
|
|
response = self._client.chat(payload)
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
stream: Optional[bool] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
should_stream = stream if stream is not None else self.streaming
|
|
if should_stream:
|
|
stream_iter = self._astream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await agenerate_from_stream(stream_iter)
|
|
|
|
payload = self._build_payload(messages)
|
|
response = await self._client.achat(payload)
|
|
|
|
return self._create_chat_result(response)
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
payload = self._build_payload(messages)
|
|
|
|
for chunk in self._client.stream(payload):
|
|
if chunk.choices:
|
|
content = chunk.choices[0].delta.content
|
|
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
|
yield cg_chunk
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
payload = self._build_payload(messages)
|
|
|
|
async for chunk in self._client.astream(payload):
|
|
if chunk.choices:
|
|
content = chunk.choices[0].delta.content
|
|
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
|
yield cg_chunk
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
"""Count approximate number of tokens"""
|
|
return round(len(text) / 4.6)
|