@ -1,10 +1,21 @@
import json
from http import HTTPStatus
from typing import Any , Dict, List , Optional , Union
from typing import Any , AsyncIterator, Dict, Iterator , List , Optional , Union
import requests
from langchain_core . callbacks import CallbackManagerForLLMRun
from langchain_core . language_models . chat_models import SimpleChatModel
from langchain_core . messages import AIMessage , BaseMessage , HumanMessage , SystemMessage
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
from langchain_core . language_models . chat_models import BaseChatModel
from langchain_core . messages import (
AIMessage ,
AIMessageChunk ,
BaseMessage ,
HumanMessage ,
SystemMessage ,
)
from langchain_core . outputs import ChatGeneration , ChatGenerationChunk , ChatResult
from langchain_core . pydantic_v1 import Field
from requests import Response
from requests . exceptions import HTTPError
@ -34,7 +45,7 @@ class MaritalkHTTPError(HTTPError):
return formatted_message
class ChatMaritalk ( Simpl eChatModel) :
class ChatMaritalk ( Bas eChatModel) :
""" `MariTalk` Chat models API.
This class allows interacting with the MariTalk chatbot API .
@ -132,7 +143,51 @@ class ChatMaritalk(SimpleChatModel):
If an error occurs ( e . g . , rate limiting ) , returns a string
describing the error .
"""
url = " https://chat.maritaca.ai/api/chat/inference "
headers = { " authorization " : f " Key { self . api_key } " }
stopping_tokens = stop if stop is not None else [ ]
parsed_messages = self . parse_messages_for_model ( messages )
data = {
" messages " : parsed_messages ,
" model " : self . model ,
" do_sample " : self . do_sample ,
" max_tokens " : self . max_tokens ,
" temperature " : self . temperature ,
" top_p " : self . top_p ,
" stopping_tokens " : stopping_tokens ,
* * kwargs ,
}
response = requests . post ( url , json = data , headers = headers )
if response . ok :
return response . json ( ) . get ( " answer " , " No answer found " )
else :
raise MaritalkHTTPError ( response )
async def _acall (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
"""
Asynchronously sends the parsed messages to the MariTalk API and returns
the generated response or an error message .
This method makes an HTTP POST request to the MariTalk API with the
provided messages and other parameters using async I / O .
If the request is successful and the API returns a response ,
this method returns a string containing the answer .
If the request is rate - limited or encounters another error ,
it returns a string with the error message .
"""
try :
import httpx
url = " https://chat.maritaca.ai/api/chat/inference "
headers = { " authorization " : f " Key { self . api_key } " }
stopping_tokens = stop if stop is not None else [ ]
@ -150,18 +205,157 @@ class ChatMaritalk(SimpleChatModel):
* * kwargs ,
}
response = requests . post ( url , json = data , headers = headers )
async with httpx . AsyncClient ( ) as client :
response = await client . post (
url , json = data , headers = headers , timeout = None
)
if response . ok :
if response . status_code == 200 :
return response . json ( ) . get ( " answer " , " No answer found " )
else :
raise MaritalkHTTPError ( response )
except requests . exceptions . RequestException as e :
return f " An error occurred: { str ( e ) } "
except ImportError :
raise ImportError (
" Could not import httpx python package. "
" Please install it with `pip install httpx`. "
)
def _stream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ ChatGenerationChunk ] :
headers = { " Authorization " : f " Key { self . api_key } " }
stopping_tokens = stop if stop is not None else [ ]
parsed_messages = self . parse_messages_for_model ( messages )
data = {
" messages " : parsed_messages ,
" model " : self . model ,
" do_sample " : self . do_sample ,
" max_tokens " : self . max_tokens ,
" temperature " : self . temperature ,
" top_p " : self . top_p ,
" stopping_tokens " : stopping_tokens ,
" stream " : True ,
* * kwargs ,
}
response = requests . post (
" https://chat.maritaca.ai/api/chat/inference " ,
data = json . dumps ( data ) ,
headers = headers ,
stream = True ,
)
if response . ok :
for line in response . iter_lines ( ) :
if line . startswith ( b " data: " ) :
response_data = line . replace ( b " data: " , b " " ) . decode ( " utf-8 " )
if response_data :
parsed_data = json . loads ( response_data )
if " text " in parsed_data :
delta = parsed_data [ " text " ]
chunk = ChatGenerationChunk (
message = AIMessageChunk ( content = delta )
)
if run_manager :
run_manager . on_llm_new_token ( delta , chunk = chunk )
yield chunk
else :
raise MaritalkHTTPError ( response )
async def _astream (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ ChatGenerationChunk ] :
try :
import httpx
headers = { " Authorization " : f " Key { self . api_key } " }
stopping_tokens = stop if stop is not None else [ ]
parsed_messages = self . parse_messages_for_model ( messages )
data = {
" messages " : parsed_messages ,
" model " : self . model ,
" do_sample " : self . do_sample ,
" max_tokens " : self . max_tokens ,
" temperature " : self . temperature ,
" top_p " : self . top_p ,
" stopping_tokens " : stopping_tokens ,
" stream " : True ,
* * kwargs ,
}
async with httpx . AsyncClient ( ) as client :
async with client . stream (
" POST " ,
" https://chat.maritaca.ai/api/chat/inference " ,
data = json . dumps ( data ) ,
headers = headers ,
timeout = None ,
) as response :
if response . status_code == 200 :
async for line in response . aiter_lines ( ) :
if line . startswith ( " data: " ) :
response_data = line . replace ( " data: " , " " )
if response_data :
parsed_data = json . loads ( response_data )
if " text " in parsed_data :
delta = parsed_data [ " text " ]
chunk = ChatGenerationChunk (
message = AIMessageChunk ( content = delta )
)
if run_manager :
await run_manager . on_llm_new_token (
delta , chunk = chunk
)
yield chunk
else :
raise MaritalkHTTPError ( response )
except ImportError :
raise ImportError (
" Could not import httpx python package. "
" Please install it with `pip install httpx`. "
)
def _generate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
output_str = self . _call ( messages , stop = stop , run_manager = run_manager , * * kwargs )
message = AIMessage ( content = output_str )
generation = ChatGeneration ( message = message )
return ChatResult ( generations = [ generation ] )
# Fallback return statement, in case of unexpected code paths
return " An unexpected error occurred "
async def _agenerate (
self ,
messages : List [ BaseMessage ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > ChatResult :
output_str = await self . _acall (
messages , stop = stop , run_manager = run_manager , * * kwargs
)
message = AIMessage ( content = output_str )
generation = ChatGeneration ( message = message )
return ChatResult ( generations = [ generation ] )
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :