@ -1,11 +1,25 @@
from __future__ import annotations
import asyncio
import json
import warnings
from abc import ABC
from typing import TYPE_CHECKING , Any , Dict , Iterator , List , Mapping , Optional
from typing import (
TYPE_CHECKING ,
Any ,
AsyncGenerator ,
AsyncIterator ,
Dict ,
Iterator ,
List ,
Mapping ,
Optional ,
)
from langchain_core . callbacks import CallbackManagerForLLMRun
from langchain_core . callbacks import (
AsyncCallbackManagerForLLMRun ,
CallbackManagerForLLMRun ,
)
from langchain_core . language_models . llms import LLM
from langchain_core . outputs import GenerationChunk
from langchain_core . pydantic_v1 import BaseModel , Extra , Field , root_validator
@ -128,26 +142,56 @@ class LLMInputOutputAdapter:
if not stream :
return
if provider not in cls . provider_to_output_key_map :
output_key = cls . provider_to_output_key_map . get ( provider , None )
if not output_key :
raise ValueError (
f " Unknown streaming response output key for provider: { provider } "
)
for event in stream :
chunk = event . get ( " chunk " )
if chunk :
chunk_obj = json . loads ( chunk . get ( " bytes " ) . decode ( ) )
if provider == " cohere " and (
chunk_obj [ " is_finished " ]
or chunk_obj [ cls . provider_to_output_key_map [ provider ] ]
== " <EOS_TOKEN> "
) :
return
# chunk obj format varies with provider
yield GenerationChunk (
text = chunk_obj [ cls . provider_to_output_key_map [ provider ] ]
)
if not chunk :
continue
chunk_obj = json . loads ( chunk . get ( " bytes " ) . decode ( ) )
if provider == " cohere " and (
chunk_obj [ " is_finished " ] or chunk_obj [ output_key ] == " <EOS_TOKEN> "
) :
return
yield GenerationChunk ( text = chunk_obj [ output_key ] )
@classmethod
async def aprepare_output_stream (
cls , provider : str , response : Any , stop : Optional [ List [ str ] ] = None
) - > AsyncIterator [ GenerationChunk ] :
stream = response . get ( " body " )
if not stream :
return
output_key = cls . provider_to_output_key_map . get ( provider , None )
if not output_key :
raise ValueError (
f " Unknown streaming response output key for provider: { provider } "
)
for event in stream :
chunk = event . get ( " chunk " )
if not chunk :
continue
chunk_obj = json . loads ( chunk . get ( " bytes " ) . decode ( ) )
if provider == " cohere " and (
chunk_obj [ " is_finished " ] or chunk_obj [ output_key ] == " <EOS_TOKEN> "
) :
return
yield GenerationChunk ( text = chunk_obj [ output_key ] )
class BedrockBase ( BaseModel , ABC ) :
@ -332,6 +376,51 @@ class BedrockBase(BaseModel, ABC):
if run_manager is not None :
run_manager . on_llm_new_token ( chunk . text , chunk = chunk )
async def _aprepare_input_and_invoke_stream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncIterator [ GenerationChunk ] :
_model_kwargs = self . model_kwargs or { }
provider = self . _get_provider ( )
if stop :
if provider not in self . provider_stop_sequence_key_name_map :
raise ValueError (
f " Stop sequence key name for { provider } is not supported. "
)
_model_kwargs [ self . provider_stop_sequence_key_name_map . get ( provider ) ] = stop
if provider == " cohere " :
_model_kwargs [ " stream " ] = True
params = { * * _model_kwargs , * * kwargs }
input_body = LLMInputOutputAdapter . prepare_input ( provider , prompt , params )
body = json . dumps ( input_body )
response = await asyncio . get_running_loop ( ) . run_in_executor (
None ,
lambda : self . client . invoke_model_with_response_stream (
body = body ,
modelId = self . model_id ,
accept = " application/json " ,
contentType = " application/json " ,
) ,
)
async for chunk in LLMInputOutputAdapter . aprepare_output_stream (
provider , response , stop
) :
yield chunk
if run_manager is not None and asyncio . iscoroutinefunction (
run_manager . on_llm_new_token
) :
await run_manager . on_llm_new_token ( chunk . text , chunk = chunk )
elif run_manager is not None :
run_manager . on_llm_new_token ( chunk . text , chunk = chunk )
class Bedrock ( LLM , BedrockBase ) :
""" Bedrock models.
@ -449,6 +538,65 @@ class Bedrock(LLM, BedrockBase):
return self . _prepare_input_and_invoke ( prompt = prompt , stop = stop , * * kwargs )
async def _astream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > AsyncGenerator [ GenerationChunk , None ] :
""" Call out to Bedrock service with streaming.
Args :
prompt ( str ) : The prompt to pass into the model
stop ( Optional [ List [ str ] ] , optional ) : Stop sequences . These will
override any stop sequences in the ` model_kwargs ` attribute .
Defaults to None .
run_manager ( Optional [ CallbackManagerForLLMRun ] , optional ) : Callback
run managers used to process the output . Defaults to None .
Yields :
AsyncGenerator [ GenerationChunk , None ] : Generator that asynchronously yields
the streamed responses .
"""
async for chunk in self . _aprepare_input_and_invoke_stream (
prompt = prompt , stop = stop , run_manager = run_manager , * * kwargs
) :
yield chunk
async def _acall (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Call out to Bedrock service model.
Args :
prompt : The prompt to pass into the model .
stop : Optional list of stop words to use when generating .
Returns :
The string generated by the model .
Example :
. . code - block : : python
response = await llm . _acall ( " Tell me a joke. " )
"""
if not self . streaming :
raise ValueError ( " Streaming must be set to True for async operations. " )
chunks = [
chunk . text
async for chunk in self . _astream (
prompt = prompt , stop = stop , run_manager = run_manager , * * kwargs
)
]
return " " . join ( chunks )
def get_num_tokens ( self , text : str ) - > int :
if self . _model_is_anthropic :
return get_num_tokens_anthropic ( text )