feat(llms): support ERNIE Embedding-V1 (#9370)

- Description: support [ERNIE
Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu),
which is part of ERNIE ecology
- Issue: None
- Dependencies: None
- Tag maintainer: @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
axiangcoding 2023-08-21 22:52:25 +08:00 committed by GitHub
parent f116e10d53
commit 05aa02005b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 0 deletions

View File

@ -0,0 +1,60 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ERNIE Embedding-V1\n",
"\n",
"[ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu) is a text representation model based on Baidu Wenxin's large-scale model technology, \n",
"which converts text into a vector form represented by numerical values, and is used in text retrieval, information recommendation, knowledge mining and other scenarios."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.embeddings import ErnieEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = ErnieEmbeddings()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -28,6 +28,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.edenai import EdenAiEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
from langchain.embeddings.embaas import EmbaasEmbeddings
from langchain.embeddings.ernie import ErnieEmbeddings
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
from langchain.embeddings.google_palm import GooglePalmEmbeddings
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
@ -101,6 +102,7 @@ __all__ = [
"LocalAIEmbeddings",
"AwaEmbeddings",
"HuggingFaceBgeEmbeddings",
"ErnieEmbeddings",
]

View File

@ -0,0 +1,102 @@
import logging
import threading
from typing import Dict, List, Optional
import requests
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class ErnieEmbeddings(BaseModel, Embeddings):
"""`Ernie Embeddings V1` embedding models."""
ernie_client_id: Optional[str] = None
ernie_client_secret: Optional[str] = None
access_token: Optional[str] = None
chunk_size: int = 16
model_name = "ErnieBot-Embedding-V1"
_lock = threading.Lock()
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["ernie_client_id"] = get_from_dict_or_env(
values,
"ernie_client_id",
"ERNIE_CLIENT_ID",
)
values["ernie_client_secret"] = get_from_dict_or_env(
values,
"ernie_client_secret",
"ERNIE_CLIENT_SECRET",
)
return values
def _embedding(self, json: object) -> dict:
base_url = (
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
)
resp = requests.post(
f"{base_url}/embedding-v1",
headers={
"Content-Type": "application/json",
},
params={"access_token": self.access_token},
json=json,
)
return resp.json()
def _refresh_access_token_with_lock(self) -> None:
with self._lock:
logger.debug("Refreshing access token")
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
resp = requests.post(
base_url,
headers={
"Content-Type": "application/json",
"Accept": "application/json",
},
params={
"grant_type": "client_credentials",
"client_id": self.ernie_client_id,
"client_secret": self.ernie_client_secret,
},
)
self.access_token = str(resp.json().get("access_token"))
def embed_documents(self, texts: List[str]) -> List[List[float]]:
if not self.access_token:
self._refresh_access_token_with_lock()
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self._embedding({"input": [text for text in chunk]})
if resp.get("error_code"):
if resp.get("error_code") == 111:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text for text in chunk]})
else:
raise ValueError(f"Error from Ernie: {resp}")
lst.extend([i["embedding"] for i in resp["data"]])
return lst
def embed_query(self, text: str) -> List[float]:
if not self.access_token:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text]})
if resp.get("error_code"):
if resp.get("error_code") == 111:
self._refresh_access_token_with_lock()
resp = self._embedding({"input": [text]})
else:
raise ValueError(f"Error from Ernie: {resp}")
return resp["data"][0]["embedding"]

View File

@ -0,0 +1,41 @@
import pytest
from langchain.embeddings.ernie import ErnieEmbeddings
def test_embedding_documents_1() -> None:
documents = ["foo bar"]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 384
def test_embedding_documents_2() -> None:
documents = ["foo", "bar"]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
assert len(output[1]) == 384
def test_embedding_query() -> None:
query = "foo"
embedding = ErnieEmbeddings()
output = embedding.embed_query(query)
assert len(output) == 384
def test_max_chunks() -> None:
documents = [f"text-{i}" for i in range(20)]
embedding = ErnieEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 20
def test_too_many_chunks() -> None:
documents = [f"text-{i}" for i in range(20)]
embedding = ErnieEmbeddings(chunk_size=20)
with pytest.raises(ValueError):
embedding.embed_documents(documents)