community[patch]: update OctoAI endpoint to subclass BaseOpenAI (#19757)

This PR updates OctoAIEndpoint LLM to subclass BaseOpenAI as OctoAI is
an OpenAI-compatible service. The documentation and tests have also been
updated.
pull/18775/head
Sevin F. Varoglu 2 months ago committed by GitHub
parent 0c95ddbcd8
commit 54d388d898
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -18,7 +18,7 @@
" \n",
"2. Paste your API key in in the code cell below.\n",
"\n",
"Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then update your Endpoint URL in the code cell below.\n"
"Note: If you want to use a different LLM model, you can containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](https://octo.ai/docs/bring-your-own-model/advanced-build-a-container-from-scratch-in-python) and [Create a Custom Endpoint from a Container](https://octo.ai/docs/bring-your-own-model/create-custom-endpoints-from-a-container/create-custom-endpoints-from-a-container) and then updating your `OCTOAI_API_BASE` environment variable.\n"
]
},
{
@ -29,8 +29,7 @@
"source": [
"import os\n",
"\n",
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
"os.environ[\"ENDPOINT_URL\"] = \"https://text.octoai.run/v1/chat/completions\""
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\""
]
},
{
@ -68,44 +67,33 @@
"outputs": [],
"source": [
"llm = OctoAIEndpoint(\n",
" model_kwargs={\n",
" \"model\": \"llama-2-13b-chat-fp16\",\n",
" \"max_tokens\": 128,\n",
" \"presence_penalty\": 0,\n",
" \"temperature\": 0.1,\n",
" \"top_p\": 0.9,\n",
" \"messages\": [\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a helpful assistant. Keep your responses limited to one short paragraph if possible.\",\n",
" },\n",
" ],\n",
" },\n",
" model=\"llama-2-13b-chat-fp16\",\n",
" max_tokens=200,\n",
" presence_penalty=0,\n",
" temperature=0.1,\n",
" top_p=0.9,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Sure thing! Here's my response:\n",
"\n",
"Leonardo da Vinci was a true Renaissance man - an Italian polymath who excelled in various fields, including painting, sculpture, engineering, mathematics, anatomy, and geology. He is widely considered one of the greatest painters of all time, and his inventive and innovative works continue to inspire and influence artists and thinkers to this day. Some of his most famous works include the Mona Lisa, The Last Supper, and Vitruvian Man. \n"
]
}
],
"outputs": [],
"source": [
"question = \"Who was leonardo davinci?\"\n",
"question = \"Who was Leonardo da Vinci?\"\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"print(llm_chain.run(question))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Leonardo da Vinci was a true Renaissance man. He was born in 1452 in Vinci, Italy and was known for his work in various fields, including art, science, engineering, and mathematics. He is considered one of the greatest painters of all time, and his most famous works include the Mona Lisa and The Last Supper. In addition to his art, da Vinci made significant contributions to engineering and anatomy, and his designs for machines and inventions were centuries ahead of his time. He is also known for his extensive journals and drawings, which provide valuable insights into his thoughts and ideas. Da Vinci's legacy continues to inspire and influence artists, scientists, and thinkers around the world today."
]
}
],
"metadata": {

@ -1003,6 +1003,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"oci_model_deployment_tgi_endpoint": _import_oci_md_tgi,
"oci_model_deployment_vllm_endpoint": _import_oci_md_vllm,
"oci_generative_ai": _import_oci_gen_ai,
"octoai_endpoint": _import_octoai_endpoint,
"ollama": _import_ollama,
"openai": _import_openai,
"openlm": _import_openlm,

@ -1,166 +1,117 @@
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_community.llms.utils import enforce_stop_tokens
from langchain_community.llms.openai import BaseOpenAI
from langchain_community.utils.openai import is_openai_v1
DEFAULT_BASE_URL = "https://text.octoai.run/v1/"
DEFAULT_MODEL = "codellama-7b-instruct"
class OctoAIEndpoint(LLM):
"""OctoAI LLM Endpoints.
OctoAIEndpoint is a class to interact with OctoAI
Compute Service large language model endpoints.
class OctoAIEndpoint(BaseOpenAI):
"""OctoAI LLM Endpoints - OpenAI compatible.
To use, you should have the ``octoai`` python package installed, and the
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
OctoAIEndpoint is a class to interact with OctoAI Compute Service large
language model endpoints.
To use, you should have the environment variable ``OCTOAI_API_TOKEN`` set
with your API token, or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://text.octoai.run/v1/chat/completions",
model_kwargs={
"model": "llama-2-13b-chat-fp16",
"messages": [
{
"role": "system",
"content": "Below is an instruction that describes a task.
Write a response that completes the request."
}
],
"stream": False,
"max_tokens": 256,
"presence_penalty": 0,
"temperature": 0.1,
"top_p": 0.9
}
llm = OctoAIEndpoint(
model="llama-2-13b-chat-fp16",
max_tokens=200,
presence_penalty=0,
temperature=0.1,
top_p=0.9,
)
"""
endpoint_url: Optional[str] = None
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
octoai_api_token: Optional[str] = None
"""OCTOAI API Token"""
streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
"""Key word arguments to pass to the model."""
octoai_api_base: str = Field(default=DEFAULT_BASE_URL)
octoai_api_token: SecretStr = Field(default=None)
model_name: str = Field(default=DEFAULT_MODEL)
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
octoai_api_token = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN"
)
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL"
)
values["octoai_api_token"] = octoai_api_token
return values
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_url": self.endpoint_url},
**{"model_kwargs": _model_kwargs},
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
params: Dict[str, Any] = {
"model": self.model_name,
**self._default_params,
}
if not is_openai_v1():
params.update(
{
"api_key": self.octoai_api_token.get_secret_value(),
"api_base": self.octoai_api_base,
}
)
return {**params, **super()._invocation_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "octoai_endpoint"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to OctoAI's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
_model_kwargs = self.model_kwargs or {}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["octoai_api_base"] = get_from_dict_or_env(
values,
"octoai_api_base",
"OCTOAI_API_BASE",
default=DEFAULT_BASE_URL,
)
values["octoai_api_token"] = convert_to_secret_str(
get_from_dict_or_env(values, "octoai_api_token", "OCTOAI_API_TOKEN")
)
values["model_name"] = get_from_dict_or_env(
values,
"model_name",
"MODEL_NAME",
default=DEFAULT_MODEL,
)
try:
from octoai import client
# Initialize the OctoAI client
octoai_client = client.Client(token=self.octoai_api_token)
if "model" in _model_kwargs:
parameter_payload = _model_kwargs
sys_msg = None
if "messages" in parameter_payload:
msgs = parameter_payload.get("messages", [])
for msg in msgs:
if msg.get("role") == "system":
sys_msg = msg.get("content")
# Reset messages list
parameter_payload["messages"] = []
# Append system message if exists
if sys_msg:
parameter_payload["messages"].append(
{"role": "system", "content": sys_msg}
)
# Append user message
parameter_payload["messages"].append(
{"role": "user", "content": prompt}
)
# Send the request using the OctoAI client
try:
output = octoai_client.infer(self.endpoint_url, parameter_payload)
if output and "choices" in output and len(output["choices"]) > 0:
text = output["choices"][0].get("message", {}).get("content")
else:
text = "Error: Invalid response format or empty choices."
except Exception as e:
text = f"Error during API call: {str(e)}"
import openai
if is_openai_v1():
client_params = {
"api_key": values["octoai_api_token"].get_secret_value(),
"base_url": values["octoai_api_base"],
}
if not values.get("client"):
values["client"] = openai.OpenAI(**client_params).completions
if not values.get("async_client"):
values["async_client"] = openai.AsyncOpenAI(
**client_params
).completions
else:
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
values["openai_api_base"] = values["octoai_api_base"]
values["openai_api_key"] = values["octoai_api_token"].get_secret_value()
values["client"] = openai.Completion
except ImportError:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`."
)
if stop is not None:
# Apply stop tokens when making calls to OctoAI
text = enforce_stop_tokens(text, stop)
if "endpoint_url" in values["model_kwargs"]:
raise ValueError(
"`endpoint_url` was deprecated, please use `octoai_api_base`."
)
return text
return values

@ -1,58 +1,11 @@
"""Test OctoAI API wrapper."""
from pathlib import Path
import pytest
from langchain_community.llms.loading import load_llm
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
def test_octoai_endpoint_call() -> None:
"""Test valid call to OctoAI endpoint."""
llm = OctoAIEndpoint()
output = llm("Which state is Los Angeles in?")
print(output) # noqa: T201
assert isinstance(output, str)
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
llm("Which state is Los Angeles in?")
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an OctoAIHub LLM."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
llm.save(file_path=tmp_path / "octoai.yaml")
loaded_llm = load_llm(tmp_path / "octoai.yaml")
assert_llm_equality(llm, loaded_llm)

Loading…
Cancel
Save