mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
feat(llms): support ERNIE Embedding-V1 (#9370)
- Description: support [ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu), which is part of ERNIE ecology - Issue: None - Dependencies: None - Tag maintainer: @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
f116e10d53
commit
05aa02005b
60
docs/extras/integrations/text_embedding/ernie.ipynb
Normal file
60
docs/extras/integrations/text_embedding/ernie.ipynb
Normal file
@ -0,0 +1,60 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ERNIE Embedding-V1\n",
|
||||
"\n",
|
||||
"[ERNIE Embedding-V1](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu) is a text representation model based on Baidu Wenxin's large-scale model technology, \n",
|
||||
"which converts text into a vector form represented by numerical values, and is used in text retrieval, information recommendation, knowledge mining and other scenarios."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import ErnieEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = ErnieEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_result = embeddings.embed_query(\"foo\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"doc_results = embeddings.embed_documents([\"foo\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -28,6 +28,7 @@ from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||
from langchain.embeddings.embaas import EmbaasEmbeddings
|
||||
from langchain.embeddings.ernie import ErnieEmbeddings
|
||||
from langchain.embeddings.fake import DeterministicFakeEmbedding, FakeEmbeddings
|
||||
from langchain.embeddings.google_palm import GooglePalmEmbeddings
|
||||
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
|
||||
@ -101,6 +102,7 @@ __all__ = [
|
||||
"LocalAIEmbeddings",
|
||||
"AwaEmbeddings",
|
||||
"HuggingFaceBgeEmbeddings",
|
||||
"ErnieEmbeddings",
|
||||
]
|
||||
|
||||
|
||||
|
102
libs/langchain/langchain/embeddings/ernie.py
Normal file
102
libs/langchain/langchain/embeddings/ernie.py
Normal file
@ -0,0 +1,102 @@
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ErnieEmbeddings(BaseModel, Embeddings):
|
||||
"""`Ernie Embeddings V1` embedding models."""
|
||||
|
||||
ernie_client_id: Optional[str] = None
|
||||
ernie_client_secret: Optional[str] = None
|
||||
access_token: Optional[str] = None
|
||||
|
||||
chunk_size: int = 16
|
||||
|
||||
model_name = "ErnieBot-Embedding-V1"
|
||||
|
||||
_lock = threading.Lock()
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
values["ernie_client_id"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_id",
|
||||
"ERNIE_CLIENT_ID",
|
||||
)
|
||||
values["ernie_client_secret"] = get_from_dict_or_env(
|
||||
values,
|
||||
"ernie_client_secret",
|
||||
"ERNIE_CLIENT_SECRET",
|
||||
)
|
||||
return values
|
||||
|
||||
def _embedding(self, json: object) -> dict:
|
||||
base_url = (
|
||||
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
||||
)
|
||||
resp = requests.post(
|
||||
f"{base_url}/embedding-v1",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
params={"access_token": self.access_token},
|
||||
json=json,
|
||||
)
|
||||
return resp.json()
|
||||
|
||||
def _refresh_access_token_with_lock(self) -> None:
|
||||
with self._lock:
|
||||
logger.debug("Refreshing access token")
|
||||
base_url: str = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
resp = requests.post(
|
||||
base_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.ernie_client_id,
|
||||
"client_secret": self.ernie_client_secret,
|
||||
},
|
||||
)
|
||||
self.access_token = str(resp.json().get("access_token"))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
if not self.access_token:
|
||||
self._refresh_access_token_with_lock()
|
||||
text_in_chunks = [
|
||||
texts[i : i + self.chunk_size]
|
||||
for i in range(0, len(texts), self.chunk_size)
|
||||
]
|
||||
lst = []
|
||||
for chunk in text_in_chunks:
|
||||
resp = self._embedding({"input": [text for text in chunk]})
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._embedding({"input": [text for text in chunk]})
|
||||
else:
|
||||
raise ValueError(f"Error from Ernie: {resp}")
|
||||
lst.extend([i["embedding"] for i in resp["data"]])
|
||||
return lst
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
if not self.access_token:
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._embedding({"input": [text]})
|
||||
if resp.get("error_code"):
|
||||
if resp.get("error_code") == 111:
|
||||
self._refresh_access_token_with_lock()
|
||||
resp = self._embedding({"input": [text]})
|
||||
else:
|
||||
raise ValueError(f"Error from Ernie: {resp}")
|
||||
return resp["data"][0]["embedding"]
|
@ -0,0 +1,41 @@
|
||||
import pytest
|
||||
|
||||
from langchain.embeddings.ernie import ErnieEmbeddings
|
||||
|
||||
|
||||
def test_embedding_documents_1() -> None:
|
||||
documents = ["foo bar"]
|
||||
embedding = ErnieEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 384
|
||||
|
||||
|
||||
def test_embedding_documents_2() -> None:
|
||||
documents = ["foo", "bar"]
|
||||
embedding = ErnieEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 384
|
||||
assert len(output[1]) == 384
|
||||
|
||||
|
||||
def test_embedding_query() -> None:
|
||||
query = "foo"
|
||||
embedding = ErnieEmbeddings()
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) == 384
|
||||
|
||||
|
||||
def test_max_chunks() -> None:
|
||||
documents = [f"text-{i}" for i in range(20)]
|
||||
embedding = ErnieEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 20
|
||||
|
||||
|
||||
def test_too_many_chunks() -> None:
|
||||
documents = [f"text-{i}" for i in range(20)]
|
||||
embedding = ErnieEmbeddings(chunk_size=20)
|
||||
with pytest.raises(ValueError):
|
||||
embedding.embed_documents(documents)
|
Loading…
Reference in New Issue
Block a user