@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
from langchain_core . pydantic_v1 import BaseModel , SecretStr , root_validator , validator
from langchain_core . pydantic_v1 import BaseModel , SecretStr , root_validator , validator
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
DEFAULT_TIMEOUT = 50
class AzureMLEndpointClient ( object ) :
class AzureMLEndpointClient ( object ) :
""" AzureML Managed Endpoint client. """
""" AzureML Managed Endpoint client. """
def __init__ (
def __init__ (
self , endpoint_url : str , endpoint_api_key : str , deployment_name : str = " "
self ,
endpoint_url : str ,
endpoint_api_key : str ,
deployment_name : str = " " ,
timeout : int = DEFAULT_TIMEOUT ,
) - > None :
) - > None :
""" Initialize the class. """
""" Initialize the class. """
if not endpoint_api_key or not endpoint_url :
if not endpoint_api_key or not endpoint_url :
@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
self . endpoint_url = endpoint_url
self . endpoint_url = endpoint_url
self . endpoint_api_key = endpoint_api_key
self . endpoint_api_key = endpoint_api_key
self . deployment_name = deployment_name
self . deployment_name = deployment_name
self . timeout = timeout
def call (
def call (
self ,
self ,
@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
headers [ " azureml-model-deployment " ] = self . deployment_name
headers [ " azureml-model-deployment " ] = self . deployment_name
req = urllib . request . Request ( self . endpoint_url , body , headers )
req = urllib . request . Request ( self . endpoint_url , body , headers )
response = urllib . request . urlopen ( req , timeout = kwargs . get ( " timeout " , 50 ) )
response = urllib . request . urlopen (
req , timeout = kwargs . get ( " timeout " , self . timeout )
)
result = response . read ( )
result = response . read ( )
return result
return result
@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
""" Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
""" Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var ` AZUREML_DEPLOYMENT_NAME ` . """
to constructor or specified as env var ` AZUREML_DEPLOYMENT_NAME ` . """
timeout : int = DEFAULT_TIMEOUT
""" Request timeout for calls to the endpoint """
http_client : Any = None #: :meta private:
http_client : Any = None #: :meta private:
content_formatter : Any = None
content_formatter : Any = None
@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
" AZUREML_ENDPOINT_API_TYPE " ,
" AZUREML_ENDPOINT_API_TYPE " ,
AzureMLEndpointApiType . realtime ,
AzureMLEndpointApiType . realtime ,
)
)
values [ " timeout " ] = get_from_dict_or_env (
values ,
" timeout " ,
" AZUREML_TIMEOUT " ,
str ( DEFAULT_TIMEOUT ) ,
)
return values
return values
@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
endpoint_url = values . get ( " endpoint_url " )
endpoint_url = values . get ( " endpoint_url " )
endpoint_key = values . get ( " endpoint_api_key " )
endpoint_key = values . get ( " endpoint_api_key " )
deployment_name = values . get ( " deployment_name " )
deployment_name = values . get ( " deployment_name " )
timeout = values . get ( " timeout " , DEFAULT_TIMEOUT )
http_client = AzureMLEndpointClient (
http_client = AzureMLEndpointClient (
endpoint_url , # type: ignore
endpoint_url , # type: ignore
endpoint_key . get_secret_value ( ) , # type: ignore
endpoint_key . get_secret_value ( ) , # type: ignore
deployment_name , # type: ignore
deployment_name , # type: ignore
timeout , # type: ignore
)
)
return http_client
return http_client
@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
endpoint_url = " https://<your-endpoint>.<your_region>.inference.ml.azure.com/score " ,
endpoint_url = " https://<your-endpoint>.<your_region>.inference.ml.azure.com/score " ,
endpoint_api_type = AzureMLApiType . realtime ,
endpoint_api_type = AzureMLApiType . realtime ,
endpoint_api_key = " my-api-key " ,
endpoint_api_key = " my-api-key " ,
timeout = 120 ,
content_formatter = content_formatter ,
content_formatter = content_formatter ,
)
)
""" # noqa: E501
""" # noqa: E501