Arcee.ai LLM & Retriever integration (#11579)

- **Description:** This PR introduces a new LLM and Retriever API to
https://arcee.ai for the python client
  - **Issue:** implements the integrations as requested in #11578 ,
  - **Dependencies:** no dependencies are required,
  - **Tag maintainer:** @hwchase17
  - **Twitter handle:** shwooobham 


** `make format`, `make lint` and `make test` runs locally.**
```shell
=========== 1245 passed, 277 skipped, 20 warnings in 16.26s ===========
./scripts/check_pydantic.sh .
./scripts/check_imports.sh
poetry run ruff .
[ "." = "" ] || poetry run black . --check
All done!  🍰 
1818 files would be left unchanged.
[ "." = "" ] || poetry run mypy .
Success: no issues found in 1815 source files
[ "." = "" ] || poetry run black .
All done!  🍰 
1818 files left unchanged.
[ "." = "" ] || poetry run ruff --select I --fix .
poetry run codespell --toml pyproject.toml
poetry run codespell --toml pyproject.toml -w
```


**Contributions**
1. Arcee (langchain/llms), ArceeRetriever (langchain/retrievers),
ArceeWrapper (langchain/utilities)
2. docs for Arcee (llms/arcee.py) and
ArceeRetriever(retrievers/arcee.py)
3.

cc: @jacobsolawetz @ben-epstein

---------

Co-authored-by: Shubham <shubham@sORo.local>
pull/11621/head
Shubham Kushwaha 1 year ago committed by GitHub
parent b6a2507794
commit 49de862076
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,146 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Arcee\n",
"This notebook demonstrates how to use the `Arcee` class for generating text using Arcee's Domain Adapted Language Models (DALMs)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Arcee\n",
"\n",
"# Create an instance of the Arcee class\n",
"arcee = Arcee(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional Configuration\n",
"\n",
"You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
"Setting the `model_kwargs` at the object initialization uses the parameters as default for all the subsequent calls to the generate response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"arcee = Arcee(\n",
" model=\"DALM-Patent\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
" model_kwargs={\n",
" \"size\": 5,\n",
" \"filters\": [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" }\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generating Text\n",
"\n",
"You can generate text from Arcee by providing a prompt. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Generate text\n",
"prompt = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
"response = arcee(prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional parameters\n",
"\n",
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s) to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define filters\n",
"filters = [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" },\n",
" {\n",
" \"field_name\": \"year\",\n",
" \"filter_type\": \"strict_search\",\n",
" \"value\": \"1905\"\n",
" }\n",
"]\n",
"\n",
"# Generate text with filters and size params\n",
"response = arcee(prompt, size=5, filters=filters)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,141 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Arcee Retriever\n",
"This notebook demonstrates how to use the `ArceeRetriever` class to retrieve relevant document(s) for Arcee's Domain Adapted Language Models (DALMs)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup\n",
"\n",
"Before using `ArceeRetriever`, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.retrievers import ArceeRetriever\n",
"\n",
"retriever = ArceeRetriever(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional Configuration\n",
"\n",
"You can also configure `ArceeRetriever`'s parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
"Setting the `model_kwargs` at the object initialization uses the filters and size as default for all the subsequent retrievals."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"retriever = ArceeRetriever(\n",
" model=\"DALM-PubMed\",\n",
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
" model_kwargs={\n",
" \"size\": 5,\n",
" \"filters\": [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Einstein\"\n",
" }\n",
" ]\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Retrieving documents\n",
"You can retrieve relevant documents from uploaded contexts by providing a query. Here's an example:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
"documents = retriever.get_relevant_documents(query=query)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Additional parameters\n",
"\n",
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s). Filters help narrow down the results. Here's how to use these parameters:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define filters\n",
"filters = [\n",
" {\n",
" \"field_name\": \"document\",\n",
" \"filter_type\": \"fuzzy_search\",\n",
" \"value\": \"Music\"\n",
" },\n",
" {\n",
" \"field_name\": \"year\",\n",
" \"filter_type\": \"strict_search\",\n",
" \"value\": \"1905\"\n",
" }\n",
"]\n",
"\n",
"# Retrieve documents with filters and size params\n",
"documents = retriever.get_relevant_documents(query=query, size=5, filters=filters)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -52,6 +52,12 @@ def _import_anyscale() -> Any:
return Anyscale
def _import_arcee() -> Any:
from langchain.llms.arcee import Arcee
return Arcee
def _import_aviary() -> Any:
from langchain.llms.aviary import Aviary
@ -479,6 +485,8 @@ def __getattr__(name: str) -> Any:
return _import_anthropic()
elif name == "Anyscale":
return _import_anyscale()
elif name == "Arcee":
return _import_arcee()
elif name == "Aviary":
return _import_aviary()
elif name == "AzureMLOnlineEndpoint":
@ -633,6 +641,7 @@ __all__ = [
"AmazonAPIGateway",
"Anthropic",
"Anyscale",
"Arcee",
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
@ -713,6 +722,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"amazon_bedrock": _import_bedrock,
"anthropic": _import_anthropic,
"anyscale": _import_anyscale,
"arcee": _import_arcee,
"aviary": _import_aviary,
"azure": _import_azure_openai,
"azureml_endpoint": _import_azureml_endpoint,

@ -0,0 +1,147 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, root_validator
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
class Arcee(LLM):
"""Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain.llms import Arcee
arcee = Arcee(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
response = arcee("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee _client."""
arcee_api_key: str = ""
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_id: str = ""
"""Arcee Model ID"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "arcee"
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = None
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must be positive")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.generate(prompt=prompt, **kwargs)
except Exception as e:
raise Exception(f"Failed to generate text: {e}") from e

@ -18,6 +18,7 @@ the backbone of a retriever, but there are other types of retrievers as well.
CallbackManagerForRetrieverRun, AsyncCallbackManagerForRetrieverRun
"""
from langchain.retrievers.arcee import ArceeRetriever
from langchain.retrievers.arxiv import ArxivRetriever
from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetriever
from langchain.retrievers.bm25 import BM25Retriever
@ -66,6 +67,7 @@ from langchain.retrievers.zilliz import ZillizRetriever
__all__ = [
"AmazonKendraRetriever",
"ArceeRetriever",
"ArxivRetriever",
"AzureCognitiveSearchRetriever",
"ChatGPTPluginRetriever",

@ -0,0 +1,136 @@
from typing import Any, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import BaseRetriever
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
from langchain.utils import get_from_dict_or_env
class ArceeRetriever(BaseRetriever):
"""Document retriever for Arcee's Domain Adapted Language Models (DALMs).
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
or pass ``arcee_api_key`` as a named parameter.
Example:
.. code-block:: python
from langchain.retrievers import ArceeRetriever
retriever = ArceeRetriever(
model="DALM-PubMed",
arcee_api_key="ARCEE-API-KEY"
)
documents = retriever.get_relevant_documents("AI-driven music therapy")
"""
_client: Optional[ArceeWrapper] = None #: :meta private:
"""Arcee client."""
arcee_api_key: str = ""
"""Arcee API Key"""
model: str
"""Arcee DALM name"""
arcee_api_url: str = "https://api.arcee.ai"
"""Arcee API URL"""
arcee_api_version: str = "v2"
"""Arcee API Version"""
arcee_app_url: str = "https://app.arcee.ai"
"""Arcee App URL"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Keyword arguments to pass to the model."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
underscore_attrs_are_private = True
def __init__(self, **data: Any) -> None:
"""Initializes private fields."""
super().__init__(**data)
self._client = ArceeWrapper(
arcee_api_key=self.arcee_api_key,
arcee_api_url=self.arcee_api_url,
arcee_api_version=self.arcee_api_version,
model_kwargs=self.model_kwargs,
model_name=self.model,
)
self._client.validate_model_training_status()
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
"""Validate Arcee environment variables."""
# validate env vars
values["arcee_api_key"] = get_from_dict_or_env(
values,
"arcee_api_key",
"ARCEE_API_KEY",
)
values["arcee_api_url"] = get_from_dict_or_env(
values,
"arcee_api_url",
"ARCEE_API_URL",
)
values["arcee_app_url"] = get_from_dict_or_env(
values,
"arcee_app_url",
"ARCEE_APP_URL",
)
values["arcee_api_version"] = get_from_dict_or_env(
values,
"arcee_api_version",
"ARCEE_API_VERSION",
)
# validate model kwargs
if values["model_kwargs"]:
kw = values["model_kwargs"]
# validate size
if kw.get("size") is not None:
if not kw.get("size") >= 0:
raise ValueError("`size` must not be negative.")
# validate filters
if kw.get("filters") is not None:
if not isinstance(kw.get("filters"), List):
raise ValueError("`filters` must be a list.")
for f in kw.get("filters"):
DALMFilter(**f)
return values
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve.
Defaults to 3. (Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
try:
if not self._client:
raise ValueError("Client is not initialized.")
return self._client.retrieve(query=query, **kwargs)
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e

@ -5,6 +5,7 @@ and packages.
"""
from langchain.utilities.alpha_vantage import AlphaVantageAPIWrapper
from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.arcee import ArceeWrapper
from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.awslambda import LambdaWrapper
from langchain.utilities.bash import BashProcess
@ -41,6 +42,7 @@ from langchain.utilities.zapier import ZapierNLAWrapper
__all__ = [
"AlphaVantageAPIWrapper",
"ApifyWrapper",
"ArceeWrapper",
"ArxivAPIWrapper",
"BashProcess",
"BibtexparserWrapper",

@ -0,0 +1,189 @@
# This module contains utility classes and functions for interacting with Arcee API.
# For more information and updates, refer to the Arcee utils page:
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
import requests
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.retriever import Document
class ArceeRoute(str, Enum):
generate = "models/generate"
retrieve = "models/retrieve"
model_training_status = "models/status/{id_or_name}"
class DALMFilterType(str, Enum):
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"
class DALMFilter(BaseModel):
"""Filters available for a dalm retrieval and generation
Arguments:
field_name: The field to filter on. Can be 'document' or 'name' to filter
on your document's raw text or title. Any other field will be presumed
to be a metadata field you included when uploading your context data
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
'fuzzy_search' means a fuzzy search on the provided field is performed.
The exact strict doesn't need to exist in the document
for this to find a match.
Very useful for scanning a document for some keyword terms.
'strict_search' means that the exact string must appear
in the provided field.
This is NOT an exact eq filter. ie a document with content
"the happy dog crossed the street" will match on a strict_search of
"dog" but won't match on "the dog".
Python equivalent of `return search_string in full_string`.
value: The actual value to search for in the context data/metadata
"""
field_name: str
filter_type: DALMFilterType
value: str
_is_metadata: bool = False
@root_validator()
def set_meta(cls, values: Dict) -> Dict:
"""document and name are reserved arcee keys. Anything else is metadata"""
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
return values
class ArceeWrapper:
def __init__(
self,
arcee_api_key: str,
arcee_api_url: str,
arcee_api_version: str,
model_kwargs: Optional[Dict[str, Any]],
model_name: str,
):
self.arcee_api_key = arcee_api_key
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
self.arcee_api_version = arcee_api_version
try:
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
response = self._make_request("get", route)
self.model_id = response.get("model_id")
self.model_training_status = response.get("status")
except Exception as e:
raise ValueError(
f"Error while validating model training status for '{model_name}': {e}"
) from e
def validate_model_training_status(self) -> None:
if self.model_training_status != "training_complete":
raise Exception(
f"Model {self.model_id} is not ready. "
"Please wait for training to complete."
)
def _make_request(
self,
method: Literal["post", "get"],
route: Union[ArceeRoute, str],
body: Optional[Mapping[str, Any]] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
) -> dict:
"""Make a request to the Arcee API
Args:
method: The HTTP method to use
route: The route to call
body: The body of the request
params: The query params of the request
headers: The headers of the request
"""
headers = self._make_request_headers(headers=headers)
url = self._make_request_url(route=route)
req_type = getattr(requests, method)
response = req_type(url, json=body, params=params, headers=headers)
if response.status_code not in (200, 201):
raise Exception(f"Failed to make request. Response: {response.text}")
return response.json()
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
headers = headers or {}
internal_headers = {
"X-Token": self.arcee_api_key,
"Content-Type": "application/json",
}
headers.update(internal_headers)
return headers
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
def _make_request_body_for_models(
self, prompt: str, **kwargs: Mapping[str, Any]
) -> Mapping[str, Any]:
"""Make the request body for generate/retrieve models endpoint"""
_model_kwargs = self.model_kwargs or {}
_params = {**_model_kwargs, **kwargs}
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
return dict(
model_id=self.model_id,
query=prompt,
size=_params.get("size", 3),
filters=filters,
id=self.model_id,
)
def generate(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Generate text from Arcee DALM.
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.generate,
body=self._make_request_body_for_models(
prompt=prompt,
**kwargs,
),
)
return response["text"]
def retrieve(
self,
query: str,
**kwargs: Any,
) -> List[Document]:
"""Retrieve {size} contexts with your retriever for a given query
Args:
query: Query to submit to the model
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
response = self._make_request(
method="post",
route=ArceeRoute.retrieve,
body=self._make_request_body_for_models(
prompt=query,
**kwargs,
),
)
return [Document(**doc) for doc in response["documents"]]
Loading…
Cancel
Save