Adds support for llama2 and fixes MPT-7b url (#11465)

- **Description:** This is an update to OctoAI LLM provider that adds
support for llama2 endpoints hosted on OctoAI and updates MPT-7b url
with the current one.
@baskaryan
Thanks!

---------

Co-authored-by: ML Wiz <bassemgeorgi@gmail.com>
pull/11646/head
Bassem Yacoube 9 months ago committed by GitHub
parent 0bff399af1
commit 5451b724fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@
"import os\n",
"\n",
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\""
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate\""
]
},
{

@ -23,7 +23,7 @@ class OctoAIEndpoint(LLM):
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
@ -34,6 +34,24 @@ class OctoAIEndpoint(LLM):
},
)
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
model_kwargs={
"model": "llama-2-7b-chat",
"messages": [
{
"role": "system",
"content": "Below is an instruction that describes a task.
Write a response that completes the request."
}
],
"stream": False,
"max_tokens": 256
}
)
"""
endpoint_url: Optional[str] = None
@ -45,6 +63,9 @@ class OctoAIEndpoint(LLM):
octoai_api_token: Optional[str] = None
"""OCTOAI API Token"""
streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
class Config:
"""Configuration for this pydantic object."""
@ -96,18 +117,27 @@ class OctoAIEndpoint(LLM):
"""
_model_kwargs = self.model_kwargs or {}
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
try:
# Initialize the OctoAI client
from octoai import client
octoai_client = client.Client(token=self.octoai_api_token)
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
if "model" in _model_kwargs and "llama-2" in _model_kwargs["model"]:
parameter_payload = _model_kwargs
parameter_payload["messages"].append(
{"role": "user", "content": prompt}
)
# Send the request using the OctoAI client
output = octoai_client.infer(self.endpoint_url, parameter_payload)
text = output.get("choices")[0].get("message").get("content")
else:
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e:
# Handle any errors raised by the inference endpoint

@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
@ -32,7 +32,7 @@ def test_octoai_endpoint_text_generation() -> None:
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
@ -42,7 +42,7 @@ def test_octoai_endpoint_call_error() -> None:
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an OctoAIHub LLM."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,

Loading…
Cancel
Save