@ -11,12 +11,18 @@ from langchain_core.outputs import Generation, LLMResult
from langchain_core . pydantic_v1 import BaseModel , SecretStr , root_validator , validator
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
DEFAULT_TIMEOUT = 50
class AzureMLEndpointClient ( object ) :
""" AzureML Managed Endpoint client. """
def __init__ (
self , endpoint_url : str , endpoint_api_key : str , deployment_name : str = " "
self ,
endpoint_url : str ,
endpoint_api_key : str ,
deployment_name : str = " " ,
timeout : int = DEFAULT_TIMEOUT ,
) - > None :
""" Initialize the class. """
if not endpoint_api_key or not endpoint_url :
@ -27,6 +33,7 @@ class AzureMLEndpointClient(object):
self . endpoint_url = endpoint_url
self . endpoint_api_key = endpoint_api_key
self . deployment_name = deployment_name
self . timeout = timeout
def call (
self ,
@ -47,7 +54,9 @@ class AzureMLEndpointClient(object):
headers [ " azureml-model-deployment " ] = self . deployment_name
req = urllib . request . Request ( self . endpoint_url , body , headers )
response = urllib . request . urlopen ( req , timeout = kwargs . get ( " timeout " , 50 ) )
response = urllib . request . urlopen (
req , timeout = kwargs . get ( " timeout " , self . timeout )
)
result = response . read ( )
return result
@ -334,6 +343,9 @@ class AzureMLBaseEndpoint(BaseModel):
""" Deployment Name for Endpoint. NOT REQUIRED to call endpoint. Should be passed
to constructor or specified as env var ` AZUREML_DEPLOYMENT_NAME ` . """
timeout : int = DEFAULT_TIMEOUT
""" Request timeout for calls to the endpoint """
http_client : Any = None #: :meta private:
content_formatter : Any = None
@ -361,6 +373,12 @@ class AzureMLBaseEndpoint(BaseModel):
" AZUREML_ENDPOINT_API_TYPE " ,
AzureMLEndpointApiType . realtime ,
)
values [ " timeout " ] = get_from_dict_or_env (
values ,
" timeout " ,
" AZUREML_TIMEOUT " ,
str ( DEFAULT_TIMEOUT ) ,
)
return values
@ -424,12 +442,15 @@ class AzureMLBaseEndpoint(BaseModel):
endpoint_url = values . get ( " endpoint_url " )
endpoint_key = values . get ( " endpoint_api_key " )
deployment_name = values . get ( " deployment_name " )
timeout = values . get ( " timeout " , DEFAULT_TIMEOUT )
http_client = AzureMLEndpointClient (
endpoint_url , # type: ignore
endpoint_key . get_secret_value ( ) , # type: ignore
deployment_name , # type: ignore
timeout , # type: ignore
)
return http_client
@ -442,6 +463,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
endpoint_url = " https://<your-endpoint>.<your_region>.inference.ml.azure.com/score " ,
endpoint_api_type = AzureMLApiType . realtime ,
endpoint_api_key = " my-api-key " ,
timeout = 120 ,
content_formatter = content_formatter ,
)
""" # noqa: E501