@ -11,12 +11,17 @@ from langchain.callbacks.manager import (
from langchain . llms . base import LLM , create_base_retry_decorator
from langchain . llms . utils import enforce_stop_tokens
from langchain . pydantic_v1 import BaseModel , root_validator
from langchain . schema import (
Generation ,
LLMResult ,
)
from langchain . utilities . vertexai import (
init_vertexai ,
raise_vertex_import_error ,
)
if TYPE_CHECKING :
from google . cloud . aiplatform . gapic import PredictionServiceClient
from vertexai . language_models . _language_models import _LanguageModel
@ -57,10 +62,40 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
return _completion_with_retry ( * args , * * kwargs )
class _VertexAICommon ( BaseModel ) :
class _VertexAIBase ( BaseModel ) :
project : Optional [ str ] = None
" The default GCP project to use when making Vertex API calls. "
location : str = " us-central1 "
" The default location to use when making API calls. "
request_parallelism : int = 5
" The amount of parallelism allowed for requests issued to VertexAI models. "
" Default is 5. "
max_retries : int = 6
""" The maximum number of retries to make when generating. """
task_executor : ClassVar [ Optional [ Executor ] ] = None
stop : Optional [ List [ str ] ] = None
" Optional list of stop words to use when generating. "
model_name : Optional [ str ] = None
" Underlying model name. "
def _enforce_stop_words ( self , text : str , stop : Optional [ List [ str ] ] = None ) - > str :
if stop is None and self . stop is not None :
stop = self . stop
if stop :
return enforce_stop_tokens ( text , stop )
return text
@classmethod
def _get_task_executor ( cls , request_parallelism : int = 5 ) - > Executor :
if cls . task_executor is None :
cls . task_executor = ThreadPoolExecutor ( max_workers = request_parallelism )
return cls . task_executor
class _VertexAICommon ( _VertexAIBase ) :
client : " _LanguageModel " = None #: :meta private:
model_name : str
" Model name to use. "
" Underlying model nam e."
temperature : float = 0.0
" Sampling temperature, it controls the degree of randomness in token selection. "
max_output_tokens : int = 128
@ -71,27 +106,20 @@ class _VertexAICommon(BaseModel):
top_k : int = 40
" How the model selects tokens for output, the next token is selected from "
" among the top-k most probable tokens. Top-k is ignored for Codey models. "
stop : Optional [ List [ str ] ] = None
" Optional list of stop words to use when generating. "
project : Optional [ str ] = None
" The default GCP project to use when making Vertex API calls. "
location : str = " us-central1 "
" The default location to use when making API calls. "
credentials : Any = None
" The default custom credentials (google.auth.credentials.Credentials) to use "
" when making API calls. If not provided, credentials will be ascertained from "
" the environment. "
request_parallelism : int = 5
" The amount of parallelism allowed for requests issued to VertexAI models. "
" Default is 5. "
max_retries : int = 6
""" The maximum number of retries to make when generating. """
task_executor : ClassVar [ Optional [ Executor ] ] = None
@property
def is_codey_model ( self ) - > bool :
return is_codey_model ( self . model_name )
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
return { * * { " model_name " : self . model_name } , * * self . _default_params }
@property
def _default_params ( self ) - > Dict [ str , Any ] :
if self . is_codey_model :
@ -114,28 +142,10 @@ class _VertexAICommon(BaseModel):
res = completion_with_retry ( self , prompt , * * params ) # type: ignore
return self . _enforce_stop_words ( res . text , stop )
def _enforce_stop_words ( self , text : str , stop : Optional [ List [ str ] ] = None ) - > str :
if stop is None and self . stop is not None :
stop = self . stop
if stop :
return enforce_stop_tokens ( text , stop )
return text
@property
def _identifying_params ( self ) - > Dict [ str , Any ] :
""" Get the identifying parameters. """
return { * * { " model_name " : self . model_name } , * * self . _default_params }
@property
def _llm_type ( self ) - > str :
return " vertexai "
@classmethod
def _get_task_executor ( cls , request_parallelism : int = 5 ) - > Executor :
if cls . task_executor is None :
cls . task_executor = ThreadPoolExecutor ( max_workers = request_parallelism )
return cls . task_executor
@classmethod
def _try_init_vertexai ( cls , values : Dict ) - > None :
allowed_params = [ " project " , " location " , " credentials " ]
@ -176,6 +186,25 @@ class VertexAI(_VertexAICommon, LLM):
raise_vertex_import_error ( )
return values
def _call (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Call Vertex model to get predictions based on the prompt.
Args :
prompt : The prompt to pass into the model .
stop : A list of stop words ( optional ) .
run_manager : A Callbackmanager for LLM run , optional .
Returns :
The string generated by the model .
"""
return self . _predict ( prompt , stop , * * kwargs )
async def _acall (
self ,
prompt : str ,
@ -194,9 +223,44 @@ class VertexAI(_VertexAICommon, LLM):
The string generated by the model .
"""
return await asyncio . wrap_future (
self . _get_task_executor ( ) . submit ( self . _ predict , prompt , stop )
self . _get_task_executor ( ) . submit ( self . _ call , prompt , stop )
)
class VertexAIModelGarden ( _VertexAIBase , LLM ) :
""" Large language models served from Vertex AI Model Garden. """
client : " PredictionServiceClient " = None #: :meta private:
endpoint_id : str
" A name of an endpoint where the model has been deployed. "
allowed_model_args : Optional [ List [ str ] ] = None
""" Allowed optional args to be passed to the model. """
prompt_arg : str = " prompt "
result_arg : str = " generated_text "
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
""" Validate that the python package exists in environment. """
try :
from google . cloud . aiplatform . gapic import PredictionServiceClient
except ImportError :
raise_vertex_import_error ( )
if values [ " project " ] is None :
raise ValueError (
" A GCP project should be provided to run inference on Model Garden! "
)
client_options = {
" api_endpoint " : f " { values [ ' location ' ] } -aiplatform.googleapis.com "
}
values [ " client " ] = PredictionServiceClient ( client_options = client_options )
return values
@property
def _llm_type ( self ) - > str :
return " vertexai_model_garden "
def _call (
self ,
prompt : str ,
@ -214,4 +278,71 @@ class VertexAI(_VertexAICommon, LLM):
Returns :
The string generated by the model .
"""
return self . _predict ( prompt , stop , * * kwargs )
result = self . _generate (
prompts = [ prompt ] , stop = stop , run_manager = run_manager , * * kwargs
)
return result . generations [ 0 ] [ 0 ] . text
def _generate (
self ,
prompts : List [ str ] ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ CallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > LLMResult :
""" Run the LLM on the given prompt and input. """
try :
from google . protobuf import json_format
from google . protobuf . struct_pb2 import Value
except ImportError :
raise ImportError (
" protobuf package not found, please install it with "
" `pip install protobuf` "
)
instances = [ ]
for prompt in prompts :
if self . allowed_model_args :
instance = {
k : v for k , v in kwargs . items ( ) if k in self . allowed_model_args
}
else :
instance = { }
instance [ self . prompt_arg ] = prompt
instances . append ( instance )
predict_instances = [
json_format . ParseDict ( instance_dict , Value ( ) ) for instance_dict in instances
]
endpoint = self . client . endpoint_path (
project = self . project , location = self . location , endpoint = self . endpoint_id
)
response = self . client . predict ( endpoint = endpoint , instances = predict_instances )
generations : List [ List [ Generation ] ] = [ ]
for result in response . predictions :
generations . append (
[ Generation ( text = prediction [ self . result_arg ] ) for prediction in result ]
)
return LLMResult ( generations = generations )
async def _acall (
self ,
prompt : str ,
stop : Optional [ List [ str ] ] = None ,
run_manager : Optional [ AsyncCallbackManagerForLLMRun ] = None ,
* * kwargs : Any ,
) - > str :
""" Call Vertex model to get predictions based on the prompt.
Args :
prompt : The prompt to pass into the model .
stop : A list of stop words ( optional ) .
run_manager : A callback manager for async interaction with LLMs .
Returns :
The string generated by the model .
"""
return await asyncio . wrap_future (
self . _get_task_executor ( ) . submit ( self . _call , prompt , stop )
)