From 064be93edf76db06c1cc0d97acc83939e3f9a7c9 Mon Sep 17 00:00:00 2001 From: Philipp Schmid <32632186+philschmid@users.noreply.github.com> Date: Wed, 22 Mar 2023 05:51:48 +0100 Subject: [PATCH] [Embeddings] Add SageMaker Endpoint Embedding class (#1859) # What does this PR do? This PR adds similar to `llms` a SageMaker-powered `embeddings` class. This is helpful if you want to leverage Hugging Face models on SageMaker for creating your indexes. I added a example into the [docs/modules/indexes/examples/embeddings.ipynb](https://github.com/hwchase17/langchain/compare/master...philschmid:add-sm-embeddings?expand=1#diff-e82629e2894974ec87856aedd769d4bdfe400314b03734f32bee5990bc7e8062) document. The example currently includes some `_### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ ` code showing how you can deploy a sentence-transformers to SageMaker and then run the methods of the embeddings class. @hwchase17 please let me know if/when i should remove the `_### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_` in the description i linked to a detail blog on how to deploy a Sentence Transformers so i think we don't need to include those steps here. I also reused the `ContentHandlerBase` from `langchain.llms.sagemaker_endpoint` and changed the output type to `any` since it is depending on the implementation. --- .../modules/indexes/examples/embeddings.ipynb | 1448 ++++++++++++++++- langchain/embeddings/__init__.py | 2 + langchain/embeddings/sagemaker_endpoint.py | 194 +++ langchain/llms/sagemaker_endpoint.py | 8 +- 4 files changed, 1646 insertions(+), 6 deletions(-) create mode 100644 langchain/embeddings/sagemaker_endpoint.py diff --git a/docs/modules/indexes/examples/embeddings.ipynb b/docs/modules/indexes/examples/embeddings.ipynb index d1d1fb3a..20985679 100644 --- a/docs/modules/indexes/examples/embeddings.ipynb +++ b/docs/modules/indexes/examples/embeddings.ipynb @@ -654,18 +654,1460 @@ "doc_results = embeddings.embed_documents([\"foo\"])" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f83f273", + "metadata": {}, + "source": [ + "## SageMaker Endpoint Embeddings\n", + "\n", + "Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker learn more [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker)" + ] + }, { "cell_type": "code", "execution_count": null, "id": "88d366bd", "metadata": {}, "outputs": [], + "source": [ + "!pip3 install langchain boto3" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c5855922", + "metadata": {}, + "source": [ + "## _### TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "e0ddd9b4", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install sagemaker --quiet\n", + "\n", + "import os \n", + "os.environ[\"AWS_DEFAULT_REGION\"] = \"us-east-1\"\n", + "import boto3\n", + "from sagemaker import Session\n", + "# get sagemaker execution role to deploy\n", + "iam = boto3.client('iam')\n", + "role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n", + "sess = Session()\n", + "# create code/ dir\n", + "os.makedirs(\"model/code\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "86ce76c6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing model/code/inference.py\n" + ] + } + ], + "source": [ + "%%writefile model/code/inference.py\n", + "\n", + "from transformers import AutoTokenizer, AutoModel\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "# Helper: Mean Pooling - Take attention mask into account for correct averaging\n", + "def mean_pooling(model_output, attention_mask):\n", + " token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n", + " input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n", + " return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n", + "\n", + "\n", + "def model_fn(model_dir):\n", + " # Load model from HuggingFace Hub\n", + " tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n", + " model = AutoModel.from_pretrained(\"sentence-transformers/all-MiniLM-L6-v2\")\n", + " return model, tokenizer\n", + "\n", + "def predict_fn(data, model_and_tokenizer):\n", + " # destruct model and tokenizer\n", + " model, tokenizer = model_and_tokenizer\n", + "\n", + " # Tokenize sentences\n", + " sentences = data.pop(\"inputs\", data)\n", + " encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n", + "\n", + " # Compute token embeddings\n", + " with torch.no_grad():\n", + " model_output = model(**encoded_input)\n", + "\n", + " # Perform pooling\n", + " sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n", + "\n", + " # Normalize embeddings\n", + " sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n", + "\n", + " # return dictonary, which will be json serializable\n", + " return {\"embeddings\": sentence_embeddings[0].tolist()}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "24b809d4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "code/\n", + "code/inference.py\n", + "----!" + ] + } + ], + "source": [ + "from sagemaker.s3 import S3Uploader\n", + "from sagemaker.huggingface.model import HuggingFaceModel\n", + "\n", + "# create model.tar.gz and upload to s3 \n", + "parent_dir=os.getcwd()\n", + "# change to model dir\n", + "os.chdir(\"model\")\n", + "# use pigz for faster and parallel compression\n", + "!tar zcvf model.tar.gz *\n", + "# change back to parent dir\n", + "os.chdir(parent_dir)\n", + "\n", + "\n", + "# upload model.tar.gz to s3\n", + "s3_model_uri = S3Uploader.upload(local_path=\"model/model.tar.gz\", desired_s3_uri=f\"s3://{sess.default_bucket()}/embeddings\")\n", + "\n", + "# create Hugging Face Model Class\n", + "huggingface_model = HuggingFaceModel(\n", + " model_data=s3_model_uri, # path to your model and script\n", + " role=role, # iam role with permissions to create an Endpoint\n", + " transformers_version=\"4.26\", # transformers version used\n", + " pytorch_version=\"1.13\", # pytorch version used\n", + " py_version='py39', # python version used\n", + ")\n", + "\n", + "# deploy the endpoint endpoint\n", + "predictor = huggingface_model.deploy(1,\"ml.m5.2xlarge\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "324213fd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'huggingface-pytorch-inference-2023-03-21-16-14-03-834'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictor.endpoint_name" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3dff3efa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'embeddings': [-0.03833858296275139,\n", + " 0.12346473336219788,\n", + " -0.028642961755394936,\n", + " 0.05365271493792534,\n", + " 0.008845399133861065,\n", + " -0.039839327335357666,\n", + " -0.07300589978694916,\n", + " 0.04777129739522934,\n", + " -0.03046245686709881,\n", + " 0.054979756474494934,\n", + " 0.08505291491746902,\n", + " 0.03665667772293091,\n", + " -0.0053200023248791695,\n", + " -0.002233208389952779,\n", + " -0.06071101501584053,\n", + " -0.027237888425588608,\n", + " -0.011351668275892735,\n", + " -0.04243773967027664,\n", + " 0.009129947051405907,\n", + " 0.10081552714109421,\n", + " 0.075787253677845,\n", + " 0.06911724805831909,\n", + " 0.009857476688921452,\n", + " -0.0018377384403720498,\n", + " 0.02624901942908764,\n", + " 0.03290242329239845,\n", + " -0.07177436351776123,\n", + " 0.028384245932102203,\n", + " 0.06170952320098877,\n", + " -0.05252952501177788,\n", + " 0.033661700785160065,\n", + " 0.07446815073490143,\n", + " 0.07536035776138306,\n", + " 0.03538404032588005,\n", + " 0.0671340748667717,\n", + " 0.01079804077744484,\n", + " 0.08167019486427307,\n", + " 0.01656288281083107,\n", + " 0.03283063322305679,\n", + " 0.03632563352584839,\n", + " 0.002172857290133834,\n", + " -0.09895739704370499,\n", + " 0.005046747159212828,\n", + " 0.050896503031253815,\n", + " 0.009287566877901554,\n", + " 0.024507733061909676,\n", + " -0.0644078254699707,\n", + " 0.0019362837774679065,\n", + " -0.079103484749794,\n", + " 0.020850397646427155,\n", + " -0.01922827586531639,\n", + " -0.02805466018617153,\n", + " -0.07059794664382935,\n", + " -0.007083615753799677,\n", + " 0.01040570717304945,\n", + " 0.038834139704704285,\n", + " 0.01765601523220539,\n", + " -0.019606105983257294,\n", + " -0.020058417692780495,\n", + " 0.018083792179822922,\n", + " -0.00017212114471476525,\n", + " 0.013043343089520931,\n", + " -0.09337250143289566,\n", + " 0.08453577756881714,\n", + " 0.11705499142408371,\n", + " 0.057413410395383835,\n", + " -0.022439058870077133,\n", + " -0.03677624836564064,\n", + " -0.03434618189930916,\n", + " -0.06383830308914185,\n", + " -0.06846101582050323,\n", + " -0.005553076509386301,\n", + " 0.044378429651260376,\n", + " 0.016669290140271187,\n", + " 0.030911751091480255,\n", + " -0.01975969970226288,\n", + " -0.024855101481080055,\n", + " -0.05904391035437584,\n", + " 0.0945875272154808,\n", + " -0.06530515849590302,\n", + " -0.05597255751490593,\n", + " -0.03284724801778793,\n", + " 0.00811521615833044,\n", + " -0.002234684070572257,\n", + " 0.002023296197876334,\n", + " 0.07942128926515579,\n", + " 0.08518771082162857,\n", + " 0.007815245538949966,\n", + " -0.01374559011310339,\n", + " 0.031104223802685738,\n", + " 0.010080904699862003,\n", + " -0.03275560960173607,\n", + " 0.007714808918535709,\n", + " -0.006191879045218229,\n", + " -0.05613413453102112,\n", + " 0.004364899825304747,\n", + " -0.01403757743537426,\n", + " -0.039304714649915695,\n", + " 0.07822350412607193,\n", + " 0.07393720000982285,\n", + " 0.05619140341877937,\n", + " 0.003301335731521249,\n", + " 0.04155803844332695,\n", + " -0.010387539863586426,\n", + " -0.13272696733474731,\n", + " -0.10473112016916275,\n", + " 0.018451808020472527,\n", + " -0.07520624995231628,\n", + " 0.04954085499048233,\n", + " -0.028530888259410858,\n", + " -0.01358408946543932,\n", + " -0.037112679332494736,\n", + " -0.06756578385829926,\n", + " -0.019552525132894516,\n", + " -0.010211824439466,\n", + " -0.051934875547885895,\n", + " -0.05941231921315193,\n", + " 0.016754044219851494,\n", + " 0.04098018631339073,\n", + " 0.001522376318462193,\n", + " 0.08095283806324005,\n", + " 0.002651068614795804,\n", + " -0.03870720788836479,\n", + " -0.04703034833073616,\n", + " -0.05854427441954613,\n", + " -0.029478492215275764,\n", + " 0.03882651776075363,\n", + " -8.102625254868425e-33,\n", + " -0.012914206832647324,\n", + " -0.014458492398262024,\n", + " -0.022368784993886948,\n", + " 0.1056450605392456,\n", + " 0.0037274654023349285,\n", + " 0.005939559079706669,\n", + " -0.023657256737351418,\n", + " 0.041163913905620575,\n", + " -0.07411694526672363,\n", + " 0.007076926529407501,\n", + " 0.0018349214224144816,\n", + " -0.03314222767949104,\n", + " 0.006818821653723717,\n", + " 0.04693515598773956,\n", + " -0.03836120665073395,\n", + " 0.05861291661858559,\n", + " -0.0840379074215889,\n", + " 0.11954139918088913,\n", + " -0.025204092264175415,\n", + " 0.02761165052652359,\n", + " 0.0244757030159235,\n", + " 0.014137371443212032,\n", + " 0.0128665491938591,\n", + " -0.05779572203755379,\n", + " -0.031691741198301315,\n", + " -0.0029006320983171463,\n", + " -0.027254171669483185,\n", + " -0.027451230213046074,\n", + " -0.03404244780540466,\n", + " 0.020136823877692223,\n", + " 0.022654512897133827,\n", + " 0.030933434143662453,\n", + " -0.045505885034799576,\n", + " -0.0025163793470710516,\n", + " 0.01510235108435154,\n", + " 0.09668111801147461,\n", + " 0.001809411682188511,\n", + " -0.05403870716691017,\n", + " 0.0025403527542948723,\n", + " 0.006051000207662582,\n", + " -0.056302234530448914,\n", + " -0.028254246339201927,\n", + " 0.06966646015644073,\n", + " 0.04410792514681816,\n", + " 0.039832279086112976,\n", + " -0.0419430211186409,\n", + " -0.0038099137600511312,\n", + " -0.04156690835952759,\n", + " 0.09482309967279434,\n", + " 0.019028929993510246,\n", + " -0.04011702537536621,\n", + " 0.0324222669005394,\n", + " 0.012565849348902702,\n", + " -0.056325893849134445,\n", + " 0.04461190849542618,\n", + " 0.04928917437791824,\n", + " 0.017442630603909492,\n", + " 0.05323149263858795,\n", + " -0.020876457914710045,\n", + " 0.061462536454200745,\n", + " -0.014837260358035564,\n", + " 0.07423629611730576,\n", + " -0.0576944537460804,\n", + " 0.049852192401885986,\n", + " -0.05890402942895889,\n", + " -0.0006539729074575007,\n", + " -0.10970547795295715,\n", + " -0.06829895824193954,\n", + " 0.13056595623493195,\n", + " -0.011906635947525501,\n", + " -0.0159984789788723,\n", + " -0.0211041159927845,\n", + " -0.007144191302359104,\n", + " -0.0164438858628273,\n", + " -0.016906214877963066,\n", + " -0.04813709110021591,\n", + " 0.015731733292341232,\n", + " 0.030654815956950188,\n", + " -0.004599860403686762,\n", + " -0.03823969140648842,\n", + " -0.04718682914972305,\n", + " -0.08068915456533432,\n", + " -0.011494779027998447,\n", + " -0.05190776288509369,\n", + " -0.04332379251718521,\n", + " -0.019109943881630898,\n", + " 0.036341868340969086,\n", + " -0.06575313955545425,\n", + " -0.014969361014664173,\n", + " -0.0911363959312439,\n", + " 0.035127948969602585,\n", + " 0.019904181361198425,\n", + " -0.055992890149354935,\n", + " -0.04273851588368416,\n", + " 0.11667020618915558,\n", + " 4.7537233992963164e-33,\n", + " -0.04277687147259712,\n", + " 0.010693217627704144,\n", + " -0.08699914813041687,\n", + " 0.11428382992744446,\n", + " 0.026194244623184204,\n", + " 0.008768039755523205,\n", + " 0.08940346539020538,\n", + " -0.0019060149788856506,\n", + " -0.0455072745680809,\n", + " 0.08432017266750336,\n", + " 0.011060485616326332,\n", + " 0.000260289350990206,\n", + " -0.00023178635456133634,\n", + " -0.0015942883910611272,\n", + " 0.0015580946346744895,\n", + " -0.025324126705527306,\n", + " -0.03786805272102356,\n", + " -0.0546313114464283,\n", + " 0.004270816687494516,\n", + " 0.016222011297941208,\n", + " -0.04763113334774971,\n", + " 0.11077607423067093,\n", + " 0.045782990753650665,\n", + " 0.07989457994699478,\n", + " -0.006792569998651743,\n", + " -0.010313649661839008,\n", + " 0.006975427269935608,\n", + " -0.09530742466449738,\n", + " -0.014356936328113079,\n", + " -0.013479162007570267,\n", + " -0.009381195530295372,\n", + " -0.0026153195649385452,\n", + " -0.12162390351295471,\n", + " 0.07765249162912369,\n", + " 0.009094372391700745,\n", + " -0.10183481127023697,\n", + " 0.13146239519119263,\n", + " -0.04587067291140556,\n", + " -0.009605005383491516,\n", + " 0.024302706122398376,\n", + " 0.045921340584754944,\n", + " 0.08771276473999023,\n", + " 0.055159058421850204,\n", + " 0.047116719186306,\n", + " -0.022800585255026817,\n", + " 0.05540422350168228,\n", + " 0.03942396119236946,\n", + " -0.06854791939258575,\n", + " 0.07696892321109772,\n", + " 0.0264807790517807,\n", + " 0.013421732001006603,\n", + " -0.03159027546644211,\n", + " 0.02122318185865879,\n", + " -0.02458374947309494,\n", + " -0.09490033239126205,\n", + " 0.05001789703965187,\n", + " -0.07885674387216568,\n", + " -0.0469261035323143,\n", + " -0.009405327029526234,\n", + " 0.06844945251941681,\n", + " -0.019532756879925728,\n", + " 0.08325397968292236,\n", + " -0.0020212731324136257,\n", + " 0.07861411571502686,\n", + " 0.009707036428153515,\n", + " -0.08329329639673233,\n", + " -0.08883728086948395,\n", + " 0.026159727945923805,\n", + " -0.0036121727898716927,\n", + " 0.0021212503779679537,\n", + " 0.06756487488746643,\n", + " -0.04351912811398506,\n", + " -0.031103378161787987,\n", + " -0.1055448055267334,\n", + " 0.08162888139486313,\n", + " -0.11693760007619858,\n", + " 0.0012153959833085537,\n", + " -0.042226288467645645,\n", + " -0.025040708482265472,\n", + " -0.05382077395915985,\n", + " 0.046688906848430634,\n", + " -0.004659516736865044,\n", + " -0.049144256860017776,\n", + " 0.05339549481868744,\n", + " -0.016824593767523766,\n", + " -0.018911045044660568,\n", + " 0.0021526776254177094,\n", + " 0.010545731522142887,\n", + " -0.02843359299004078,\n", + " 0.06319320946931839,\n", + " -0.041760899126529694,\n", + " 0.03648762032389641,\n", + " -0.028613677248358727,\n", + " 0.012441876344382763,\n", + " -0.030993392691016197,\n", + " -1.827941886745066e-08,\n", + " -0.03364746645092964,\n", + " -0.010457276366651058,\n", + " 0.006326176226139069,\n", + " -0.03394529968500137,\n", + " -0.03437081351876259,\n", + " 0.043725401163101196,\n", + " 0.07607871294021606,\n", + " -0.05076980963349342,\n", + " -0.06551552563905716,\n", + " -0.023710858076810837,\n", + " 0.05217289924621582,\n", + " 0.008229373954236507,\n", + " -0.05053586885333061,\n", + " -0.0046344115398824215,\n", + " 0.04596329480409622,\n", + " -0.048263613134622574,\n", + " -0.007646505255252123,\n", + " -0.0246701892465353,\n", + " -0.05899248272180557,\n", + " 0.02179579623043537,\n", + " -0.033197544515132904,\n", + " 0.026267115026712418,\n", + " 0.019565267488360405,\n", + " 0.022036483511328697,\n", + " -0.02707892283797264,\n", + " 0.07815380394458771,\n", + " 0.03259186074137688,\n", + " 0.10126295685768127,\n", + " 0.007166724652051926,\n", + " -0.031028350815176964,\n", + " 0.04080115631222725,\n", + " 0.10805943608283997,\n", + " -0.00941381324082613,\n", + " -0.01028114091604948,\n", + " 0.037279773503541946,\n", + " 0.11904413253068924,\n", + " 0.04982069879770279,\n", + " 0.05209505558013916,\n", + " 0.020246144384145737,\n", + " 0.05551902949810028,\n", + " -0.10270132124423981,\n", + " -0.009933318942785263,\n", + " -0.022510290145874023,\n", + " 0.03311152011156082,\n", + " 0.05227212607860565,\n", + " -0.029383286833763123,\n", + " -0.1383359581232071,\n", + " -0.014143865555524826,\n", + " -0.037659481167793274,\n", + " -0.08339183777570724,\n", + " -0.0034869578666985035,\n", + " -0.0415429063141346,\n", + " 0.04902830719947815,\n", + " 0.02155115082859993,\n", + " -0.040210600942373276,\n", + " 0.008557669818401337,\n", + " 0.046616844832897186,\n", + " -0.004114149138331413,\n", + " -0.03815949708223343,\n", + " -0.015223635360598564,\n", + " 0.12486445158720016,\n", + " 0.08800436556339264,\n", + " 0.08585748821496964,\n", + " -0.015338928438723087]}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "predictor.predict({\"inputs\": \"This is a test document.\"})" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3764f159", + "metadata": {}, + "source": [ + "## _### END TEMPORARY: Showing how to deploy a SageMaker Endpoint from a Hugging Face model ###_ " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1e9b926a", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict\n", + "from langchain.embeddings import SagemakerEndpointEmbeddings\n", + "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", + "import json\n", + "\n", + "\n", + "class ContentHandler(ContentHandlerBase):\n", + " content_type = \"application/json\"\n", + " accepts = \"application/json\"\n", + "\n", + " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", + " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", + " return input_str.encode('utf-8')\n", + " \n", + " def transform_output(self, output: bytes) -> str:\n", + " response_json = json.loads(output.read().decode(\"utf-8\"))\n", + " return response_json[\"embeddings\"]\n", + "\n", + "content_handler = ContentHandler()\n", + "\n", + "\n", + "embeddings = SagemakerEndpointEmbeddings(\n", + " # endpoint_name=\"endpoint-name\", \n", + " # credentials_profile_name=\"credentials-profile-name\", \n", + " endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n", + " region_name=\"us-east-1\", \n", + " content_handler=content_handler\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "836e3ea5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.01623339205980301,\n", + " -0.007662336342036724,\n", + " 0.018606489524245262,\n", + " 0.031968992203474045,\n", + " -0.031003747135400772,\n", + " 0.008777972310781479,\n", + " 0.1594553291797638,\n", + " -0.009521624073386192,\n", + " 0.020200366154313087,\n", + " -0.04545809328556061,\n", + " 0.013985812664031982,\n", + " -0.017674963921308517,\n", + " -0.03616964817047119,\n", + " -0.02194339968264103,\n", + " 0.021387653425335884,\n", + " 0.06459270417690277,\n", + " -0.03659535571932793,\n", + " -0.01213359646499157,\n", + " -0.043666232377290726,\n", + " -0.03515005484223366,\n", + " -0.032629866153001785,\n", + " 0.07834123075008392,\n", + " -0.021041689440608025,\n", + " 0.03372766822576523,\n", + " -0.024157941341400146,\n", + " -0.010767146944999695,\n", + " -0.042864806950092316,\n", + " 0.013539575971662998,\n", + " 0.05039731785655022,\n", + " -0.091956727206707,\n", + " 0.035494621843099594,\n", + " 0.18029741942882538,\n", + " 0.01576363667845726,\n", + " -0.04949156939983368,\n", + " -0.003976485226303339,\n", + " 0.00032106428989209235,\n", + " 0.021849628537893295,\n", + " 0.035368386656045914,\n", + " 0.04185418039560318,\n", + " 0.04899369180202484,\n", + " -0.026651302352547646,\n", + " -0.05650882050395012,\n", + " -0.03276852145791054,\n", + " -0.020723465830087662,\n", + " -0.011230835691094398,\n", + " 0.02798161283135414,\n", + " -0.010538998991250992,\n", + " 0.030317796394228935,\n", + " 0.017697133123874664,\n", + " 0.003633821150287986,\n", + " -0.008708533830940723,\n", + " -0.04946836829185486,\n", + " -0.029903240501880646,\n", + " 0.022750651463866234,\n", + " 0.09276428818702698,\n", + " 0.05072581395506859,\n", + " 0.02917262725532055,\n", + " 0.00728880288079381,\n", + " -0.011496285907924175,\n", + " -0.05313197895884514,\n", + " -0.027890320867300034,\n", + " 0.030064044520258904,\n", + " -0.06029842048883438,\n", + " -0.043088313192129135,\n", + " 0.05004483461380005,\n", + " 0.0015685138059780002,\n", + " -0.01834270916879177,\n", + " 0.046504270285367966,\n", + " -0.043405696749687195,\n", + " 0.08440472185611725,\n", + " 0.022881966084241867,\n", + " 0.013790522702038288,\n", + " 0.03525456413626671,\n", + " 0.08282686769962311,\n", + " 0.031224273145198822,\n", + " -0.032255761325359344,\n", + " 0.033190108835697174,\n", + " -0.02879202552139759,\n", + " 0.09641945362091064,\n", + " 0.014541308395564556,\n", + " -0.0425688773393631,\n", + " 0.007836293429136276,\n", + " -0.07434573024511337,\n", + " -0.03844423592090607,\n", + " 0.007907251827418804,\n", + " 0.005604865029454231,\n", + " 0.014666788280010223,\n", + " -0.015787949785590172,\n", + " -0.011632969602942467,\n", + " 0.06502652168273926,\n", + " -0.09462911635637283,\n", + " -0.05418006703257561,\n", + " 0.07266692817211151,\n", + " -0.0059609077870845795,\n", + " 0.015150884166359901,\n", + " 0.033904850482940674,\n", + " 0.04719925299286842,\n", + " -0.006713803857564926,\n", + " -0.0628967359662056,\n", + " 0.2842618525028229,\n", + " -0.007497070357203484,\n", + " 0.11969559639692307,\n", + " 0.047007955610752106,\n", + " -0.023929651826620102,\n", + " 0.015414181165397167,\n", + " -0.029861856251955032,\n", + " -0.014299873262643814,\n", + " 0.018457207828760147,\n", + " 0.05915089696645737,\n", + " -0.03441261500120163,\n", + " -0.01635487750172615,\n", + " -0.021376853808760643,\n", + " -0.01367877796292305,\n", + " -0.04958583787083626,\n", + " 0.01463531143963337,\n", + " 0.01211540587246418,\n", + " -0.005459521431475878,\n", + " 0.005695596802979708,\n", + " -0.04747362807393074,\n", + " -0.05998056381940842,\n", + " -0.042840663343667984,\n", + " 0.04042218253016472,\n", + " -0.006624337285757065,\n", + " 0.025121480226516724,\n", + " -0.054944958537817,\n", + " -0.06516158580780029,\n", + " 0.007337308023124933,\n", + " -6.10322324013495e-33,\n", + " 0.002179093426093459,\n", + " -0.073213592171669,\n", + " -0.014703890308737755,\n", + " 0.00238825217820704,\n", + " 0.02046307921409607,\n", + " -0.06456342339515686,\n", + " 0.014286896213889122,\n", + " 0.02082856185734272,\n", + " -0.07692538946866989,\n", + " 0.09246989339590073,\n", + " -0.03469334542751312,\n", + " 0.022259987890720367,\n", + " -0.0369521826505661,\n", + " -0.0876070111989975,\n", + " 0.13785682618618011,\n", + " -0.000683621852658689,\n", + " 0.0018552240217104554,\n", + " 0.07194776087999344,\n", + " -0.0633404403924942,\n", + " -0.01646324060857296,\n", + " -0.0361541211605072,\n", + " -0.006936112884432077,\n", + " 0.003252814756706357,\n", + " 0.02627389132976532,\n", + " -0.0014277833979576826,\n", + " -0.09001296013593674,\n", + " 0.008833721280097961,\n", + " -0.07455790787935257,\n", + " 0.10064911842346191,\n", + " 0.03227779641747475,\n", + " -0.016069436445832253,\n", + " 0.024673042818903923,\n", + " 0.04188213497400284,\n", + " 0.03961843252182007,\n", + " -0.028469551354646683,\n", + " -0.05262545496225357,\n", + " -0.006966825108975172,\n", + " -0.0033834113273769617,\n", + " -0.038578882813453674,\n", + " -0.010265848599374294,\n", + " -0.033789463341236115,\n", + " 0.0030778711661696434,\n", + " -0.05088731646537781,\n", + " -0.019024258479475975,\n", + " 0.05421010032296181,\n", + " 0.015494044870138168,\n", + " -0.009311210364103317,\n", + " 0.0050599598325788975,\n", + " -0.04918931797146797,\n", + " 0.03970836102962494,\n", + " 0.06579958647489548,\n", + " 0.014110234566032887,\n", + " -0.04829266294836998,\n", + " 0.05065532773733139,\n", + " 0.021345015615224838,\n", + " -0.02805492654442787,\n", + " -0.013115333393216133,\n", + " -0.03833610191941261,\n", + " 0.0081633934751153,\n", + " 0.0020320340991020203,\n", + " 0.025601046159863472,\n", + " 0.046745311468839645,\n", + " -0.07602663338184357,\n", + " 0.08589514344930649,\n", + " -0.09630884975194931,\n", + " 0.01156257651746273,\n", + " 0.047838304191827774,\n", + " -0.03707060590386391,\n", + " 0.05717772990465164,\n", + " -0.028168894350528717,\n", + " -0.06691361963748932,\n", + " 0.003909755032509565,\n", + " -0.01265989150851965,\n", + " -0.024667585268616676,\n", + " -0.04399942606687546,\n", + " 0.013469734229147434,\n", + " 0.013298758305609226,\n", + " 0.0409042127430439,\n", + " -0.012081797234714031,\n", + " -0.009779289364814758,\n", + " 0.021113228052854538,\n", + " -0.06191551312804222,\n", + " -0.010964356362819672,\n", + " 0.027119100093841553,\n", + " -0.03144009783864021,\n", + " 0.037719033658504486,\n", + " 0.02421882562339306,\n", + " -0.13700149953365326,\n", + " 0.0038421833887696266,\n", + " -0.06574120372533798,\n", + " -0.12629178166389465,\n", + " 0.018397213891148567,\n", + " 0.0019562605302780867,\n", + " -0.06581622362136841,\n", + " 0.0056412797421216965,\n", + " 6.17423926197194e-33,\n", + " 0.11609319597482681,\n", + " 0.023075049743056297,\n", + " -0.02540658414363861,\n", + " 0.021112393587827682,\n", + " -0.010050611570477486,\n", + " 0.0045014130882918835,\n", + " 0.02216450683772564,\n", + " 0.03083667904138565,\n", + " -0.065506212413311,\n", + " -0.028498610481619835,\n", + " -0.08708083629608154,\n", + " -0.027195820584893227,\n", + " 0.04075731709599495,\n", + " -0.00738579360768199,\n", + " 0.031747449189424515,\n", + " 0.020246611908078194,\n", + " 0.03285415843129158,\n", + " -0.037579674273729324,\n", + " -0.025780295953154564,\n", + " 0.044498566538095474,\n", + " -0.01600523293018341,\n", + " -0.1110200434923172,\n", + " 0.10275887697935104,\n", + " -0.044455550611019135,\n", + " -0.043082430958747864,\n", + " 0.04361744225025177,\n", + " 0.09388253092765808,\n", + " 0.03668423742055893,\n", + " -0.08740879595279694,\n", + " -0.015174541622400284,\n", + " 0.035617802292108536,\n", + " -0.056008175015449524,\n", + " -0.07729960232973099,\n", + " -0.055068857967853546,\n", + " 0.011802412569522858,\n", + " 0.0005090870545245707,\n", + " -0.04531490430235863,\n", + " 0.009635107591748238,\n", + " 0.0066973078064620495,\n", + " -0.08850639313459396,\n", + " 0.07926266640424728,\n", + " 0.03328349068760872,\n", + " 0.02206319011747837,\n", + " 0.08003410696983337,\n", + " -0.004926585592329502,\n", + " -0.012855191715061665,\n", + " -0.030001787468791008,\n", + " 0.0038301211316138506,\n", + " 0.09513995796442032,\n", + " -0.023254361003637314,\n", + " -0.01384524442255497,\n", + " -0.0006733545451425016,\n", + " 0.004949721973389387,\n", + " -0.03836912661790848,\n", + " -0.0484086349606514,\n", + " -0.04300595819950104,\n", + " -0.03302333503961563,\n", + " -0.011142191477119923,\n", + " -0.021775009110569954,\n", + " 0.009556151926517487,\n", + " -0.014081847853958607,\n", + " 0.01725372113287449,\n", + " -0.002208009362220764,\n", + " 0.043982796370983124,\n", + " -0.12186389416456223,\n", + " -0.03109029121696949,\n", + " -0.0648212656378746,\n", + " -0.03446059674024582,\n", + " -0.0009474779944866896,\n", + " 0.019224559888243675,\n", + " 0.030093936249613762,\n", + " 0.011459640227258205,\n", + " -0.031019125133752823,\n", + " 0.11076018959283829,\n", + " -0.08466918021440506,\n", + " -0.028721166774630547,\n", + " -0.006525673437863588,\n", + " 0.05877530202269554,\n", + " 0.021319882944226265,\n", + " 0.08542844653129578,\n", + " 0.05103899911046028,\n", + " -0.02113465592265129,\n", + " 0.01493119541555643,\n", + " 0.010513859800994396,\n", + " -0.023147936910390854,\n", + " -0.044208601117134094,\n", + " -0.0010544550605118275,\n", + " 0.0656798928976059,\n", + " -0.013098708353936672,\n", + " 0.0029119630344212055,\n", + " 0.03165023773908615,\n", + " 0.06931225955486298,\n", + " -0.02299979329109192,\n", + " 0.022364258766174316,\n", + " -0.04974697157740593,\n", + " -1.3714036128931184e-08,\n", + " -0.0205942764878273,\n", + " 0.047028280794620514,\n", + " -0.032210975885391235,\n", + " 0.049078319221735,\n", + " 0.0394253209233284,\n", + " 0.10298289358615875,\n", + " 0.013628372922539711,\n", + " -0.07071256637573242,\n", + " -0.001111415564082563,\n", + " 0.045793063938617706,\n", + " 0.010663686320185661,\n", + " 0.022661199793219566,\n", + " -0.00039414051570929587,\n", + " 0.04868670925498009,\n", + " 0.08181674033403397,\n", + " -0.06234998628497124,\n", + " -0.017647461965680122,\n", + " -0.05699630081653595,\n", + " -0.035604529082775116,\n", + " -0.002848744625225663,\n", + " -0.07433759421110153,\n", + " 0.05970819666981697,\n", + " -0.03040698915719986,\n", + " -0.03587964177131653,\n", + " -0.05538871884346008,\n", + " -0.007939192466437817,\n", + " -0.015285325236618519,\n", + " 0.08461211621761322,\n", + " 0.01166541874408722,\n", + " 0.03213988244533539,\n", + " 0.05643611401319504,\n", + " 0.2006419152021408,\n", + " -0.07411110401153564,\n", + " -0.018009720370173454,\n", + " 0.016179822385311127,\n", + " -0.0028461480978876352,\n", + " 0.0402149073779583,\n", + " 0.0006247136043384671,\n", + " 0.0006973804556764662,\n", + " 0.09922358393669128,\n", + " -0.029822450131177902,\n", + " -0.005783025175333023,\n", + " -0.0028224103152751923,\n", + " -0.11175407469272614,\n", + " 0.012009709142148495,\n", + " -0.009956827387213707,\n", + " 0.011468647047877312,\n", + " -0.054449401795864105,\n", + " -0.016370657831430435,\n", + " 0.022106735035777092,\n", + " 0.03950563445687294,\n", + " 0.005319684278219938,\n", + " 0.042190469801425934,\n", + " 0.08844445645809174,\n", + " 0.0810166597366333,\n", + " 0.06980433315038681,\n", + " -0.04784897342324257,\n", + " 0.01753094792366028,\n", + " -0.10126522183418274,\n", + " -0.016526369377970695,\n", + " 0.11310216039419174,\n", + " 0.0874418243765831,\n", + " 0.09520682692527771,\n", + " 0.10083616524934769]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_result = embeddings.embed_query(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "76f1b752", + "metadata": {}, + "outputs": [], + "source": [ + "doc_results = embeddings.embed_documents([\"foo\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "221f2f0e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[0.01623339205980301,\n", + " -0.007662336342036724,\n", + " 0.018606489524245262,\n", + " 0.031968992203474045,\n", + " -0.031003747135400772,\n", + " 0.008777972310781479,\n", + " 0.1594553291797638,\n", + " -0.009521624073386192,\n", + " 0.020200366154313087,\n", + " -0.04545809328556061,\n", + " 0.013985812664031982,\n", + " -0.017674963921308517,\n", + " -0.03616964817047119,\n", + " -0.02194339968264103,\n", + " 0.021387653425335884,\n", + " 0.06459270417690277,\n", + " -0.03659535571932793,\n", + " -0.01213359646499157,\n", + " -0.043666232377290726,\n", + " -0.03515005484223366,\n", + " -0.032629866153001785,\n", + " 0.07834123075008392,\n", + " -0.021041689440608025,\n", + " 0.03372766822576523,\n", + " -0.024157941341400146,\n", + " -0.010767146944999695,\n", + " -0.042864806950092316,\n", + " 0.013539575971662998,\n", + " 0.05039731785655022,\n", + " -0.091956727206707,\n", + " 0.035494621843099594,\n", + " 0.18029741942882538,\n", + " 0.01576363667845726,\n", + " -0.04949156939983368,\n", + " -0.003976485226303339,\n", + " 0.00032106428989209235,\n", + " 0.021849628537893295,\n", + " 0.035368386656045914,\n", + " 0.04185418039560318,\n", + " 0.04899369180202484,\n", + " -0.026651302352547646,\n", + " -0.05650882050395012,\n", + " -0.03276852145791054,\n", + " -0.020723465830087662,\n", + " -0.011230835691094398,\n", + " 0.02798161283135414,\n", + " -0.010538998991250992,\n", + " 0.030317796394228935,\n", + " 0.017697133123874664,\n", + " 0.003633821150287986,\n", + " -0.008708533830940723,\n", + " -0.04946836829185486,\n", + " -0.029903240501880646,\n", + " 0.022750651463866234,\n", + " 0.09276428818702698,\n", + " 0.05072581395506859,\n", + " 0.02917262725532055,\n", + " 0.00728880288079381,\n", + " -0.011496285907924175,\n", + " -0.05313197895884514,\n", + " -0.027890320867300034,\n", + " 0.030064044520258904,\n", + " -0.06029842048883438,\n", + " -0.043088313192129135,\n", + " 0.05004483461380005,\n", + " 0.0015685138059780002,\n", + " -0.01834270916879177,\n", + " 0.046504270285367966,\n", + " -0.043405696749687195,\n", + " 0.08440472185611725,\n", + " 0.022881966084241867,\n", + " 0.013790522702038288,\n", + " 0.03525456413626671,\n", + " 0.08282686769962311,\n", + " 0.031224273145198822,\n", + " -0.032255761325359344,\n", + " 0.033190108835697174,\n", + " -0.02879202552139759,\n", + " 0.09641945362091064,\n", + " 0.014541308395564556,\n", + " -0.0425688773393631,\n", + " 0.007836293429136276,\n", + " -0.07434573024511337,\n", + " -0.03844423592090607,\n", + " 0.007907251827418804,\n", + " 0.005604865029454231,\n", + " 0.014666788280010223,\n", + " -0.015787949785590172,\n", + " -0.011632969602942467,\n", + " 0.06502652168273926,\n", + " -0.09462911635637283,\n", + " -0.05418006703257561,\n", + " 0.07266692817211151,\n", + " -0.0059609077870845795,\n", + " 0.015150884166359901,\n", + " 0.033904850482940674,\n", + " 0.04719925299286842,\n", + " -0.006713803857564926,\n", + " -0.0628967359662056,\n", + " 0.2842618525028229,\n", + " -0.007497070357203484,\n", + " 0.11969559639692307,\n", + " 0.047007955610752106,\n", + " -0.023929651826620102,\n", + " 0.015414181165397167,\n", + " -0.029861856251955032,\n", + " -0.014299873262643814,\n", + " 0.018457207828760147,\n", + " 0.05915089696645737,\n", + " -0.03441261500120163,\n", + " -0.01635487750172615,\n", + " -0.021376853808760643,\n", + " -0.01367877796292305,\n", + " -0.04958583787083626,\n", + " 0.01463531143963337,\n", + " 0.01211540587246418,\n", + " -0.005459521431475878,\n", + " 0.005695596802979708,\n", + " -0.04747362807393074,\n", + " -0.05998056381940842,\n", + " -0.042840663343667984,\n", + " 0.04042218253016472,\n", + " -0.006624337285757065,\n", + " 0.025121480226516724,\n", + " -0.054944958537817,\n", + " -0.06516158580780029,\n", + " 0.007337308023124933,\n", + " -6.10322324013495e-33,\n", + " 0.002179093426093459,\n", + " -0.073213592171669,\n", + " -0.014703890308737755,\n", + " 0.00238825217820704,\n", + " 0.02046307921409607,\n", + " -0.06456342339515686,\n", + " 0.014286896213889122,\n", + " 0.02082856185734272,\n", + " -0.07692538946866989,\n", + " 0.09246989339590073,\n", + " -0.03469334542751312,\n", + " 0.022259987890720367,\n", + " -0.0369521826505661,\n", + " -0.0876070111989975,\n", + " 0.13785682618618011,\n", + " -0.000683621852658689,\n", + " 0.0018552240217104554,\n", + " 0.07194776087999344,\n", + " -0.0633404403924942,\n", + " -0.01646324060857296,\n", + " -0.0361541211605072,\n", + " -0.006936112884432077,\n", + " 0.003252814756706357,\n", + " 0.02627389132976532,\n", + " -0.0014277833979576826,\n", + " -0.09001296013593674,\n", + " 0.008833721280097961,\n", + " -0.07455790787935257,\n", + " 0.10064911842346191,\n", + " 0.03227779641747475,\n", + " -0.016069436445832253,\n", + " 0.024673042818903923,\n", + " 0.04188213497400284,\n", + " 0.03961843252182007,\n", + " -0.028469551354646683,\n", + " -0.05262545496225357,\n", + " -0.006966825108975172,\n", + " -0.0033834113273769617,\n", + " -0.038578882813453674,\n", + " -0.010265848599374294,\n", + " -0.033789463341236115,\n", + " 0.0030778711661696434,\n", + " -0.05088731646537781,\n", + " -0.019024258479475975,\n", + " 0.05421010032296181,\n", + " 0.015494044870138168,\n", + " -0.009311210364103317,\n", + " 0.0050599598325788975,\n", + " -0.04918931797146797,\n", + " 0.03970836102962494,\n", + " 0.06579958647489548,\n", + " 0.014110234566032887,\n", + " -0.04829266294836998,\n", + " 0.05065532773733139,\n", + " 0.021345015615224838,\n", + " -0.02805492654442787,\n", + " -0.013115333393216133,\n", + " -0.03833610191941261,\n", + " 0.0081633934751153,\n", + " 0.0020320340991020203,\n", + " 0.025601046159863472,\n", + " 0.046745311468839645,\n", + " -0.07602663338184357,\n", + " 0.08589514344930649,\n", + " -0.09630884975194931,\n", + " 0.01156257651746273,\n", + " 0.047838304191827774,\n", + " -0.03707060590386391,\n", + " 0.05717772990465164,\n", + " -0.028168894350528717,\n", + " -0.06691361963748932,\n", + " 0.003909755032509565,\n", + " -0.01265989150851965,\n", + " -0.024667585268616676,\n", + " -0.04399942606687546,\n", + " 0.013469734229147434,\n", + " 0.013298758305609226,\n", + " 0.0409042127430439,\n", + " -0.012081797234714031,\n", + " -0.009779289364814758,\n", + " 0.021113228052854538,\n", + " -0.06191551312804222,\n", + " -0.010964356362819672,\n", + " 0.027119100093841553,\n", + " -0.03144009783864021,\n", + " 0.037719033658504486,\n", + " 0.02421882562339306,\n", + " -0.13700149953365326,\n", + " 0.0038421833887696266,\n", + " -0.06574120372533798,\n", + " -0.12629178166389465,\n", + " 0.018397213891148567,\n", + " 0.0019562605302780867,\n", + " -0.06581622362136841,\n", + " 0.0056412797421216965,\n", + " 6.17423926197194e-33,\n", + " 0.11609319597482681,\n", + " 0.023075049743056297,\n", + " -0.02540658414363861,\n", + " 0.021112393587827682,\n", + " -0.010050611570477486,\n", + " 0.0045014130882918835,\n", + " 0.02216450683772564,\n", + " 0.03083667904138565,\n", + " -0.065506212413311,\n", + " -0.028498610481619835,\n", + " -0.08708083629608154,\n", + " -0.027195820584893227,\n", + " 0.04075731709599495,\n", + " -0.00738579360768199,\n", + " 0.031747449189424515,\n", + " 0.020246611908078194,\n", + " 0.03285415843129158,\n", + " -0.037579674273729324,\n", + " -0.025780295953154564,\n", + " 0.044498566538095474,\n", + " -0.01600523293018341,\n", + " -0.1110200434923172,\n", + " 0.10275887697935104,\n", + " -0.044455550611019135,\n", + " -0.043082430958747864,\n", + " 0.04361744225025177,\n", + " 0.09388253092765808,\n", + " 0.03668423742055893,\n", + " -0.08740879595279694,\n", + " -0.015174541622400284,\n", + " 0.035617802292108536,\n", + " -0.056008175015449524,\n", + " -0.07729960232973099,\n", + " -0.055068857967853546,\n", + " 0.011802412569522858,\n", + " 0.0005090870545245707,\n", + " -0.04531490430235863,\n", + " 0.009635107591748238,\n", + " 0.0066973078064620495,\n", + " -0.08850639313459396,\n", + " 0.07926266640424728,\n", + " 0.03328349068760872,\n", + " 0.02206319011747837,\n", + " 0.08003410696983337,\n", + " -0.004926585592329502,\n", + " -0.012855191715061665,\n", + " -0.030001787468791008,\n", + " 0.0038301211316138506,\n", + " 0.09513995796442032,\n", + " -0.023254361003637314,\n", + " -0.01384524442255497,\n", + " -0.0006733545451425016,\n", + " 0.004949721973389387,\n", + " -0.03836912661790848,\n", + " -0.0484086349606514,\n", + " -0.04300595819950104,\n", + " -0.03302333503961563,\n", + " -0.011142191477119923,\n", + " -0.021775009110569954,\n", + " 0.009556151926517487,\n", + " -0.014081847853958607,\n", + " 0.01725372113287449,\n", + " -0.002208009362220764,\n", + " 0.043982796370983124,\n", + " -0.12186389416456223,\n", + " -0.03109029121696949,\n", + " -0.0648212656378746,\n", + " -0.03446059674024582,\n", + " -0.0009474779944866896,\n", + " 0.019224559888243675,\n", + " 0.030093936249613762,\n", + " 0.011459640227258205,\n", + " -0.031019125133752823,\n", + " 0.11076018959283829,\n", + " -0.08466918021440506,\n", + " -0.028721166774630547,\n", + " -0.006525673437863588,\n", + " 0.05877530202269554,\n", + " 0.021319882944226265,\n", + " 0.08542844653129578,\n", + " 0.05103899911046028,\n", + " -0.02113465592265129,\n", + " 0.01493119541555643,\n", + " 0.010513859800994396,\n", + " -0.023147936910390854,\n", + " -0.044208601117134094,\n", + " -0.0010544550605118275,\n", + " 0.0656798928976059,\n", + " -0.013098708353936672,\n", + " 0.0029119630344212055,\n", + " 0.03165023773908615,\n", + " 0.06931225955486298,\n", + " -0.02299979329109192,\n", + " 0.022364258766174316,\n", + " -0.04974697157740593,\n", + " -1.3714036128931184e-08,\n", + " -0.0205942764878273,\n", + " 0.047028280794620514,\n", + " -0.032210975885391235,\n", + " 0.049078319221735,\n", + " 0.0394253209233284,\n", + " 0.10298289358615875,\n", + " 0.013628372922539711,\n", + " -0.07071256637573242,\n", + " -0.001111415564082563,\n", + " 0.045793063938617706,\n", + " 0.010663686320185661,\n", + " 0.022661199793219566,\n", + " -0.00039414051570929587,\n", + " 0.04868670925498009,\n", + " 0.08181674033403397,\n", + " -0.06234998628497124,\n", + " -0.017647461965680122,\n", + " -0.05699630081653595,\n", + " -0.035604529082775116,\n", + " -0.002848744625225663,\n", + " -0.07433759421110153,\n", + " 0.05970819666981697,\n", + " -0.03040698915719986,\n", + " -0.03587964177131653,\n", + " -0.05538871884346008,\n", + " -0.007939192466437817,\n", + " -0.015285325236618519,\n", + " 0.08461211621761322,\n", + " 0.01166541874408722,\n", + " 0.03213988244533539,\n", + " 0.05643611401319504,\n", + " 0.2006419152021408,\n", + " -0.07411110401153564,\n", + " -0.018009720370173454,\n", + " 0.016179822385311127,\n", + " -0.0028461480978876352,\n", + " 0.0402149073779583,\n", + " 0.0006247136043384671,\n", + " 0.0006973804556764662,\n", + " 0.09922358393669128,\n", + " -0.029822450131177902,\n", + " -0.005783025175333023,\n", + " -0.0028224103152751923,\n", + " -0.11175407469272614,\n", + " 0.012009709142148495,\n", + " -0.009956827387213707,\n", + " 0.011468647047877312,\n", + " -0.054449401795864105,\n", + " -0.016370657831430435,\n", + " 0.022106735035777092,\n", + " 0.03950563445687294,\n", + " 0.005319684278219938,\n", + " 0.042190469801425934,\n", + " 0.08844445645809174,\n", + " 0.0810166597366333,\n", + " 0.06980433315038681,\n", + " -0.04784897342324257,\n", + " 0.01753094792366028,\n", + " -0.10126522183418274,\n", + " -0.016526369377970695,\n", + " 0.11310216039419174,\n", + " 0.0874418243765831,\n", + " 0.09520682692527771,\n", + " 0.10083616524934769]]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "doc_results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aaad49f8", + "metadata": {}, + "outputs": [], "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "langchain", "language": "python", "name": "python3" }, @@ -679,11 +2121,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.16" }, "vscode": { "interpreter": { - "hash": "ce6f9b0d7cdac41515b0e0c38d0e6e153a2edce81d579281cb1ab99da6e8ea6d" + "hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885" } } }, diff --git a/langchain/embeddings/__init__.py b/langchain/embeddings/__init__.py index 261447b1..acfdb141 100644 --- a/langchain/embeddings/__init__.py +++ b/langchain/embeddings/__init__.py @@ -10,6 +10,7 @@ from langchain.embeddings.huggingface import ( ) from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings from langchain.embeddings.self_hosted import SelfHostedEmbeddings from langchain.embeddings.self_hosted_hugging_face import ( SelfHostedHuggingFaceEmbeddings, @@ -25,6 +26,7 @@ __all__ = [ "CohereEmbeddings", "HuggingFaceHubEmbeddings", "TensorflowHubEmbeddings", + "SagemakerEndpointEmbeddings", "HuggingFaceInstructEmbeddings", "SelfHostedEmbeddings", "SelfHostedHuggingFaceEmbeddings", diff --git a/langchain/embeddings/sagemaker_endpoint.py b/langchain/embeddings/sagemaker_endpoint.py new file mode 100644 index 00000000..6fc22187 --- /dev/null +++ b/langchain/embeddings/sagemaker_endpoint.py @@ -0,0 +1,194 @@ +"""Wrapper around Sagemaker InvokeEndpoint API.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.llms.sagemaker_endpoint import ContentHandlerBase + + +class SagemakerEndpointEmbeddings(BaseModel, Embeddings): + """Wrapper around custom Sagemaker Inference Endpoints. + + To use, you must supply the endpoint name from your deployed + Sagemaker model & the region where it is deployed. + + To authenticate, the AWS client uses the following methods to + automatically load credentials: + https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + + If a specific credential profile should be used, you must pass + the name of the profile from the ~/.aws/credentials file that is to be used. + + Make sure the credentials / roles used have the required policies to + access the Sagemaker endpoint. + See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html + """ + + """ + Example: + .. code-block:: python + + from langchain.embeddings import SagemakerEndpointEmbeddings + endpoint_name = ( + "my-endpoint-name" + ) + region_name = ( + "us-west-2" + ) + credentials_profile_name = ( + "default" + ) + se = SagemakerEndpointEmbeddings( + endpoint_name=endpoint_name, + region_name=region_name, + credentials_profile_name=credentials_profile_name + ) + """ + client: Any #: :meta private: + + endpoint_name: str = "" + """The name of the endpoint from the deployed Sagemaker model. + Must be unique within an AWS Region.""" + + region_name: str = "" + """The aws region where the Sagemaker model is deployed, eg. `us-west-2`.""" + + credentials_profile_name: Optional[str] = None + """The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which + has either access keys or role information specified. + If not specified, the default credential profile or, if on an EC2 instance, + credentials from IMDS will be used. + See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html + """ + + content_handler: ContentHandlerBase + """The content handler class that provides an input and + output transform functions to handle formats between LLM + and the endpoint. + """ + + """ + Example: + .. code-block:: python + + from langchain.llms.sagemaker_endpoint import ContentHandlerBase + + class ContentHandler(ContentHandlerBase): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + input_str = json.dumps({prompt: prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json[0]["generated_text"] + """ + + model_kwargs: Optional[Dict] = None + """Key word arguments to pass to the model.""" + + endpoint_kwargs: Optional[Dict] = None + """Optional attributes passed to the invoke_endpoint + function. See `boto3`_. docs for more info. + .. _boto3: + """ + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that AWS credentials to and python package exists in environment.""" + try: + import boto3 + + try: + if values["credentials_profile_name"] is not None: + session = boto3.Session( + profile_name=values["credentials_profile_name"] + ) + else: + # use default credentials + session = boto3.Session() + + values["client"] = session.client( + "sagemaker-runtime", region_name=values["region_name"] + ) + + except Exception as e: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e + + except ImportError: + raise ValueError( + "Could not import boto3 python package. " + "Please it install it with `pip install boto3`." + ) + return values + + def _embedding_func(self, texts: List[str]) -> List[float]: + """Call out to SageMaker Inference embedding endpoint.""" + # replace newlines, which can negatively affect performance. + texts = list(map(lambda x: x.replace("\n", " "), texts)) + _model_kwargs = self.model_kwargs or {} + _endpoint_kwargs = self.endpoint_kwargs or {} + + body = self.content_handler.transform_input(texts, _model_kwargs) + content_type = self.content_handler.content_type + accepts = self.content_handler.accepts + + # send request + try: + response = self.client.invoke_endpoint( + EndpointName=self.endpoint_name, + Body=body, + ContentType=content_type, + Accept=accepts, + **_endpoint_kwargs, + ) + except Exception as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + return self.content_handler.transform_output(response["Body"]) + + def embed_documents( + self, texts: List[str], chunk_size: int = 64 + ) -> List[List[float]]: + """Compute doc embeddings using a SageMaker Inference Endpoint. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size defines how many input texts will + be grouped together as request. If None, will use the + chunk size specified by the class. + + + Returns: + List of embeddings, one for each text. + """ + results = [] + _chunk_size = len(texts) if chunk_size > len(texts) else chunk_size + for i in range(0, len(texts), _chunk_size): + response = self._embedding_func(texts[i : i + _chunk_size]) + results.append(response) + return results + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a SageMaker inference endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self._embedding_func([text]) diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index 246f38e4..926e1718 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -1,6 +1,6 @@ """Wrapper around Sagemaker InvokeEndpoint API.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Union from pydantic import BaseModel, Extra, root_validator @@ -39,7 +39,9 @@ class ContentHandlerBase(ABC): """The MIME type of the response data returned from endpoint""" @abstractmethod - def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes: + def transform_input( + self, prompt: Union[str, List[str]], model_kwargs: Dict + ) -> bytes: """Transforms the input to a format that model can accept as the request Body. Should return bytes or seekable file like object in the format specified in the content_type @@ -47,7 +49,7 @@ class ContentHandlerBase(ABC): """ @abstractmethod - def transform_output(self, output: bytes) -> str: + def transform_output(self, output: bytes) -> Any: """Transforms the output from the model to string that the LLM class expects. """