From a4e858b11165e6d2ef10478d4caf99a185c5b9f5 Mon Sep 17 00:00:00 2001 From: Vikram Shitole Date: Tue, 19 Sep 2023 20:36:12 +0530 Subject: [PATCH] Sagemaker endpoint capability to inject boto3 client for cross account scenarios (#10728) - **Description: Allow to inject boto3 client for Cross account access type of scenarios in using Sagemaker Endpoint ** - **Issue:#10634 #10184** - **Dependencies: None** - **Tag maintainer:** - **Twitter handle:lethargicoder** Co-authored-by: Vikram(VS) --- .../langchain/llms/sagemaker_endpoint.py | 35 ++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index 9505f17330..b38cbd40d4 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -80,6 +80,22 @@ class SagemakerEndpoint(LLM): """ """ + Args: + + region_name: The aws region e.g., `us-west-2`. + Fallsback to AWS_DEFAULT_REGION env variable + or region specified in ~/.aws/config. + + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + + client: boto3 client for Sagemaker Endpoint + + content_handler: Implementation for model specific LLMContentHandler + + Example: .. code-block:: python @@ -98,8 +114,21 @@ class SagemakerEndpoint(LLM): region_name=region_name, credentials_profile_name=credentials_profile_name ) + + #Use with boto3 client + client = boto3.client( + "sagemaker-runtime", + region_name=region_name + ) + + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + client=client + ) + """ - client: Any #: :meta private: + client: Any = None + """Boto3 client for sagemaker runtime""" endpoint_name: str = "" """The name of the endpoint from the deployed Sagemaker model. @@ -157,6 +186,10 @@ class SagemakerEndpoint(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: + """Dont do anything if client provided externally""" + if values.get("client") is not None: + return values + """Validate that AWS credentials to and python package exists in environment.""" try: import boto3