From 5451b724fc68b481c61e80b68f4dcfa9bbbe6a64 Mon Sep 17 00:00:00 2001 From: Bassem Yacoube <125713079+AI-Bassem@users.noreply.github.com> Date: Wed, 11 Oct 2023 03:34:35 +0000 Subject: [PATCH] Adds support for llama2 and fixes MPT-7b url (#11465) - **Description:** This is an update to OctoAI LLM provider that adds support for llama2 endpoints hosted on OctoAI and updates MPT-7b url with the current one. @baskaryan Thanks! --------- Co-authored-by: ML Wiz --- docs/docs/integrations/llms/octoai.ipynb | 2 +- .../langchain/llms/octoai_endpoint.py | 44 ++++++++++++++++--- .../llms/test_octoai_endpoint.py | 6 +-- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/docs/docs/integrations/llms/octoai.ipynb b/docs/docs/integrations/llms/octoai.ipynb index 79324cd9ac..c9bcff7abe 100644 --- a/docs/docs/integrations/llms/octoai.ipynb +++ b/docs/docs/integrations/llms/octoai.ipynb @@ -33,7 +33,7 @@ "import os\n", "\n", "os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n", - "os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\"" + "os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate\"" ] }, { diff --git a/libs/langchain/langchain/llms/octoai_endpoint.py b/libs/langchain/langchain/llms/octoai_endpoint.py index adc2bf89bc..055ab8ac75 100644 --- a/libs/langchain/langchain/llms/octoai_endpoint.py +++ b/libs/langchain/langchain/llms/octoai_endpoint.py @@ -23,7 +23,7 @@ class OctoAIEndpoint(LLM): from langchain.llms.octoai_endpoint import OctoAIEndpoint OctoAIEndpoint( octoai_api_token="octoai-api-key", - endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate", model_kwargs={ "max_new_tokens": 200, "temperature": 0.75, @@ -34,6 +34,24 @@ class OctoAIEndpoint(LLM): }, ) + from langchain.llms.octoai_endpoint import OctoAIEndpoint + OctoAIEndpoint( + octoai_api_token="octoai-api-key", + endpoint_url="https://llama-2-7b-chat-demo-kk0powt97tmb.octoai.run/v1/chat/completions", + model_kwargs={ + "model": "llama-2-7b-chat", + "messages": [ + { + "role": "system", + "content": "Below is an instruction that describes a task. + Write a response that completes the request." + } + ], + "stream": False, + "max_tokens": 256 + } + ) + """ endpoint_url: Optional[str] = None @@ -45,6 +63,9 @@ class OctoAIEndpoint(LLM): octoai_api_token: Optional[str] = None """OCTOAI API Token""" + streaming: bool = False + """Whether to generate a stream of tokens asynchronously""" + class Config: """Configuration for this pydantic object.""" @@ -96,18 +117,27 @@ class OctoAIEndpoint(LLM): """ _model_kwargs = self.model_kwargs or {} - # Prepare the payload JSON - parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} - try: # Initialize the OctoAI client from octoai import client octoai_client = client.Client(token=self.octoai_api_token) - # Send the request using the OctoAI client - resp_json = octoai_client.infer(self.endpoint_url, parameter_payload) - text = resp_json["generated_text"] + if "model" in _model_kwargs and "llama-2" in _model_kwargs["model"]: + parameter_payload = _model_kwargs + parameter_payload["messages"].append( + {"role": "user", "content": prompt} + ) + # Send the request using the OctoAI client + output = octoai_client.infer(self.endpoint_url, parameter_payload) + text = output.get("choices")[0].get("message").get("content") + else: + # Prepare the payload JSON + parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} + + # Send the request using the OctoAI client + resp_json = octoai_client.infer(self.endpoint_url, parameter_payload) + text = resp_json["generated_text"] except Exception as e: # Handle any errors raised by the inference endpoint diff --git a/libs/langchain/tests/integration_tests/llms/test_octoai_endpoint.py b/libs/langchain/tests/integration_tests/llms/test_octoai_endpoint.py index 0533eac5d4..f4ea11d9c6 100644 --- a/libs/langchain/tests/integration_tests/llms/test_octoai_endpoint.py +++ b/libs/langchain/tests/integration_tests/llms/test_octoai_endpoint.py @@ -12,7 +12,7 @@ from tests.integration_tests.llms.utils import assert_llm_equality def test_octoai_endpoint_text_generation() -> None: """Test valid call to OctoAI text generation model.""" llm = OctoAIEndpoint( - endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate", octoai_api_token="", model_kwargs={ "max_new_tokens": 200, @@ -32,7 +32,7 @@ def test_octoai_endpoint_text_generation() -> None: def test_octoai_endpoint_call_error() -> None: """Test valid call to OctoAI that errors.""" llm = OctoAIEndpoint( - endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate", model_kwargs={"max_new_tokens": -1}, ) with pytest.raises(ValueError): @@ -42,7 +42,7 @@ def test_octoai_endpoint_call_error() -> None: def test_saving_loading_endpoint_llm(tmp_path: Path) -> None: """Test saving/loading an OctoAIHub LLM.""" llm = OctoAIEndpoint( - endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate", + endpoint_url="https://mpt-7b-demo-f1kzsig6xes9.octoai.run/generate", octoai_api_token="", model_kwargs={ "max_new_tokens": 200,