Adds support for llama2 and fixes MPT-7b url (#11465)

- **Description:** This is an update to OctoAI LLM provider that adds
support for llama2 endpoints hosted on OctoAI and updates MPT-7b url
with the current one.
@baskaryan
Thanks!

---------

Co-authored-by: ML Wiz <bassemgeorgi@gmail.com>
pull/11646/head
Bassem Yacoube 11 months ago committed by GitHub
parent 0bff399af1
commit 5451b724fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@
"import os\n", "import os\n",
"\n", "\n",
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n", "os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\"" "os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate\""
] ]
}, },
{ {

@ -23,7 +23,7 @@ class OctoAIEndpoint(LLM):
from langchain.llms.octoai_endpoint import OctoAIEndpoint from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint( OctoAIEndpoint(
octoai_api_token="octoai-api-key", octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={ model_kwargs={
"max_new_tokens": 200, "max_new_tokens": 200,
"temperature": 0.75, "temperature": 0.75,
@ -34,6 +34,24 @@ class OctoAIEndpoint(LLM):
}, },
) )
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions",
model_kwargs={
"model": "llama-2-7b-chat",
"messages": [
{
"role": "system",
"content": "Below is an instruction that describes a task.
Write a response that completes the request."
}
],
"stream": False,
"max_tokens": 256
}
)
""" """
endpoint_url: Optional[str] = None endpoint_url: Optional[str] = None
@ -45,6 +63,9 @@ class OctoAIEndpoint(LLM):
octoai_api_token: Optional[str] = None octoai_api_token: Optional[str] = None
"""OCTOAI API Token""" """OCTOAI API Token"""
streaming: bool = False
"""Whether to generate a stream of tokens asynchronously"""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -96,18 +117,27 @@ class OctoAIEndpoint(LLM):
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
try: try:
# Initialize the OctoAI client # Initialize the OctoAI client
from octoai import client from octoai import client
octoai_client = client.Client(token=self.octoai_api_token) octoai_client = client.Client(token=self.octoai_api_token)
# Send the request using the OctoAI client if "model" in _model_kwargs and "llama-2" in _model_kwargs["model"]:
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload) parameter_payload = _model_kwargs
text = resp_json["generated_text"] parameter_payload["messages"].append(
{"role": "user", "content": prompt}
)
# Send the request using the OctoAI client
output = octoai_client.infer(self.endpoint_url, parameter_payload)
text = output.get("choices")[0].get("message").get("content")
else:
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e: except Exception as e:
# Handle any errors raised by the inference endpoint # Handle any errors raised by the inference endpoint

@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None: def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model.""" """Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint( llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>", octoai_api_token="<octoai_api_token>",
model_kwargs={ model_kwargs={
"max_new_tokens": 200, "max_new_tokens": 200,
@ -32,7 +32,7 @@ def test_octoai_endpoint_text_generation() -> None:
def test_octoai_endpoint_call_error() -> None: def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors.""" """Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint( llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
model_kwargs={"max_new_tokens": -1}, model_kwargs={"max_new_tokens": -1},
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -42,7 +42,7 @@ def test_octoai_endpoint_call_error() -> None:
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None: def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an OctoAIHub LLM.""" """Test saving/loading an OctoAIHub LLM."""
llm = OctoAIEndpoint( llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate",
octoai_api_token="<octoai_api_token>", octoai_api_token="<octoai_api_token>",
model_kwargs={ model_kwargs={
"max_new_tokens": 200, "max_new_tokens": 200,

Loading…
Cancel
Save