mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Docs: enrich SageMaker endpoint embeddings with docstrings and examples (#9924)
Description: added comments to address the relationship between input/output transformations and the customised inference.py script.
This commit is contained in:
parent
8dbf4cbe80
commit
0fb95ebe66
@ -48,10 +48,31 @@
|
||||
" accepts = \"application/json\"\n",
|
||||
"\n",
|
||||
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
||||
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
||||
" \"\"\"\n",
|
||||
" Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
|
||||
" Args:\n",
|
||||
" inputs: List of input strings.\n",
|
||||
" model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
|
||||
" Returns:\n",
|
||||
" The transformed bytes input.\n",
|
||||
" \"\"\"\n",
|
||||
" # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
|
||||
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs}) \n",
|
||||
" return input_str.encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
||||
" \"\"\"\n",
|
||||
" Transforms the bytes output from the endpoint into a list of embeddings.\n",
|
||||
" Args:\n",
|
||||
" output: The bytes output from SageMaker endpoint.\n",
|
||||
" Returns:\n",
|
||||
" The transformed output - list of embeddings\n",
|
||||
" Note:\n",
|
||||
" The length of the outer list is the number of input strings.\n",
|
||||
" The length of the inner lists is the embedding dimension.\n",
|
||||
" \"\"\"\n",
|
||||
" # Example: inference.py returns a JSON string with the list of\n",
|
||||
" # embeddings in a \"vectors\" key:\n",
|
||||
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
||||
" return response_json[\"vectors\"]\n",
|
||||
"\n",
|
||||
@ -60,7 +81,6 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"embeddings = SagemakerEndpointEmbeddings(\n",
|
||||
" # endpoint_name=\"endpoint-name\",\n",
|
||||
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
||||
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
||||
" region_name=\"us-east-1\",\n",
|
||||
|
Loading…
Reference in New Issue
Block a user