diff --git a/libs/langchain/langchain/llms/sagemaker_endpoint.py b/libs/langchain/langchain/llms/sagemaker_endpoint.py index 9505f17330..b38cbd40d4 100644 --- a/libs/langchain/langchain/llms/sagemaker_endpoint.py +++ b/libs/langchain/langchain/llms/sagemaker_endpoint.py @@ -80,6 +80,22 @@ class SagemakerEndpoint(LLM): """ """ + Args: + + region_name: The aws region e.g., `us-west-2`. + Fallsback to AWS_DEFAULT_REGION env variable + or region specified in ~/.aws/config. + + credentials_profile_name: The name of the profile in the ~/.aws/credentials + or ~/.aws/config files, which has either access keys or role information + specified. If not specified, the default credential profile or, if on an + EC2 instance, credentials from IMDS will be used. + + client: boto3 client for Sagemaker Endpoint + + content_handler: Implementation for model specific LLMContentHandler + + Example: .. code-block:: python @@ -98,8 +114,21 @@ class SagemakerEndpoint(LLM): region_name=region_name, credentials_profile_name=credentials_profile_name ) + + #Use with boto3 client + client = boto3.client( + "sagemaker-runtime", + region_name=region_name + ) + + se = SagemakerEndpoint( + endpoint_name=endpoint_name, + client=client + ) + """ - client: Any #: :meta private: + client: Any = None + """Boto3 client for sagemaker runtime""" endpoint_name: str = "" """The name of the endpoint from the deployed Sagemaker model. @@ -157,6 +186,10 @@ class SagemakerEndpoint(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: + """Dont do anything if client provided externally""" + if values.get("client") is not None: + return values + """Validate that AWS credentials to and python package exists in environment.""" try: import boto3