mirror of https://github.com/hwchase17/langchain
Bagatur/eden llm (#8670)
Co-authored-by: RedhaWassim <rwasssim@gmail.com> Co-authored-by: KyrianC <ckyrian@protonmail.com> Co-authored-by: sam <melaine.samy@gmail.com>pull/8720/head
parent
8022293124
commit
b2b71b0d35
File diff suppressed because one or more lines are too long
@ -0,0 +1,169 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# EDEN AI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Eden AI is an AI consulting company that was founded to use its resources to empower people and create impactful products that use AI to improve the quality of life for individuals, businesses and societies at large."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example goes over how to use LangChain to interact with Eden AI embedding models\n",
|
||||
"\n",
|
||||
"-----------------------------------------------------------------------------------\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Accessing the EDENAI's API requires an API key, \n",
|
||||
"\n",
|
||||
"which you can get by creating an account https://app.edenai.run/user/register and heading here https://app.edenai.run/admin/account/settings\n",
|
||||
"\n",
|
||||
"Once we have a key we'll want to set it as an environment variable by running:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"export EDENAI_API_KEY=\"...\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you'd prefer not to set an environment variable you can pass the key in directly via the edenai_api_key named parameter\n",
|
||||
"\n",
|
||||
" when initiating the EdenAI embedding class:\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.edenai import EdenAiEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = EdenAiEmbeddings(edenai_api_key=\"...\",provider=\"...\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Calling a model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The EdenAI API brings together various providers.\n",
|
||||
"\n",
|
||||
"To access a specific model, you can simply use the \"provider\" when calling.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = EdenAiEmbeddings(provider=\"openai\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs = [\"It's raining right now\", \"cats are cute\"]\n",
|
||||
"document_result = embeddings.embed_documents(docs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"my umbrella is broken\"\n",
|
||||
"query_result = embeddings.embed_query(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Cosine similarity between \"It's raining right now\" and query: 0.849261496107252\n",
|
||||
"Cosine similarity between \"cats are cute\" and query: 0.7525900655705218\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"query_numpy = np.array(query_result)\n",
|
||||
"for doc_res, doc in zip(document_result, docs):\n",
|
||||
" document_numpy = np.array(doc_res)\n",
|
||||
" similarity = np.dot(query_numpy, document_numpy) / (\n",
|
||||
" np.linalg.norm(query_numpy) * np.linalg.norm(document_numpy)\n",
|
||||
" )\n",
|
||||
" print(f'Cosine similarity between \"{doc}\" and query: {similarity}')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,88 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class EdenAiEmbeddings(BaseModel, Embeddings):
|
||||
"""EdenAI embedding.
|
||||
environment variable ``EDENAI_API_KEY`` set with your API key, or pass
|
||||
it as a named parameter.
|
||||
"""
|
||||
|
||||
edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token")
|
||||
|
||||
provider: Optional[str] = "openai"
|
||||
"""embedding provider to use (eg: openai,google etc.)"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
def _generate_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Compute embeddings using EdenAi api."""
|
||||
url = "https://api.edenai.run/v2/text/embeddings"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {self.edenai_api_key}",
|
||||
}
|
||||
|
||||
payload = {"texts": texts, "providers": self.provider}
|
||||
request = Requests(headers=headers)
|
||||
response = request.post(url=url, data=payload)
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
temp = response.json()
|
||||
|
||||
embeddings = []
|
||||
for embed_item in temp[self.provider]["items"]:
|
||||
embedding = embed_item["embedding"]
|
||||
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using EdenAI.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
return self._generate_embeddings(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using EdenAI.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self._generate_embeddings([text])[0]
|
@ -0,0 +1,217 @@
|
||||
"""Wrapper around EdenAI's Generation API."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.requests import Requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EdenAI(LLM):
|
||||
"""Wrapper around edenai models.
|
||||
|
||||
To use, you should have
|
||||
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
||||
You can find your token here: https://app.edenai.run/admin/account/settings
|
||||
|
||||
`feature` and `subfeature` are required, but any other model parameters can also be
|
||||
passed in with the format params={model_param: value, ...}
|
||||
|
||||
for api reference check edenai documentation: http://docs.edenai.co.
|
||||
"""
|
||||
|
||||
base_url = "https://api.edenai.run/v2"
|
||||
|
||||
edenai_api_key: Optional[str] = None
|
||||
|
||||
feature: Literal["text", "image"] = "text"
|
||||
"""Which generative feature to use, use text by default"""
|
||||
|
||||
subfeature: Literal["generation"] = "generation"
|
||||
"""Subfeature of above feature, use generation by default"""
|
||||
|
||||
provider: str
|
||||
"""Geneerative provider to use (eg: openai,stabilityai,cohere,google etc.)"""
|
||||
|
||||
params: Dict[str, Any]
|
||||
"""
|
||||
Parameters to pass to above subfeature (excluding 'providers' & 'text')
|
||||
ref text: https://docs.edenai.co/reference/text_generation_create
|
||||
ref image: https://docs.edenai.co/reference/text_generation_create
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""extra parameters"""
|
||||
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
"""Stop sequences to use."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["edenai_api_key"] = get_from_dict_or_env(
|
||||
values, "edenai_api_key", "EDENAI_API_KEY"
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
||||
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name not in all_required_field_names:
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
logger.warning(
|
||||
f"""{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of model."""
|
||||
return "edenai"
|
||||
|
||||
def _format_output(self, output: dict) -> str:
|
||||
if self.feature == "text":
|
||||
return output[self.provider]["generated_text"]
|
||||
else:
|
||||
return output[self.provider]["items"][0]["image"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to EdenAI's text generation endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
|
||||
Returns:
|
||||
json formatted str response.
|
||||
|
||||
"""
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
**kwargs,
|
||||
}
|
||||
request = Requests(headers=headers)
|
||||
|
||||
response = request.post(url=url, data=payload)
|
||||
|
||||
if response.status_code >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status_code}")
|
||||
elif response.status_code >= 400:
|
||||
raise ValueError(f"EdenAI received an invalid payload: {response.text}")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
output = self._format_output(response.json())
|
||||
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call EdenAi model to get predictions based on the prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of stop words (optional).
|
||||
run_manager: A callback manager for async interaction with LLMs.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
|
||||
stops = None
|
||||
if self.stop_sequences is not None and stop is not None:
|
||||
raise ValueError(
|
||||
"stop sequences found in both the input and default params."
|
||||
)
|
||||
elif self.stop_sequences is not None:
|
||||
stops = self.stop_sequences
|
||||
else:
|
||||
stops = stop
|
||||
|
||||
print("Running the acall")
|
||||
url = f"{self.base_url}/{self.feature}/{self.subfeature}"
|
||||
headers = {"Authorization": f"Bearer {self.edenai_api_key}"}
|
||||
payload = {
|
||||
**self.params,
|
||||
"providers": self.provider,
|
||||
"num_images": 1, # always limit to 1 (ignored for text)
|
||||
"text": prompt,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
async with ClientSession() as session:
|
||||
print("Requesting")
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status >= 500:
|
||||
raise Exception(f"EdenAI Server: Error {response.status}")
|
||||
elif response.status >= 400:
|
||||
raise ValueError(
|
||||
f"EdenAI received an invalid payload: {response.text}"
|
||||
)
|
||||
elif response.status != 200:
|
||||
raise Exception(
|
||||
f"EdenAI returned an unexpected response with status "
|
||||
f"{response.status}: {response.text}"
|
||||
)
|
||||
|
||||
response_json = await response.json()
|
||||
|
||||
output = self._format_output(response_json)
|
||||
if stops is not None:
|
||||
output = enforce_stop_tokens(output, stops)
|
||||
|
||||
return output
|
@ -0,0 +1,21 @@
|
||||
"""Test edenai embeddings."""
|
||||
|
||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||
|
||||
|
||||
def test_edenai_embedding_documents() -> None:
|
||||
"""Test edenai embeddings with openai."""
|
||||
documents = ["foo bar", "test text"]
|
||||
embedding = EdenAiEmbeddings(provider="openai")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 1536
|
||||
assert len(output[1]) == 1536
|
||||
|
||||
|
||||
def test_edenai_embedding_query() -> None:
|
||||
"""Test eden ai embeddings with google."""
|
||||
document = "foo bar"
|
||||
embedding = EdenAiEmbeddings(provider="google")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 768
|
@ -0,0 +1,32 @@
|
||||
"""Test EdenAi API wrapper.
|
||||
|
||||
In order to run this test, you need to have an EdenAI api key.
|
||||
You can get it by registering for free at https://app.edenai.run/user/register.
|
||||
A test key can be found at https://app.edenai.run/admin/account/settings by
|
||||
clicking on the 'sandbox' toggle.
|
||||
(calls will be free, and will return dummy results)
|
||||
|
||||
You'll then need to set EDENAI_API_KEY environment variable to your api key.
|
||||
"""
|
||||
from langchain.llms import EdenAI
|
||||
|
||||
|
||||
def test_edenai_call() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
output = llm("Say foo:")
|
||||
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
async def test_edenai_acall() -> None:
|
||||
"""Test simple call to edenai."""
|
||||
llm = EdenAI(provider="openai", params={"temperature": 0.2, "max_tokens": 250})
|
||||
output = await llm.agenerate(["Say foo:"])
|
||||
assert llm._llm_type == "edenai"
|
||||
assert llm.feature == "text"
|
||||
assert llm.subfeature == "generation"
|
||||
assert isinstance(output, str)
|
Loading…
Reference in New Issue