mirror of https://github.com/hwchase17/langchain
Arcee.ai LLM & Retriever integration (#11579)
- **Description:** This PR introduces a new LLM and Retriever API to https://arcee.ai for the python client - **Issue:** implements the integrations as requested in #11578 , - **Dependencies:** no dependencies are required, - **Tag maintainer:** @hwchase17 - **Twitter handle:** shwooobham **✅ `make format`, `make lint` and `make test` runs locally.** ```shell =========== 1245 passed, 277 skipped, 20 warnings in 16.26s =========== ./scripts/check_pydantic.sh . ./scripts/check_imports.sh poetry run ruff . [ "." = "" ] || poetry run black . --check All done! ✨ 🍰 ✨ 1818 files would be left unchanged. [ "." = "" ] || poetry run mypy . Success: no issues found in 1815 source files [ "." = "" ] || poetry run black . All done! ✨ 🍰 ✨ 1818 files left unchanged. [ "." = "" ] || poetry run ruff --select I --fix . poetry run codespell --toml pyproject.toml poetry run codespell --toml pyproject.toml -w ``` **Contributions** 1. Arcee (langchain/llms), ArceeRetriever (langchain/retrievers), ArceeWrapper (langchain/utilities) 2. docs for Arcee (llms/arcee.py) and ArceeRetriever(retrievers/arcee.py) 3. cc: @jacobsolawetz @ben-epstein --------- Co-authored-by: Shubham <shubham@sORo.local>pull/11621/head
parent
b6a2507794
commit
49de862076
@ -0,0 +1,146 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Arcee\n",
|
||||
"This notebook demonstrates how to use the `Arcee` class for generating text using Arcee's Domain Adapted Language Models (DALMs)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"\n",
|
||||
"Before using Arcee, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import Arcee\n",
|
||||
"\n",
|
||||
"# Create an instance of the Arcee class\n",
|
||||
"arcee = Arcee(\n",
|
||||
" model=\"DALM-PubMed\",\n",
|
||||
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Additional Configuration\n",
|
||||
"\n",
|
||||
"You can also configure Arcee's parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
|
||||
"Setting the `model_kwargs` at the object initialization uses the parameters as default for all the subsequent calls to the generate response."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"arcee = Arcee(\n",
|
||||
" model=\"DALM-Patent\",\n",
|
||||
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
|
||||
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
|
||||
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
|
||||
" model_kwargs={\n",
|
||||
" \"size\": 5,\n",
|
||||
" \"filters\": [\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"document\",\n",
|
||||
" \"filter_type\": \"fuzzy_search\",\n",
|
||||
" \"value\": \"Einstein\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generating Text\n",
|
||||
"\n",
|
||||
"You can generate text from Arcee by providing a prompt. Here's an example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate text\n",
|
||||
"prompt = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
|
||||
"response = arcee(prompt)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Additional parameters\n",
|
||||
"\n",
|
||||
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s) to aid text generation. Filters help narrow down the results. Here's how to use these parameters:\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define filters\n",
|
||||
"filters = [\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"document\",\n",
|
||||
" \"filter_type\": \"fuzzy_search\",\n",
|
||||
" \"value\": \"Einstein\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"year\",\n",
|
||||
" \"filter_type\": \"strict_search\",\n",
|
||||
" \"value\": \"1905\"\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Generate text with filters and size params\n",
|
||||
"response = arcee(prompt, size=5, filters=filters)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,141 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Arcee Retriever\n",
|
||||
"This notebook demonstrates how to use the `ArceeRetriever` class to retrieve relevant document(s) for Arcee's Domain Adapted Language Models (DALMs)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup\n",
|
||||
"\n",
|
||||
"Before using `ArceeRetriever`, make sure the Arcee API key is set as `ARCEE_API_KEY` environment variable. You can also pass the api key as a named parameter."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.retrievers import ArceeRetriever\n",
|
||||
"\n",
|
||||
"retriever = ArceeRetriever(\n",
|
||||
" model=\"DALM-PubMed\",\n",
|
||||
" # arcee_api_key=\"ARCEE-API-KEY\" # if not already set in the environment\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Additional Configuration\n",
|
||||
"\n",
|
||||
"You can also configure `ArceeRetriever`'s parameters such as `arcee_api_url`, `arcee_app_url`, and `model_kwargs` as needed.\n",
|
||||
"Setting the `model_kwargs` at the object initialization uses the filters and size as default for all the subsequent retrievals."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = ArceeRetriever(\n",
|
||||
" model=\"DALM-PubMed\",\n",
|
||||
" # arcee_api_key=\"ARCEE-API-KEY\", # if not already set in the environment\n",
|
||||
" arcee_api_url=\"https://custom-api.arcee.ai\", # default is https://api.arcee.ai\n",
|
||||
" arcee_app_url=\"https://custom-app.arcee.ai\", # default is https://app.arcee.ai\n",
|
||||
" model_kwargs={\n",
|
||||
" \"size\": 5,\n",
|
||||
" \"filters\": [\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"document\",\n",
|
||||
" \"filter_type\": \"fuzzy_search\",\n",
|
||||
" \"value\": \"Einstein\"\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Retrieving documents\n",
|
||||
"You can retrieve relevant documents from uploaded contexts by providing a query. Here's an example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"Can AI-driven music therapy contribute to the rehabilitation of patients with disorders of consciousness?\"\n",
|
||||
"documents = retriever.get_relevant_documents(query=query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Additional parameters\n",
|
||||
"\n",
|
||||
"Arcee allows you to apply `filters` and set the `size` (in terms of count) of retrieved document(s). Filters help narrow down the results. Here's how to use these parameters:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define filters\n",
|
||||
"filters = [\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"document\",\n",
|
||||
" \"filter_type\": \"fuzzy_search\",\n",
|
||||
" \"value\": \"Music\"\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"field_name\": \"year\",\n",
|
||||
" \"filter_type\": \"strict_search\",\n",
|
||||
" \"value\": \"1905\"\n",
|
||||
" }\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Retrieve documents with filters and size params\n",
|
||||
"documents = retriever.get_relevant_documents(query=query, size=5, filters=filters)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,147 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class Arcee(LLM):
|
||||
"""Arcee's Domain Adapted Language Models (DALMs).
|
||||
|
||||
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
|
||||
or pass ``arcee_api_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Arcee
|
||||
|
||||
arcee = Arcee(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_key="ARCEE-API-KEY"
|
||||
)
|
||||
|
||||
response = arcee("AI-driven music therapy")
|
||||
"""
|
||||
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee _client."""
|
||||
|
||||
arcee_api_key: str = ""
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
"""Arcee DALM name"""
|
||||
|
||||
arcee_api_url: str = "https://api.arcee.ai"
|
||||
"""Arcee API URL"""
|
||||
|
||||
arcee_api_version: str = "v2"
|
||||
"""Arcee API Version"""
|
||||
|
||||
arcee_app_url: str = "https://app.arcee.ai"
|
||||
"""Arcee App URL"""
|
||||
|
||||
model_id: str = ""
|
||||
"""Arcee Model ID"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "arcee"
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes private fields."""
|
||||
|
||||
super().__init__(**data)
|
||||
self._client = None
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=self.arcee_api_key,
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
model_name=self.model,
|
||||
)
|
||||
|
||||
self._client.validate_model_training_status()
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
# validate env vars
|
||||
values["arcee_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_url",
|
||||
"ARCEE_API_URL",
|
||||
)
|
||||
|
||||
values["arcee_app_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_app_url",
|
||||
"ARCEE_APP_URL",
|
||||
)
|
||||
|
||||
values["arcee_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_version",
|
||||
"ARCEE_API_VERSION",
|
||||
)
|
||||
|
||||
# validate model kwargs
|
||||
if values["model_kwargs"]:
|
||||
kw = values["model_kwargs"]
|
||||
|
||||
# validate size
|
||||
if kw.get("size") is not None:
|
||||
if not kw.get("size") >= 0:
|
||||
raise ValueError("`size` must be positive")
|
||||
|
||||
# validate filters
|
||||
if kw.get("filters") is not None:
|
||||
if not isinstance(kw.get("filters"), List):
|
||||
raise ValueError("`filters` must be a list")
|
||||
for f in kw.get("filters"):
|
||||
DALMFilter(**f)
|
||||
|
||||
return values
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text from Arcee DALM.
|
||||
|
||||
Args:
|
||||
prompt: Prompt to generate text from.
|
||||
size: The max number of context results to retrieve.
|
||||
Defaults to 3. (Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not self._client:
|
||||
raise ValueError("Client is not initialized.")
|
||||
return self._client.generate(prompt=prompt, **kwargs)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to generate text: {e}") from e
|
@ -0,0 +1,136 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.utilities.arcee import ArceeWrapper, DALMFilter
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class ArceeRetriever(BaseRetriever):
|
||||
"""Document retriever for Arcee's Domain Adapted Language Models (DALMs).
|
||||
|
||||
To use, set the ``ARCEE_API_KEY`` environment variable with your Arcee API key,
|
||||
or pass ``arcee_api_key`` as a named parameter.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.retrievers import ArceeRetriever
|
||||
|
||||
retriever = ArceeRetriever(
|
||||
model="DALM-PubMed",
|
||||
arcee_api_key="ARCEE-API-KEY"
|
||||
)
|
||||
|
||||
documents = retriever.get_relevant_documents("AI-driven music therapy")
|
||||
"""
|
||||
|
||||
_client: Optional[ArceeWrapper] = None #: :meta private:
|
||||
"""Arcee client."""
|
||||
|
||||
arcee_api_key: str = ""
|
||||
"""Arcee API Key"""
|
||||
|
||||
model: str
|
||||
"""Arcee DALM name"""
|
||||
|
||||
arcee_api_url: str = "https://api.arcee.ai"
|
||||
"""Arcee API URL"""
|
||||
|
||||
arcee_api_version: str = "v2"
|
||||
"""Arcee API Version"""
|
||||
|
||||
arcee_app_url: str = "https://app.arcee.ai"
|
||||
"""Arcee App URL"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes private fields."""
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
self._client = ArceeWrapper(
|
||||
arcee_api_key=self.arcee_api_key,
|
||||
arcee_api_url=self.arcee_api_url,
|
||||
arcee_api_version=self.arcee_api_version,
|
||||
model_kwargs=self.model_kwargs,
|
||||
model_name=self.model,
|
||||
)
|
||||
|
||||
self._client.validate_model_training_status()
|
||||
|
||||
@root_validator()
|
||||
def validate_environments(cls, values: Dict) -> Dict:
|
||||
"""Validate Arcee environment variables."""
|
||||
|
||||
# validate env vars
|
||||
values["arcee_api_key"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_key",
|
||||
"ARCEE_API_KEY",
|
||||
)
|
||||
|
||||
values["arcee_api_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_url",
|
||||
"ARCEE_API_URL",
|
||||
)
|
||||
|
||||
values["arcee_app_url"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_app_url",
|
||||
"ARCEE_APP_URL",
|
||||
)
|
||||
|
||||
values["arcee_api_version"] = get_from_dict_or_env(
|
||||
values,
|
||||
"arcee_api_version",
|
||||
"ARCEE_API_VERSION",
|
||||
)
|
||||
|
||||
# validate model kwargs
|
||||
if values["model_kwargs"]:
|
||||
kw = values["model_kwargs"]
|
||||
|
||||
# validate size
|
||||
if kw.get("size") is not None:
|
||||
if not kw.get("size") >= 0:
|
||||
raise ValueError("`size` must not be negative.")
|
||||
|
||||
# validate filters
|
||||
if kw.get("filters") is not None:
|
||||
if not isinstance(kw.get("filters"), List):
|
||||
raise ValueError("`filters` must be a list.")
|
||||
for f in kw.get("filters"):
|
||||
DALMFilter(**f)
|
||||
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
"""Retrieve {size} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
size: The max number of context results to retrieve.
|
||||
Defaults to 3. (Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not self._client:
|
||||
raise ValueError("Client is not initialized.")
|
||||
return self._client.retrieve(query=query, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error while retrieving documents: {e}") from e
|
@ -0,0 +1,189 @@
|
||||
# This module contains utility classes and functions for interacting with Arcee API.
|
||||
# For more information and updates, refer to the Arcee utils page:
|
||||
# [https://github.com/arcee-ai/arcee-python/blob/main/arcee/dalm.py]
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.retriever import Document
|
||||
|
||||
|
||||
class ArceeRoute(str, Enum):
|
||||
generate = "models/generate"
|
||||
retrieve = "models/retrieve"
|
||||
model_training_status = "models/status/{id_or_name}"
|
||||
|
||||
|
||||
class DALMFilterType(str, Enum):
|
||||
fuzzy_search = "fuzzy_search"
|
||||
strict_search = "strict_search"
|
||||
|
||||
|
||||
class DALMFilter(BaseModel):
|
||||
"""Filters available for a dalm retrieval and generation
|
||||
|
||||
Arguments:
|
||||
field_name: The field to filter on. Can be 'document' or 'name' to filter
|
||||
on your document's raw text or title. Any other field will be presumed
|
||||
to be a metadata field you included when uploading your context data
|
||||
filter_type: Currently 'fuzzy_search' and 'strict_search' are supported.
|
||||
'fuzzy_search' means a fuzzy search on the provided field is performed.
|
||||
The exact strict doesn't need to exist in the document
|
||||
for this to find a match.
|
||||
Very useful for scanning a document for some keyword terms.
|
||||
'strict_search' means that the exact string must appear
|
||||
in the provided field.
|
||||
This is NOT an exact eq filter. ie a document with content
|
||||
"the happy dog crossed the street" will match on a strict_search of
|
||||
"dog" but won't match on "the dog".
|
||||
Python equivalent of `return search_string in full_string`.
|
||||
value: The actual value to search for in the context data/metadata
|
||||
"""
|
||||
|
||||
field_name: str
|
||||
filter_type: DALMFilterType
|
||||
value: str
|
||||
_is_metadata: bool = False
|
||||
|
||||
@root_validator()
|
||||
def set_meta(cls, values: Dict) -> Dict:
|
||||
"""document and name are reserved arcee keys. Anything else is metadata"""
|
||||
values["_is_meta"] = values.get("field_name") not in ["document", "name"]
|
||||
return values
|
||||
|
||||
|
||||
class ArceeWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
arcee_api_key: str,
|
||||
arcee_api_url: str,
|
||||
arcee_api_version: str,
|
||||
model_kwargs: Optional[Dict[str, Any]],
|
||||
model_name: str,
|
||||
):
|
||||
self.arcee_api_key = arcee_api_key
|
||||
self.model_kwargs = model_kwargs
|
||||
self.arcee_api_url = arcee_api_url
|
||||
self.arcee_api_version = arcee_api_version
|
||||
|
||||
try:
|
||||
route = ArceeRoute.model_training_status.value.format(id_or_name=model_name)
|
||||
response = self._make_request("get", route)
|
||||
self.model_id = response.get("model_id")
|
||||
self.model_training_status = response.get("status")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error while validating model training status for '{model_name}': {e}"
|
||||
) from e
|
||||
|
||||
def validate_model_training_status(self) -> None:
|
||||
if self.model_training_status != "training_complete":
|
||||
raise Exception(
|
||||
f"Model {self.model_id} is not ready. "
|
||||
"Please wait for training to complete."
|
||||
)
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: Literal["post", "get"],
|
||||
route: Union[ArceeRoute, str],
|
||||
body: Optional[Mapping[str, Any]] = None,
|
||||
params: Optional[dict] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> dict:
|
||||
"""Make a request to the Arcee API
|
||||
Args:
|
||||
method: The HTTP method to use
|
||||
route: The route to call
|
||||
body: The body of the request
|
||||
params: The query params of the request
|
||||
headers: The headers of the request
|
||||
"""
|
||||
headers = self._make_request_headers(headers=headers)
|
||||
url = self._make_request_url(route=route)
|
||||
|
||||
req_type = getattr(requests, method)
|
||||
|
||||
response = req_type(url, json=body, params=params, headers=headers)
|
||||
if response.status_code not in (200, 201):
|
||||
raise Exception(f"Failed to make request. Response: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict:
|
||||
headers = headers or {}
|
||||
internal_headers = {
|
||||
"X-Token": self.arcee_api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(internal_headers)
|
||||
return headers
|
||||
|
||||
def _make_request_url(self, route: Union[ArceeRoute, str]) -> str:
|
||||
return f"{self.arcee_api_url}/{self.arcee_api_version}/{route}"
|
||||
|
||||
def _make_request_body_for_models(
|
||||
self, prompt: str, **kwargs: Mapping[str, Any]
|
||||
) -> Mapping[str, Any]:
|
||||
"""Make the request body for generate/retrieve models endpoint"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_params = {**_model_kwargs, **kwargs}
|
||||
|
||||
filters = [DALMFilter(**f) for f in _params.get("filters", [])]
|
||||
return dict(
|
||||
model_id=self.model_id,
|
||||
query=prompt,
|
||||
size=_params.get("size", 3),
|
||||
filters=filters,
|
||||
id=self.model_id,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Generate text from Arcee DALM.
|
||||
|
||||
Args:
|
||||
prompt: Prompt to generate text from.
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.generate,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return response["text"]
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve {size} contexts with your retriever for a given query
|
||||
|
||||
Args:
|
||||
query: Query to submit to the model
|
||||
size: The max number of context results to retrieve. Defaults to 3.
|
||||
(Can be less if filters are provided).
|
||||
filters: Filters to apply to the context dataset.
|
||||
"""
|
||||
|
||||
response = self._make_request(
|
||||
method="post",
|
||||
route=ArceeRoute.retrieve,
|
||||
body=self._make_request_body_for_models(
|
||||
prompt=query,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
return [Document(**doc) for doc in response["documents"]]
|
Loading…
Reference in New Issue