|
|
|
@ -48,10 +48,31 @@
|
|
|
|
|
" accepts = \"application/json\"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
|
|
|
|
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Transforms the input into bytes that can be consumed by SageMaker endpoint.\n",
|
|
|
|
|
" Args:\n",
|
|
|
|
|
" inputs: List of input strings.\n",
|
|
|
|
|
" model_kwargs: Additional keyword arguments to be passed to the endpoint.\n",
|
|
|
|
|
" Returns:\n",
|
|
|
|
|
" The transformed bytes input.\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" # Example: inference.py expects a JSON string with a \"inputs\" key:\n",
|
|
|
|
|
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs}) \n",
|
|
|
|
|
" return input_str.encode(\"utf-8\")\n",
|
|
|
|
|
"\n",
|
|
|
|
|
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" Transforms the bytes output from the endpoint into a list of embeddings.\n",
|
|
|
|
|
" Args:\n",
|
|
|
|
|
" output: The bytes output from SageMaker endpoint.\n",
|
|
|
|
|
" Returns:\n",
|
|
|
|
|
" The transformed output - list of embeddings\n",
|
|
|
|
|
" Note:\n",
|
|
|
|
|
" The length of the outer list is the number of input strings.\n",
|
|
|
|
|
" The length of the inner lists is the embedding dimension.\n",
|
|
|
|
|
" \"\"\"\n",
|
|
|
|
|
" # Example: inference.py returns a JSON string with the list of\n",
|
|
|
|
|
" # embeddings in a \"vectors\" key:\n",
|
|
|
|
|
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
|
|
|
|
|
" return response_json[\"vectors\"]\n",
|
|
|
|
|
"\n",
|
|
|
|
@ -60,7 +81,6 @@
|
|
|
|
|
"\n",
|
|
|
|
|
"\n",
|
|
|
|
|
"embeddings = SagemakerEndpointEmbeddings(\n",
|
|
|
|
|
" # endpoint_name=\"endpoint-name\",\n",
|
|
|
|
|
" # credentials_profile_name=\"credentials-profile-name\",\n",
|
|
|
|
|
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n",
|
|
|
|
|
" region_name=\"us-east-1\",\n",
|
|
|
|
|