@ -1,11 +1,12 @@
import json
from abc import ABC
from typing import Any , Dict , List, Mapping , Optional
from typing import Any , Dict , Iterator, List, Mapping , Optional
from langchain . callbacks . manager import CallbackManagerForLLMRun
from langchain . llms . base import LLM
from langchain . llms . utils import enforce_stop_tokens
from langchain . pydantic_v1 import BaseModel , Extra , root_validator
from langchain . schema . output import GenerationChunk
class LLMInputOutputAdapter :
@ -15,6 +16,11 @@ class LLMInputOutputAdapter:
It also provides helper function to extract
the generated text from the model response . """
provider_to_output_key_map = {
" anthropic " : " completion " ,
" amazon " : " outputText " ,
}
@classmethod
def prepare_input (
cls , provider : str , prompt : str , model_kwargs : Dict [ str , Any ]
@ -30,7 +36,7 @@ class LLMInputOutputAdapter:
input_body [ " inputText " ] = prompt
if provider == " anthropic " and " max_tokens_to_sample " not in input_body :
input_body [ " max_tokens_to_sample " ] = 50
input_body [ " max_tokens_to_sample " ] = 256
return input_body
@ -47,6 +53,30 @@ class LLMInputOutputAdapter:
else :
return response_body . get ( " results " ) [ 0 ] . get ( " outputText " )
@classmethod
def prepare_output_stream (
cls , provider : str , response : Any , stop : Optional [ List [ str ] ] = None
) - > Iterator [ GenerationChunk ] :
stream = response . get ( " body " )
if not stream :
return
if provider not in cls . provider_to_output_key_map :
raise ValueError (
f " Unknown streaming response output key for provider: { provider } "
)
for event in stream :
chunk = event . get ( " chunk " )
if chunk :
chunk_obj = json . loads ( chunk . get ( " bytes " ) . decode ( ) )
# chunk obj format varies with provider
yield GenerationChunk (
text = chunk_obj [ cls . provider_to_output_key_map [ provider ] ]
)
class BedrockBase ( BaseModel , ABC ) :
client : Any #: :meta private:
@ -74,6 +104,15 @@ class BedrockBase(BaseModel, ABC):
endpoint_url : Optional [ str ] = None
""" Needed if you don ' t want to default to us-east-1 endpoint """
streaming : bool = False
""" Whether to stream the results. """
provider_stop_sequence_key_name_map : Mapping [ str , str ] = {
" anthropic " : " stop_sequences " ,
" amazon " : " stopSequences " ,
" ai21 " : " stop_sequences " ,
}
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that AWS credentials to and python package exists in environment. """
@ -154,6 +193,49 @@ class BedrockBase(BaseModel, ABC):
return text
def _prepare_input_and_invoke_stream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ GenerationChunk ] :
_model_kwargs = self . model_kwargs or { }
provider = self . _get_provider ( )
if stop :
if provider not in self . provider_stop_sequence_key_name_map :
raise ValueError (
f " Stop sequence key name for { provider } is not supported. "
)
# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs [
self . provider_stop_sequence_key_name_map . get ( provider ) ,
] = stop
params = { * * _model_kwargs , * * kwargs }
input_body = LLMInputOutputAdapter . prepare_input ( provider , prompt , params )
body = json . dumps ( input_body )
try :
response = self . client . invoke_model_with_response_stream (
body = body ,
modelId = self . model_id ,
accept = " application/json " ,
contentType = " application/json " ,
)
except Exception as e :
raise ValueError ( f " Error raised by bedrock service: { e } " )
for chunk in LLMInputOutputAdapter . prepare_output_stream (
provider , response , stop
) :
yield chunk
if run_manager is not None :
run_manager . on_llm_new_token ( chunk . text , chunk = chunk )
class Bedrock ( LLM , BedrockBase ) :
""" Bedrock models.
@ -177,7 +259,8 @@ class Bedrock(LLM, BedrockBase):
llm = BedrockLLM (
credentials_profile_name = " default " ,
model_id = " amazon.titan-tg1-large "
model_id = " amazon.titan-tg1-large " ,
streaming = True
)
"""
@ -192,6 +275,33 @@ class Bedrock(LLM, BedrockBase):
extra = Extra . forbid
def _stream (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > Iterator [ GenerationChunk ] :
""" Call out to Bedrock service with streaming.
Args :
prompt ( str ) : The prompt to pass into the model
stop ( Optional [ List [ str ] ] , optional ) : Stop sequences . These will
override any stop sequences in the ` model_kwargs ` attribute .
Defaults to None .
run_manager ( Optional [ CallbackManagerForLLMRun ] , optional ) : Callback
run managers used to process the output . Defaults to None .
Returns :
Iterator [ GenerationChunk ] : Generator that yields the streamed responses .
Yields :
Iterator [ GenerationChunk ] : Responses from the model .
"""
return self . _prepare_input_and_invoke_stream (
prompt = prompt , stop = stop , run_manager = run_manager , * * kwargs
)
def _call (
self ,
prompt : str ,
@ -211,9 +321,15 @@ class Bedrock(LLM, BedrockBase):
Example :
. . code - block : : python
response = se ( " Tell me a joke. " )
response = llm ( " Tell me a joke. " )
"""
text = self . _prepare_input_and_invoke ( prompt = prompt , stop = stop , * * kwargs )
if self . streaming :
completion = " "
for chunk in self . _stream (
prompt = prompt , stop = stop , run_manager = run_manager , * * kwargs
) :
completion + = chunk . text
return completion
return text
return self . _prepare_input_and_invoke ( prompt = prompt , stop = stop , * * kwargs )