From 57fbc6bdf162e09394b6392f69d06ec11c1f8898 Mon Sep 17 00:00:00 2001 From: Qiu Qin Date: Tue, 15 Oct 2024 11:32:54 -0400 Subject: [PATCH] community: Update OCI data science integration (#27083) This PR updates the integration with OCI data science model deployment service. - Update LLM to support streaming and async calls. - Added chat model. - Updated tests and docs. - Updated `libs/community/scripts/check_pydantic.sh` since the use of `@pre_init` is removed from existing integration. - Updated `libs/community/extended_testing_deps.txt` as this integration requires `langchain_openai`. --------- Co-authored-by: MING KANG Co-authored-by: Dmitrii Cherkasov Co-authored-by: Erick Friis --- .../integrations/chat/oci_data_science.ipynb | 460 ++++++++ .../llms/oci_model_deployment_endpoint.ipynb | 105 +- docs/docs/integrations/providers/oci.mdx | 10 +- libs/community/extended_testing_deps.txt | 1 + .../chat_models/__init__.py | 11 + .../chat_models/oci_data_science.py | 998 ++++++++++++++++++ .../langchain_community/llms/__init__.py | 12 + ..._data_science_model_deployment_endpoint.py | 882 +++++++++++++--- libs/community/scripts/check_pydantic.sh | 2 +- .../unit_tests/chat_models/test_imports.py | 3 + .../chat_models/test_oci_data_science.py | 193 ++++ .../test_oci_model_deployment_endpoint.py | 98 ++ .../tests/unit_tests/llms/test_imports.py | 1 + .../test_oci_model_deployment_endpoint.py | 195 +++- 14 files changed, 2762 insertions(+), 209 deletions(-) create mode 100644 docs/docs/integrations/chat/oci_data_science.ipynb create mode 100644 libs/community/langchain_community/chat_models/oci_data_science.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_oci_data_science.py create mode 100644 libs/community/tests/unit_tests/chat_models/test_oci_model_deployment_endpoint.py diff --git a/docs/docs/integrations/chat/oci_data_science.ipynb b/docs/docs/integrations/chat/oci_data_science.ipynb new file mode 100644 index 0000000000..fdc1d8cda4 --- /dev/null +++ b/docs/docs/integrations/chat/oci_data_science.ipynb @@ -0,0 +1,460 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: ChatOCIModelDeployment\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatOCIModelDeployment\n", + "\n", + "This will help you getting started with OCIModelDeployment [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatOCIModelDeployment features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.ChatOCIModelDeployment.html).\n", + "\n", + "[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure. You can use [AI Quick Actions](https://blogs.oracle.com/ai-and-datascience/post/ai-quick-actions-in-oci-data-science) to easily deploy LLMs on [OCI Data Science Model Deployment Service](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm). You may choose to deploy the model with popular inference frameworks such as vLLM or TGI. By default, the model deployment endpoint mimics the OpenAI API protocol.\n", + "\n", + "> For the latest updates, examples and experimental features, please see [ADS LangChain Integration](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/large_language_model/langchain_models.html).\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatOCIModelDeployment](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.ChatOCIModelDeployment.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | beta | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-community?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| depends | depends | depends | depends | depends | depends | ✅ | ✅ | ✅ | ✅ | \n", + "\n", + "Some model features, including tool calling, structured output, JSON mode and multi-modal inputs, are depending on deployed model.\n", + "\n", + "\n", + "## Setup\n", + "\n", + "To use ChatOCIModelDeployment you'll need to deploy a chat model with chat completion endpoint and install the `langchain-community`, `langchain-openai` and `oracle-ads` integration packages.\n", + "\n", + "You can easily deploy foundation models using the [AI Quick Actions](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/model-deployment-tips.md) on OCI Data Science Model deployment. For additional deployment examples, please visit the [Oracle GitHub samples repository](https://github.com/oracle-samples/oci-data-science-ai-samples/tree/main/ai-quick-actions).\n", + "\n", + "### Policies\n", + "Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint) to access the OCI Data Science Model Deployment endpoint.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Credentials\n", + "\n", + "You can set authentication through Oracle ADS. When you are working in OCI Data Science Notebook Session, you can leverage resource principal to access other OCI resources." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "\n", + "# Set authentication through ads\n", + "# Use resource principal are operating within a\n", + "# OCI service that has resource principal based\n", + "# authentication configured\n", + "ads.set_auth(\"resource_principal\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, you can configure the credentials using the following environment variables. For example, to use API key with specific profile:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Set authentication through environment variables\n", + "# Use API Key setup when you are working from a local\n", + "# workstation or on platform which does not support\n", + "# resource principals.\n", + "os.environ[\"OCI_IAM_TYPE\"] = \"api_key\"\n", + "os.environ[\"OCI_CONFIG_PROFILE\"] = \"default\"\n", + "os.environ[\"OCI_CONFIG_LOCATION\"] = \"~/.oci\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check out [Oracle ADS docs](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) to see more options." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain OCIModelDeployment integration lives in the `langchain-community` package. The following command will install `langchain-community` and the required dependencies." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain-community langchain-openai oracle-ads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "You may instantiate the model with the generic `ChatOCIModelDeployment` or framework specific class like `ChatOCIModelDeploymentVLLM`.\n", + "\n", + "* Using `ChatOCIModelDeployment` when you need a generic entry point for deploying models. You can pass model parameters through `model_kwargs` during the instantiation of this class. This allows for flexibility and ease of configuration without needing to rely on framework-specific details." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatOCIModelDeployment\n", + "\n", + "# Create an instance of OCI Model Deployment Endpoint\n", + "# Replace the endpoint uri with your own\n", + "# Using generic class as entry point, you will be able\n", + "# to pass model parameters through model_kwargs during\n", + "# instantiation.\n", + "chat = ChatOCIModelDeployment(\n", + " endpoint=\"https://modeldeployment..oci.customer-oci.com//predict\",\n", + " streaming=True,\n", + " max_retries=1,\n", + " model_kwargs={\n", + " \"temperature\": 0.2,\n", + " \"max_tokens\": 512,\n", + " }, # other model params...\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Using framework specific class like `ChatOCIModelDeploymentVLLM`: This is suitable when you are working with a specific framework (e.g. `vLLM`) and need to pass model parameters directly through the constructor, streamlining the setup process." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatOCIModelDeploymentVLLM\n", + "\n", + "# Create an instance of OCI Model Deployment Endpoint\n", + "# Replace the endpoint uri with your own\n", + "# Using framework specific class as entry point, you will\n", + "# be able to pass model parameters in constructor.\n", + "chat = ChatOCIModelDeploymentVLLM(\n", + " endpoint=\"https://modeldeployment..oci.customer-oci.com//predict\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"J'adore programmer.\", response_metadata={'token_usage': {'prompt_tokens': 44, 'total_tokens': 52, 'completion_tokens': 8}, 'model_name': 'odsc-llm', 'system_fingerprint': '', 'finish_reason': 'stop'}, id='run-ca145168-efa9-414c-9dd1-21d10766fdd3-0')" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + " ),\n", + " (\"human\", \"I love programming.\"),\n", + "]\n", + "\n", + "ai_msg = chat.invoke(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "J'adore programmer.\n" + ] + } + ], + "source": [ + "print(ai_msg.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chaining" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Ich liebe Programmierung.', response_metadata={'token_usage': {'prompt_tokens': 38, 'total_tokens': 48, 'completion_tokens': 10}, 'model_name': 'odsc-llm', 'system_fingerprint': '', 'finish_reason': 'stop'}, id='run-5dd936b0-b97e-490e-9869-2ad3dd524234-0')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | chat\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Asynchronous calls" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='我喜欢编程', response_metadata={'token_usage': {'prompt_tokens': 37, 'total_tokens': 50, 'completion_tokens': 13}, 'model_name': 'odsc-llm', 'system_fingerprint': '', 'finish_reason': 'stop'}, id='run-a2dc9393-f269-41a4-b908-b1d8a92cf827-0')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.chat_models import ChatOCIModelDeployment\n", + "\n", + "system = \"You are a helpful translator that translates {input_language} to {output_language}.\"\n", + "human = \"{text}\"\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n", + "\n", + "chat = ChatOCIModelDeployment(\n", + " endpoint=\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict\"\n", + ")\n", + "chain = prompt | chat\n", + "\n", + "await chain.ainvoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"Chinese\",\n", + " \"text\": \"I love programming\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming calls" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "1. California\n", + "2. Texas\n", + "3. Florida\n", + "4. New York\n", + "5. Illinois" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "from langchain_community.chat_models import ChatOCIModelDeployment\n", + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"human\", \"List out the 5 states in the United State.\")]\n", + ")\n", + "\n", + "chat = ChatOCIModelDeployment(\n", + " endpoint=\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict\"\n", + ")\n", + "\n", + "chain = prompt | chat\n", + "\n", + "for chunk in chain.stream({}):\n", + " sys.stdout.write(chunk.content)\n", + " sys.stdout.flush()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structured output" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'setup': 'Why did the cat get stuck in the tree?',\n", + " 'punchline': 'Because it was chasing its tail!'}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_community.chat_models import ChatOCIModelDeployment\n", + "from pydantic import BaseModel\n", + "\n", + "\n", + "class Joke(BaseModel):\n", + " \"\"\"A setup to a joke and the punchline.\"\"\"\n", + "\n", + " setup: str\n", + " punchline: str\n", + "\n", + "\n", + "chat = ChatOCIModelDeployment(\n", + " endpoint=\"https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict\",\n", + ")\n", + "structured_llm = chat.with_structured_output(Joke, method=\"json_mode\")\n", + "output = structured_llm.invoke(\n", + " \"Tell me a joke about cats, respond in JSON with `setup` and `punchline` keys\"\n", + ")\n", + "\n", + "output.dict()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For comprehensive details on all features and configurations, please refer to the API reference documentation for each class:\n", + "\n", + "* [ChatOCIModelDeployment](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_data_science.ChatOCIModelDeployment.html)\n", + "* [ChatOCIModelDeploymentVLLM](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_data_science.ChatOCIModelDeploymentVLLM.html)\n", + "* [ChatOCIModelDeploymentTGI](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_data_science.ChatOCIModelDeploymentTGI.html)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/integrations/llms/oci_model_deployment_endpoint.ipynb b/docs/docs/integrations/llms/oci_model_deployment_endpoint.ipynb index b7be0db497..0110a9d575 100644 --- a/docs/docs/integrations/llms/oci_model_deployment_endpoint.ipynb +++ b/docs/docs/integrations/llms/oci_model_deployment_endpoint.ipynb @@ -8,9 +8,11 @@ "\n", "[OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a fully managed and serverless platform for data science teams to build, train, and manage machine learning models in the Oracle Cloud Infrastructure.\n", "\n", + "> For the latest updates, examples and experimental features, please see [ADS LangChain Integration](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/large_language_model/langchain_models.html).\n", + "\n", "This notebooks goes over how to use an LLM hosted on a [OCI Data Science Model Deployment](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm).\n", "\n", - "To authenticate, [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) has been used to automatically load credentials for invoking endpoint." + "For authentication, the [oracle-ads](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) library is used to automatically load credentials required for invoking the endpoint." ] }, { @@ -29,29 +31,52 @@ "## Prerequisite\n", "\n", "### Deploy model\n", - "Check [Oracle GitHub samples repository](https://github.com/oracle-samples/oci-data-science-ai-samples/tree/main/model-deployment/containers/llama2) on how to deploy your llm on OCI Data Science Model deployment.\n", + "You can easily deploy, fine-tune, and evaluate foundation models using the [AI Quick Actions](https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions.htm) on OCI Data Science Model deployment. For additional deployment examples, please visit the [Oracle GitHub samples repository](https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/llama3-with-smc.md). \n", "\n", "### Policies\n", "Make sure to have the required [policies](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint) to access the OCI Data Science Model Deployment endpoint.\n", "\n", "## Set up\n", "\n", - "### vLLM\n", - "After having deployed model, you have to set up following required parameters of the `OCIModelDeploymentVLLM` call:\n", + "After having deployed model, you have to set up following required parameters of the call:\n", "\n", - "- **`endpoint`**: The model HTTP endpoint from the deployed model, e.g. `https:///predict`. \n", - "- **`model`**: The location of the model.\n", + "- **`endpoint`**: The model HTTP endpoint from the deployed model, e.g. `https://modeldeployment..oci.customer-oci.com//predict`. \n", "\n", - "### Text generation inference (TGI)\n", - "You have to set up following required parameters of the `OCIModelDeploymentTGI` call:\n", - "\n", - "- **`endpoint`**: The model HTTP endpoint from the deployed model, e.g. `https:///predict`. \n", "\n", "### Authentication\n", "\n", "You can set authentication through either ads or environment variables. When you are working in OCI Data Science Notebook Session, you can leverage resource principal to access other OCI resources. Check out [here](https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html) to see more options. \n", "\n", - "## Example" + "## Examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import ads\n", + "from langchain_community.llms import OCIModelDeploymentLLM\n", + "\n", + "# Set authentication through ads\n", + "# Use resource principal are operating within a\n", + "# OCI service that has resource principal based\n", + "# authentication configured\n", + "ads.set_auth(\"resource_principal\")\n", + "\n", + "# Create an instance of OCI Model Deployment Endpoint\n", + "# Replace the endpoint uri and model name with your own\n", + "# Using generic class as entry point, you will be able\n", + "# to pass model parameters through model_kwargs during\n", + "# instantiation.\n", + "llm = OCIModelDeploymentLLM(\n", + " endpoint=\"https://modeldeployment..oci.customer-oci.com//predict\",\n", + " model=\"odsc-llm\",\n", + ")\n", + "\n", + "# Run the LLM\n", + "llm.invoke(\"Who is the first president of United States?\")" ] }, { @@ -71,7 +96,11 @@ "\n", "# Create an instance of OCI Model Deployment Endpoint\n", "# Replace the endpoint uri and model name with your own\n", - "llm = OCIModelDeploymentVLLM(endpoint=\"https:///predict\", model=\"model_name\")\n", + "# Using framework specific class as entry point, you will\n", + "# be able to pass model parameters in constructor.\n", + "llm = OCIModelDeploymentVLLM(\n", + " endpoint=\"https://modeldeployment..oci.customer-oci.com//predict\",\n", + ")\n", "\n", "# Run the LLM\n", "llm.invoke(\"Who is the first president of United States?\")" @@ -97,14 +126,64 @@ "\n", "# Set endpoint through environment variables\n", "# Replace the endpoint uri with your own\n", - "os.environ[\"OCI_LLM_ENDPOINT\"] = \"https:///predict\"\n", + "os.environ[\"OCI_LLM_ENDPOINT\"] = (\n", + " \"https://modeldeployment..oci.customer-oci.com//predict\"\n", + ")\n", "\n", "# Create an instance of OCI Model Deployment Endpoint\n", + "# Using framework specific class as entry point, you will\n", + "# be able to pass model parameters in constructor.\n", "llm = OCIModelDeploymentTGI()\n", "\n", "# Run the LLM\n", "llm.invoke(\"Who is the first president of United States?\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Asynchronous calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "await llm.ainvoke(\"Tell me a joke.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming calls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for chunk in llm.stream(\"Tell me a joke.\"):\n", + " print(chunk, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For comprehensive details on all features and configurations, please refer to the API reference documentation for each class:\n", + "\n", + "* [OCIModelDeploymentLLM](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.oci_data_science_model_deployment_endpoint.OCIModelDeploymentLLM.html)\n", + "* [OCIModelDeploymentVLLM](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.oci_data_science_model_deployment_endpoint.OCIModelDeploymentVLLM.html)\n", + "* [OCIModelDeploymentTGI](https://api.python.langchain.com/en/latest/llms/langchain_community.llms.oci_data_science_model_deployment_endpoint.OCIModelDeploymentTGI.html)" + ] } ], "metadata": { diff --git a/docs/docs/integrations/providers/oci.mdx b/docs/docs/integrations/providers/oci.mdx index 5037fb86f1..58c167995b 100644 --- a/docs/docs/integrations/providers/oci.mdx +++ b/docs/docs/integrations/providers/oci.mdx @@ -32,20 +32,18 @@ from langchain_community.embeddings import OCIGenAIEmbeddings > as an OCI Model Deployment Endpoint using the > [OCI Data Science Model Deployment Service](https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm). -If you deployed a LLM with the VLLM or TGI framework, you can use the -`OCIModelDeploymentVLLM` or `OCIModelDeploymentTGI` classes to interact with it. - To use, you should have the latest `oracle-ads` python SDK installed. ```bash pip install -U oracle-ads ``` -See [usage examples](/docs/integrations/llms/oci_model_deployment_endpoint). +See [chat](/docs/integrations/chat/oci_data_science) and [complete](/docs/integrations/llms/oci_model_deployment_endpoint) usage examples. + ```python -from langchain_community.llms import OCIModelDeploymentVLLM +from langchain_community.chat_models import ChatOCIModelDeployment -from langchain_community.llms import OCIModelDeploymentTGI +from langchain_community.llms import OCIModelDeploymentLLM ``` diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 3cdf3d52c6..b2548b2219 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -38,6 +38,7 @@ jinja2>=3,<4 jq>=1.4.1,<2 jsonschema>1 keybert>=0.8.5 +langchain_openai>=0.2.1 litellm>=1.30,<=1.39.5 lxml>=4.9.3,<6.0 markdownify>=0.11.6,<0.12 diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index db50763752..88df9d4330 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -125,6 +125,11 @@ if TYPE_CHECKING: from langchain_community.chat_models.moonshot import ( MoonshotChat, ) + from langchain_community.chat_models.oci_data_science import ( + ChatOCIModelDeployment, + ChatOCIModelDeploymentTGI, + ChatOCIModelDeploymentVLLM, + ) from langchain_community.chat_models.oci_generative_ai import ( ChatOCIGenAI, # noqa: F401 ) @@ -211,6 +216,9 @@ __all__ = [ "ChatMlflow", "ChatNebula", "ChatOCIGenAI", + "ChatOCIModelDeployment", + "ChatOCIModelDeploymentVLLM", + "ChatOCIModelDeploymentTGI", "ChatOllama", "ChatOpenAI", "ChatPerplexity", @@ -272,6 +280,9 @@ _module_lookup = { "ChatNebula": "langchain_community.chat_models.symblai_nebula", "ChatOctoAI": "langchain_community.chat_models.octoai", "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", + "ChatOCIModelDeployment": "langchain_community.chat_models.oci_data_science", + "ChatOCIModelDeploymentVLLM": "langchain_community.chat_models.oci_data_science", + "ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science", "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", "ChatPerplexity": "langchain_community.chat_models.perplexity", diff --git a/libs/community/langchain_community/chat_models/oci_data_science.py b/libs/community/langchain_community/chat_models/oci_data_science.py new file mode 100644 index 0000000000..cdb181df89 --- /dev/null +++ b/libs/community/langchain_community/chat_models/oci_data_science.py @@ -0,0 +1,998 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. + +"""Chat model for OCI data science model deployment endpoint.""" + +import importlib +import json +import logging +from operator import itemgetter +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Type, + Union, +) + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk +from langchain_core.output_parsers import ( + JsonOutputParser, + PydanticOutputParser, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel, Field, model_validator + +from langchain_community.llms.oci_data_science_model_deployment_endpoint import ( + DEFAULT_MODEL_NAME, + BaseOCIModelDeployment, +) + +logger = logging.getLogger(__name__) + + +def _is_pydantic_class(obj: Any) -> bool: + return isinstance(obj, type) and issubclass(obj, BaseModel) + + +class ChatOCIModelDeployment(BaseChatModel, BaseOCIModelDeployment): + """OCI Data Science Model Deployment chat model integration. + + Setup: + Install ``oracle-ads`` and ``langchain-openai``. + + .. code-block:: bash + + pip install -U oracle-ads langchain-openai + + Use `ads.set_auth()` to configure authentication. + For example, to use OCI resource_principal for authentication: + + .. code-block:: python + + import ads + ads.set_auth("resource_principal") + + For more details on authentication, see: + https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm + + + Key init args - completion params: + endpoint: str + The OCI model deployment endpoint. + temperature: float + Sampling temperature. + max_tokens: Optional[int] + Max number of tokens to generate. + + Key init args — client params: + auth: dict + ADS auth dictionary for OCI authentication. + + Instantiate: + .. code-block:: python + + from langchain_community.chat_models import ChatOCIModelDeployment + + chat = ChatOCIModelDeployment( + endpoint="https://modeldeployment..oci.customer-oci.com//predict", + model="odsc-llm", + streaming=True, + max_retries=3, + model_kwargs={ + "max_token": 512, + "temperature": 0.2, + # other model parameters ... + }, + ) + + Invocation: + .. code-block:: python + + messages = [ + ("system", "Translate the user sentence to French."), + ("human", "Hello World!"), + ] + chat.invoke(messages) + + .. code-block:: python + + AIMessage( + content='Bonjour le monde!', + response_metadata={ + 'token_usage': { + 'prompt_tokens': 40, + 'total_tokens': 50, + 'completion_tokens': 10 + }, + 'model_name': 'odsc-llm', + 'system_fingerprint': '', + 'finish_reason': 'stop' + }, + id='run-cbed62da-e1b3-4abd-9df3-ec89d69ca012-0' + ) + + Streaming: + .. code-block:: python + + for chunk in chat.stream(messages): + print(chunk) + + .. code-block:: python + + content='' id='run-02c6-c43f-42de' + content='\n' id='run-02c6-c43f-42de' + content='B' id='run-02c6-c43f-42de' + content='on' id='run-02c6-c43f-42de' + content='j' id='run-02c6-c43f-42de' + content='our' id='run-02c6-c43f-42de' + content=' le' id='run-02c6-c43f-42de' + content=' monde' id='run-02c6-c43f-42de' + content='!' id='run-02c6-c43f-42de' + content='' response_metadata={'finish_reason': 'stop'} id='run-02c6-c43f-42de' + + Async: + .. code-block:: python + + await chat.ainvoke(messages) + + # stream: + # async for chunk in (await chat.astream(messages)) + + .. code-block:: python + + AIMessage( + content='Bonjour le monde!', + response_metadata={'finish_reason': 'stop'}, + id='run-8657a105-96b7-4bb6-b98e-b69ca420e5d1-0' + ) + + Structured output: + .. code-block:: python + + from typing import Optional + from pydantic import BaseModel, Field + + class Joke(BaseModel): + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + + structured_llm = chat.with_structured_output(Joke, method="json_mode") + structured_llm.invoke( + "Tell me a joke about cats, " + "respond in JSON with `setup` and `punchline` keys" + ) + + .. code-block:: python + + Joke( + setup='Why did the cat get stuck in the tree?', + punchline='Because it was chasing its tail!' + ) + + See ``ChatOCIModelDeployment.with_structured_output()`` for more. + + Customized Usage: + You can inherit from base class and overwrite the `_process_response`, + `_process_stream_response`, `_construct_json_body` for customized usage. + + .. code-block:: python + + class MyChatModel(ChatOCIModelDeployment): + def _process_stream_response(self, response_json: dict) -> ChatGenerationChunk: + print("My customized streaming result handler.") + return GenerationChunk(...) + + def _process_response(self, response_json:dict) -> ChatResult: + print("My customized output handler.") + return ChatResult(...) + + def _construct_json_body(self, messages: list, params: dict) -> dict: + print("My customized payload handler.") + return { + "messages": messages, + **params, + } + + chat = MyChatModel( + endpoint=f"https://modeldeployment..oci.customer-oci.com/{ocid}/predict", + model="odsc-llm", + } + + chat.invoke("tell me a joke") + + Response metadata + .. code-block:: python + + ai_msg = chat.invoke(messages) + ai_msg.response_metadata + + .. code-block:: python + + { + 'token_usage': { + 'prompt_tokens': 40, + 'total_tokens': 50, + 'completion_tokens': 10 + }, + 'model_name': 'odsc-llm', + 'system_fingerprint': '', + 'finish_reason': 'stop' + } + + """ # noqa: E501 + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + + model: str = DEFAULT_MODEL_NAME + """The name of the model.""" + + stop: Optional[List[str]] = None + """Stop words to use when generating. Model output is cut off + at the first occurrence of any of these substrings.""" + + @model_validator(mode="before") + @classmethod + def validate_openai(cls, values: Any) -> Any: + """Checks if langchain_openai is installed.""" + if not importlib.util.find_spec("langchain_openai"): + raise ImportError( + "Could not import langchain_openai package. " + "Please install it with `pip install langchain_openai`." + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci_model_depolyment_chat_endpoint" + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint": self.endpoint, "model_kwargs": _model_kwargs}, + **self._default_params, + } + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters.""" + return { + "model": self.model, + "stop": self.stop, + "stream": self.streaming, + } + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Call out to an OCI Model Deployment Online endpoint. + + Args: + messages: The messages in the conversation with the chat model. + stop: Optional list of stop words to use when generating. + + Returns: + LangChain ChatResult + + Raises: + RuntimeError: + Raise when invoking endpoint fails. + + Example: + + .. code-block:: python + + messages = [ + ( + "system", + "You are a helpful assistant that translates English to French. Translate the user sentence.", + ), + ("human", "Hello World!"), + ] + + response = chat.invoke(messages) + """ # noqa: E501 + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + requests_kwargs = kwargs.pop("requests_kwargs", {}) + params = self._invocation_params(stop, **kwargs) + body = self._construct_json_body(messages, params) + res = self.completion_with_retry( + data=body, run_manager=run_manager, **requests_kwargs + ) + return self._process_response(res.json()) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream OCI Data Science Model Deployment endpoint on given messages. + + Args: + messages (List[BaseMessage]): + The messagaes to pass into the model. + stop (List[str], Optional): + List of stop words to use when generating. + kwargs: + requests_kwargs: + Additional ``**kwargs`` to pass to requests.post + + Returns: + An iterator of ChatGenerationChunk. + + Raises: + RuntimeError: + Raise when invoking endpoint fails. + + Example: + + .. code-block:: python + + messages = [ + ( + "system", + "You are a helpful assistant that translates English to French. Translate the user sentence.", + ), + ("human", "Hello World!"), + ] + + chunk_iter = chat.stream(messages) + + """ # noqa: E501 + requests_kwargs = kwargs.pop("requests_kwargs", {}) + self.streaming = True + params = self._invocation_params(stop, **kwargs) + body = self._construct_json_body(messages, params) # request json body + + response = self.completion_with_retry( + data=body, run_manager=run_manager, stream=True, **requests_kwargs + ) + default_chunk_class = AIMessageChunk + for line in self._parse_stream(response.iter_lines()): + chunk = self._handle_sse_line(line, default_chunk_class) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Asynchronously call out to OCI Data Science Model Deployment + endpoint on given messages. + + Args: + messages (List[BaseMessage]): + The messagaes to pass into the model. + stop (List[str], Optional): + List of stop words to use when generating. + kwargs: + requests_kwargs: + Additional ``**kwargs`` to pass to requests.post + + Returns: + LangChain ChatResult. + + Raises: + ValueError: + Raise when invoking endpoint fails. + + Example: + + .. code-block:: python + + messages = [ + ( + "system", + "You are a helpful assistant that translates English to French. Translate the user sentence.", + ), + ("human", "I love programming."), + ] + + resp = await chat.ainvoke(messages) + + """ # noqa: E501 + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + requests_kwargs = kwargs.pop("requests_kwargs", {}) + params = self._invocation_params(stop, **kwargs) + body = self._construct_json_body(messages, params) + response = await self.acompletion_with_retry( + data=body, + run_manager=run_manager, + **requests_kwargs, + ) + return self._process_response(response) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Asynchronously streaming OCI Data Science Model Deployment + endpoint on given messages. + + Args: + messages (List[BaseMessage]): + The messagaes to pass into the model. + stop (List[str], Optional): + List of stop words to use when generating. + kwargs: + requests_kwargs: + Additional ``**kwargs`` to pass to requests.post + + Returns: + An Asynciterator of ChatGenerationChunk. + + Raises: + ValueError: + Raise when invoking endpoint fails. + + Example: + + .. code-block:: python + + messages = [ + ( + "system", + "You are a helpful assistant that translates English to French. Translate the user sentence.", + ), + ("human", "I love programming."), + ] + + chunk_iter = await chat.astream(messages) + + """ # noqa: E501 + requests_kwargs = kwargs.pop("requests_kwargs", {}) + self.streaming = True + params = self._invocation_params(stop, **kwargs) + body = self._construct_json_body(messages, params) # request json body + + default_chunk_class = AIMessageChunk + async for line in await self.acompletion_with_retry( + data=body, run_manager=run_manager, stream=True, **requests_kwargs + ): + chunk = self._handle_sse_line(line, default_chunk_class) + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + + def with_structured_output( + self, + schema: Optional[Union[Dict, Type[BaseModel]]] = None, + *, + method: Literal["json_mode"] = "json_mode", + include_raw: bool = False, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + """Model wrapper that returns outputs formatted to match the given schema. + + Args: + schema: The output schema as a dict or a Pydantic class. If a Pydantic class + then the model output will be an object of that class. If a dict then + the model output will be a dict. With a Pydantic class the returned + attributes will be validated, whereas with a dict they will not be. If + `method` is "function_calling" and `schema` is a dict, then the dict + must match the OpenAI function-calling spec. + method: The method for steering model generation, currently only support + for "json_mode". If "json_mode" then JSON mode will be used. Note that + if using "json_mode" then you must include instructions for formatting + the output into the desired schema into the model call. + include_raw: If False then only the parsed structured output is returned. If + an error occurs during model output parsing it will be raised. If True + then both the raw model response (a BaseMessage) and the parsed model + response will be returned. If an error occurs during output parsing it + will be caught and returned as well. The final output is always a dict + with keys "raw", "parsed", and "parsing_error". + + Returns: + A Runnable that takes any ChatModel input and returns as output: + + If include_raw is True then a dict with keys: + raw: BaseMessage + parsed: Optional[_DictOrPydantic] + parsing_error: Optional[BaseException] + + If include_raw is False then just _DictOrPydantic is returned, + where _DictOrPydantic depends on the schema: + + If schema is a Pydantic class then _DictOrPydantic is the Pydantic + class. + + If schema is a dict then _DictOrPydantic is a dict. + + """ # noqa: E501 + if kwargs: + raise ValueError(f"Received unsupported arguments {kwargs}") + is_pydantic_schema = _is_pydantic_class(schema) + if method == "json_mode": + llm = self.bind(response_format={"type": "json_object"}) + output_parser = ( + PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type] + if is_pydantic_schema + else JsonOutputParser() + ) + else: + raise ValueError( + f"Unrecognized method argument. Expected `json_mode`." + f"Received: `{method}`." + ) + + if include_raw: + parser_assign = RunnablePassthrough.assign( + parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None + ) + parser_none = RunnablePassthrough.assign(parsed=lambda _: None) + parser_with_fallback = parser_assign.with_fallbacks( + [parser_none], exception_key="parsing_error" + ) + return RunnableMap(raw=llm) | parser_with_fallback + else: + return llm | output_parser + + def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: + """Combines the invocation parameters with default parameters.""" + params = self._default_params + _model_kwargs = self.model_kwargs or {} + params["stop"] = stop or params.get("stop", []) + return {**params, **_model_kwargs, **kwargs} + + def _handle_sse_line( + self, line: str, default_chunk_cls: Type[BaseMessageChunk] = AIMessageChunk + ) -> ChatGenerationChunk: + """Handle a single Server-Sent Events (SSE) line and process it into + a chat generation chunk. + + Args: + line (str): A single line from the SSE stream in string format. + default_chunk_cls (AIMessageChunk): The default class for message + chunks to be used during the processing of the stream response. + + Returns: + ChatGenerationChunk: The processed chat generation chunk. If an error + occurs, an empty `ChatGenerationChunk` is returned. + """ + try: + obj = json.loads(line) + return self._process_stream_response(obj, default_chunk_cls) + except Exception as e: + logger.debug(f"Error occurs when processing line={line}: {str(e)}") + return ChatGenerationChunk(message=AIMessageChunk(content="")) + + def _construct_json_body(self, messages: list, params: dict) -> dict: + """Constructs the request body as a dictionary (JSON). + + Args: + messages (list): A list of message objects to be included in the + request body. + params (dict): A dictionary of additional parameters to be included + in the request body. + + Returns: + dict: A dictionary representing the JSON request body, including + converted messages and additional parameters. + + """ + from langchain_openai.chat_models.base import _convert_message_to_dict + + return { + "messages": [_convert_message_to_dict(m) for m in messages], + **params, + } + + def _process_stream_response( + self, + response_json: dict, + default_chunk_cls: Type[BaseMessageChunk] = AIMessageChunk, + ) -> ChatGenerationChunk: + """Formats streaming response in OpenAI spec. + + Args: + response_json (dict): The JSON response from the streaming endpoint. + default_chunk_cls (type, optional): The default class to use for + creating message chunks. Defaults to `AIMessageChunk`. + + Returns: + ChatGenerationChunk: An object containing the processed message + chunk and any relevant generation information such as finish + reason and usage. + + Raises: + ValueError: If the response JSON is not well-formed or does not + contain the expected structure. + """ + from langchain_openai.chat_models.base import _convert_delta_to_message_chunk + + try: + choice = response_json["choices"][0] + if not isinstance(choice, dict): + raise TypeError("Endpoint response is not well formed.") + except (KeyError, IndexError, TypeError) as e: + raise ValueError( + "Error while formatting response payload for chat model of type" + ) from e + + chunk = _convert_delta_to_message_chunk(choice["delta"], default_chunk_cls) + default_chunk_cls = chunk.__class__ + finish_reason = choice.get("finish_reason") + usage = choice.get("usage") + gen_info = {} + if finish_reason is not None: + gen_info.update({"finish_reason": finish_reason}) + if usage is not None: + gen_info.update({"usage": usage}) + + return ChatGenerationChunk( + message=chunk, generation_info=gen_info if gen_info else None + ) + + def _process_response(self, response_json: dict) -> ChatResult: + """Formats response in OpenAI spec. + + Args: + response_json (dict): The JSON response from the chat model endpoint. + + Returns: + ChatResult: An object containing the list of `ChatGeneration` objects + and additional LLM output information. + + Raises: + ValueError: If the response JSON is not well-formed or does not + contain the expected structure. + + """ + from langchain_openai.chat_models.base import _convert_dict_to_message + + generations = [] + try: + choices = response_json["choices"] + if not isinstance(choices, list): + raise TypeError("Endpoint response is not well formed.") + except (KeyError, TypeError) as e: + raise ValueError( + "Error while formatting response payload for chat model of type" + ) from e + + for choice in choices: + message = _convert_dict_to_message(choice["message"]) + generation_info = dict(finish_reason=choice.get("finish_reason")) + if "logprobs" in choice: + generation_info["logprobs"] = choice["logprobs"] + + gen = ChatGeneration( + message=message, + generation_info=generation_info, + ) + generations.append(gen) + + token_usage = response_json.get("usage", {}) + llm_output = { + "token_usage": token_usage, + "model_name": self.model, + "system_fingerprint": response_json.get("system_fingerprint", ""), + } + return ChatResult(generations=generations, llm_output=llm_output) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + return super().bind(tools=formatted_tools, **kwargs) + + +class ChatOCIModelDeploymentVLLM(ChatOCIModelDeployment): + """OCI large language chat models deployed with vLLM. + + To use, you must provide the model HTTP endpoint from your deployed + model, e.g. https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict. + + To authenticate, `oracle-ads` has been used to automatically load + credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint + + Example: + + .. code-block:: python + + from langchain_community.chat_models import ChatOCIModelDeploymentVLLM + + chat = ChatOCIModelDeploymentVLLM( + endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + frequency_penalty=0.1, + max_tokens=512, + temperature=0.2, + top_p=1.0, + # other model parameters... + ) + + """ # noqa: E501 + + frequency_penalty: float = 0.0 + """Penalizes repeated tokens according to frequency. Between 0 and 1.""" + + logit_bias: Optional[Dict[str, float]] = None + """Adjust the probability of specific tokens being generated.""" + + max_tokens: Optional[int] = 256 + """The maximum number of tokens to generate in the completion.""" + + n: int = 1 + """Number of output sequences to return for the given prompt.""" + + presence_penalty: float = 0.0 + """Penalizes repeated tokens. Between 0 and 1.""" + + temperature: float = 0.2 + """What sampling temperature to use.""" + + top_p: float = 1.0 + """Total probability mass of tokens to consider at each step.""" + + best_of: Optional[int] = None + """Generates best_of completions server-side and returns the "best" + (the one with the highest log probability per token). + """ + + use_beam_search: Optional[bool] = False + """Whether to use beam search instead of sampling.""" + + top_k: Optional[int] = -1 + """Number of most likely tokens to consider at each step.""" + + min_p: Optional[float] = 0.0 + """Float that represents the minimum probability for a token to be considered. + Must be in [0,1]. 0 to disable this.""" + + repetition_penalty: Optional[float] = 1.0 + """Float that penalizes new tokens based on their frequency in the + generated text. Values > 1 encourage the model to use new tokens.""" + + length_penalty: Optional[float] = 1.0 + """Float that penalizes sequences based on their length. Used only + when `use_beam_search` is True.""" + + early_stopping: Optional[bool] = False + """Controls the stopping condition for beam search. It accepts the + following values: `True`, where the generation stops as soon as there + are `best_of` complete candidates; `False`, where a heuristic is applied + to the generation stops when it is very unlikely to find better candidates; + `never`, where the beam search procedure only stops where there cannot be + better candidates (canonical beam search algorithm).""" + + ignore_eos: Optional[bool] = False + """Whether to ignore the EOS token and continue generating tokens after + the EOS token is generated.""" + + min_tokens: Optional[int] = 0 + """Minimum number of tokens to generate per output sequence before + EOS or stop_token_ids can be generated""" + + stop_token_ids: Optional[List[int]] = None + """List of tokens that stop the generation when they are generated. + The returned output will contain the stop tokens unless the stop tokens + are special tokens.""" + + skip_special_tokens: Optional[bool] = True + """Whether to skip special tokens in the output. Defaults to True.""" + + spaces_between_special_tokens: Optional[bool] = True + """Whether to add spaces between special tokens in the output. + Defaults to True.""" + + tool_choice: Optional[str] = None + """Whether to use tool calling. + Defaults to None, tool calling is disabled. + Tool calling requires model support and the vLLM to be configured + with `--tool-call-parser`. + Set this to `auto` for the model to make tool calls automatically. + Set this to `required` to force the model to always call one or more tools. + """ + + chat_template: Optional[str] = None + """Use customized chat template. + Defaults to None. The chat template from the tokenizer will be used. + """ + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci_model_depolyment_chat_endpoint_vllm" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters.""" + params = { + "model": self.model, + "stop": self.stop, + "stream": self.streaming, + } + for attr_name in self._get_model_params(): + try: + value = getattr(self, attr_name) + if value is not None: + params.update({attr_name: value}) + except Exception: + pass + + return params + + def _get_model_params(self) -> List[str]: + """Gets the name of model parameters.""" + return [ + "best_of", + "early_stopping", + "frequency_penalty", + "ignore_eos", + "length_penalty", + "logit_bias", + "logprobs", + "max_tokens", + "min_p", + "min_tokens", + "n", + "presence_penalty", + "repetition_penalty", + "skip_special_tokens", + "spaces_between_special_tokens", + "stop_token_ids", + "temperature", + "top_k", + "top_p", + "use_beam_search", + "tool_choice", + "chat_template", + ] + + +class ChatOCIModelDeploymentTGI(ChatOCIModelDeployment): + """OCI large language chat models deployed with Text Generation Inference. + + To use, you must provide the model HTTP endpoint from your deployed + model, e.g. https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict. + + To authenticate, `oracle-ads` has been used to automatically load + credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint + + Example: + + .. code-block:: python + + from langchain_community.chat_models import ChatOCIModelDeploymentTGI + + chat = ChatOCIModelDeploymentTGI( + endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + max_token=512, + temperature=0.2, + frequency_penalty=0.1, + seed=42, + # other model parameters... + ) + + """ # noqa: E501 + + frequency_penalty: Optional[float] = None + """Penalizes repeated tokens according to frequency. Between 0 and 1.""" + + logit_bias: Optional[Dict[str, float]] = None + """Adjust the probability of specific tokens being generated.""" + + logprobs: Optional[bool] = None + """Whether to return log probabilities of the output tokens or not.""" + + max_tokens: int = 256 + """The maximum number of tokens to generate in the completion.""" + + n: int = 1 + """Number of output sequences to return for the given prompt.""" + + presence_penalty: Optional[float] = None + """Penalizes repeated tokens. Between 0 and 1.""" + + seed: Optional[int] = None + """To sample deterministically,""" + + temperature: float = 0.2 + """What sampling temperature to use.""" + + top_p: Optional[float] = None + """Total probability mass of tokens to consider at each step.""" + + top_logprobs: Optional[int] = None + """An integer between 0 and 5 specifying the number of most + likely tokens to return at each token position, each with an + associated log probability. logprobs must be set to true if + this parameter is used.""" + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci_model_depolyment_chat_endpoint_tgi" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters.""" + params = { + "model": self.model, + "stop": self.stop, + "stream": self.streaming, + } + for attr_name in self._get_model_params(): + try: + value = getattr(self, attr_name) + if value is not None: + params.update({attr_name: value}) + except Exception: + pass + + return params + + def _get_model_params(self) -> List[str]: + """Gets the name of model parameters.""" + return [ + "frequency_penalty", + "logit_bias", + "logprobs", + "max_tokens", + "n", + "presence_penalty", + "seed", + "temperature", + "top_k", + "top_p", + "top_logprobs", + ] diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index bc1ada6982..1bf5e0a7ee 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -396,6 +396,14 @@ def _import_oci_md_vllm() -> Type[BaseLLM]: return OCIModelDeploymentVLLM +def _import_oci_md() -> Type[BaseLLM]: + from langchain_community.llms.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentLLM, + ) + + return OCIModelDeploymentLLM + + def _import_oci_gen_ai() -> Type[BaseLLM]: from langchain_community.llms.oci_generative_ai import OCIGenAI @@ -773,6 +781,8 @@ def __getattr__(name: str) -> Any: return _import_oci_md_tgi() elif name == "OCIModelDeploymentVLLM": return _import_oci_md_vllm() + elif name == "OCIModelDeploymentLLM": + return _import_oci_md() elif name == "OCIGenAI": return _import_oci_gen_ai() elif name == "OctoAIEndpoint": @@ -928,6 +938,7 @@ __all__ = [ "OCIGenAI", "OCIModelDeploymentTGI", "OCIModelDeploymentVLLM", + "OCIModelDeploymentLLM", "OctoAIEndpoint", "Ollama", "OpaquePrompts", @@ -1029,6 +1040,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "nlpcloud": _import_nlpcloud, "oci_model_deployment_tgi_endpoint": _import_oci_md_tgi, "oci_model_deployment_vllm_endpoint": _import_oci_md_vllm, + "oci_model_deployment_endpoint": _import_oci_md, "oci_generative_ai": _import_oci_gen_ai, "octoai_endpoint": _import_octoai_endpoint, "ollama": _import_ollama, diff --git a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py index 04f551ce27..c998fbc0ec 100644 --- a/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py +++ b/libs/community/langchain_community/llms/oci_data_science_model_deployment_endpoint.py @@ -1,19 +1,68 @@ -import logging -from typing import Any, Dict, List, Optional +# Copyright (c) 2023, 2024, Oracle and/or its affiliates. +"""LLM for OCI data science model deployment endpoint.""" + +import json +import logging +import traceback +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Union, +) + +import aiohttp import requests -from langchain_core.callbacks import CallbackManagerForLLMRun -from langchain_core.language_models.llms import LLM -from langchain_core.utils import get_from_dict_or_env, pre_init -from pydantic import Field +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator +from langchain_core.load.serializable import Serializable +from langchain_core.outputs import Generation, GenerationChunk, LLMResult +from langchain_core.utils import get_from_dict_or_env +from pydantic import Field, model_validator + +from langchain_community.utilities.requests import Requests logger = logging.getLogger(__name__) + DEFAULT_TIME_OUT = 300 DEFAULT_CONTENT_TYPE_JSON = "application/json" +DEFAULT_MODEL_NAME = "odsc-llm" -class OCIModelDeploymentLLM(LLM): +class TokenExpiredError(Exception): + """Raises when token expired.""" + + +class ServerError(Exception): + """Raises when encounter server error when making inference.""" + + +def _create_retry_decorator( + llm: "BaseOCIModelDeployment", + *, + run_manager: Optional[ + Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] + ] = None, +) -> Callable[[Any], Any]: + """Create a retry decorator.""" + errors = [requests.exceptions.ConnectTimeout, TokenExpiredError] + decorator = create_base_retry_decorator( + error_types=errors, max_retries=llm.max_retries, run_manager=run_manager + ) + return decorator + + +class BaseOCIModelDeployment(Serializable): """Base class for LLM deployed on OCI Data Science Model Deployment.""" auth: dict = Field(default_factory=dict, exclude=True) @@ -23,21 +72,366 @@ class OCIModelDeploymentLLM(LLM): or `ads.common.auth.resource_principal()`. If this is not provided then the `ads.common.default_signer()` will be used.""" + endpoint: str = "" + """The uri of the endpoint from the deployed Model Deployment model.""" + + streaming: bool = False + """Whether to stream the results or not.""" + + max_retries: int = 3 + """Maximum number of retries to make when generating.""" + + @model_validator(mode="before") + @classmethod + def validate_environment(cls, values: Dict) -> Dict: + """Checks if oracle-ads is installed and + get credentials/endpoint from environment. + """ + try: + import ads + + except ImportError as ex: + raise ImportError( + "Could not import ads python package. " + "Please install it with `pip install oracle_ads`." + ) from ex + + if not values.get("auth", None): + values["auth"] = ads.common.auth.default_signer() + + values["endpoint"] = get_from_dict_or_env( + values, + "endpoint", + "OCI_LLM_ENDPOINT", + ) + return values + + def _headers( + self, is_async: Optional[bool] = False, body: Optional[dict] = None + ) -> Dict: + """Construct and return the headers for a request. + + Args: + is_async (bool, optional): Indicates if the request is asynchronous. + Defaults to `False`. + body (optional): The request body to be included in the headers if + the request is asynchronous. + + Returns: + Dict: A dictionary containing the appropriate headers for the request. + """ + if is_async: + signer = self.auth["signer"] + _req = requests.Request("POST", self.endpoint, json=body) + req = _req.prepare() + req = signer(req) + headers = {} + for key, value in req.headers.items(): + headers[key] = value + + if self.streaming: + headers.update( + {"enable-streaming": "true", "Accept": "text/event-stream"} + ) + return headers + + return ( + { + "Content-Type": DEFAULT_CONTENT_TYPE_JSON, + "enable-streaming": "true", + "Accept": "text/event-stream", + } + if self.streaming + else { + "Content-Type": DEFAULT_CONTENT_TYPE_JSON, + } + ) + + def completion_with_retry( + self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any + ) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + try: + request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT) + data = kwargs.pop("data") + stream = kwargs.pop("stream", self.streaming) + + request = Requests( + headers=self._headers(), auth=self.auth.get("signer") + ) + response = request.post( + url=self.endpoint, + data=data, + timeout=request_timeout, + stream=stream, + **kwargs, + ) + self._check_response(response) + return response + except TokenExpiredError as e: + raise e + except Exception as err: + traceback.print_exc() + logger.debug( + f"Requests payload: {data}. Requests arguments: " + f"url={self.endpoint},timeout={request_timeout},stream={stream}. " + f"Additional request kwargs={kwargs}." + ) + raise RuntimeError( + f"Error occurs by inference endpoint: {str(err)}" + ) from err + + return _completion_with_retry(**kwargs) + + async def acompletion_with_retry( + self, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Any: + """Use tenacity to retry the async completion call.""" + retry_decorator = _create_retry_decorator(self, run_manager=run_manager) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + try: + request_timeout = kwargs.pop("request_timeout", DEFAULT_TIME_OUT) + data = kwargs.pop("data") + stream = kwargs.pop("stream", self.streaming) + + request = Requests(headers=self._headers(is_async=True, body=data)) + if stream: + response = request.apost( + url=self.endpoint, + data=data, + timeout=request_timeout, + ) + return self._aiter_sse(response) + else: + async with request.apost( + url=self.endpoint, + data=data, + timeout=request_timeout, + ) as resp: + self._check_response(resp) + data = await resp.json() + return data + except TokenExpiredError as e: + raise e + except Exception as err: + traceback.print_exc() + logger.debug( + f"Requests payload: `{data}`. " + f"Stream mode={stream}. " + f"Requests kwargs: url={self.endpoint}, timeout={request_timeout}." + ) + raise RuntimeError( + f"Error occurs by inference endpoint: {str(err)}" + ) from err + + return await _completion_with_retry(**kwargs) + + def _check_response(self, response: Any) -> None: + """Handle server error by checking the response status. + + Args: + response: + The response object from either `requests` or `aiohttp` library. + + Raises: + TokenExpiredError: + If the response status code is 401 and the token refresh is successful. + ServerError: + If any other HTTP error occurs. + """ + try: + response.raise_for_status() + except requests.exceptions.HTTPError as http_err: + status_code = ( + response.status_code + if hasattr(response, "status_code") + else response.status + ) + if status_code == 401 and self._refresh_signer(): + raise TokenExpiredError() from http_err + + raise ServerError( + f"Server error: {str(http_err)}. \nMessage: {response.text}" + ) from http_err + + def _parse_stream(self, lines: Iterator[bytes]) -> Iterator[str]: + """Parse a stream of byte lines and yield parsed string lines. + + Args: + lines (Iterator[bytes]): + An iterator that yields lines in byte format. + + Yields: + Iterator[str]: + An iterator that yields parsed lines as strings. + """ + for line in lines: + _line = self._parse_stream_line(line) + if _line is not None: + yield _line + + async def _parse_stream_async( + self, + lines: aiohttp.StreamReader, + ) -> AsyncIterator[str]: + """ + Asynchronously parse a stream of byte lines and yield parsed string lines. + + Args: + lines (aiohttp.StreamReader): + An `aiohttp.StreamReader` object that yields lines in byte format. + + Yields: + AsyncIterator[str]: + An asynchronous iterator that yields parsed lines as strings. + """ + async for line in lines: + _line = self._parse_stream_line(line) + if _line is not None: + yield _line + + def _parse_stream_line(self, line: bytes) -> Optional[str]: + """Parse a single byte line and return a processed string line if valid. + + Args: + line (bytes): A single line in byte format. + + Returns: + Optional[str]: + The processed line as a string if valid, otherwise `None`. + """ + line = line.strip() + if not line: + return None + _line = line.decode("utf-8") + + if _line.lower().startswith("data:"): + _line = _line[5:].lstrip() + + if _line.startswith("[DONE]"): + return None + return _line + return None + + async def _aiter_sse( + self, + async_cntx_mgr: Any, + ) -> AsyncIterator[str]: + """Asynchronously iterate over server-sent events (SSE). + + Args: + async_cntx_mgr: An asynchronous context manager that yields a client + response object. + + Yields: + AsyncIterator[str]: An asynchronous iterator that yields parsed server-sent + event lines as json string. + """ + async with async_cntx_mgr as client_resp: + self._check_response(client_resp) + async for line in self._parse_stream_async(client_resp.content): + yield line + + def _refresh_signer(self) -> bool: + """Attempt to refresh the security token using the signer. + + Returns: + bool: `True` if the token was successfully refreshed, `False` otherwise. + """ + if self.auth.get("signer", None) and hasattr( + self.auth["signer"], "refresh_security_token" + ): + self.auth["signer"].refresh_security_token() + return True + return False + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by LangChain.""" + return True + + +class OCIModelDeploymentLLM(BaseLLM, BaseOCIModelDeployment): + """LLM deployed on OCI Data Science Model Deployment. + + To use, you must provide the model HTTP endpoint from your deployed + model, e.g. https://modeldeployment..oci.customer-oci.com//predict. + + To authenticate, `oracle-ads` has been used to automatically load + credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html + + Make sure to have the required policies to access the OCI Data + Science Model Deployment endpoint. See: + https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint + + Example: + + .. code-block:: python + + from langchain_community.llms import OCIModelDeploymentLLM + + llm = OCIModelDeploymentLLM( + endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com//predict", + model="odsc-llm", + streaming=True, + model_kwargs={"frequency_penalty": 1.0}, + ) + llm.invoke("tell me a joke.") + + Customized Usage: + + User can inherit from our base class and overrwrite the `_process_response`, `_process_stream_response`, + `_construct_json_body` for satisfying customized needed. + + .. code-block:: python + + from langchain_community.llms import OCIModelDeploymentLLM + + class MyCutomizedModel(OCIModelDeploymentLLM): + def _process_stream_response(self, response_json:dict) -> GenerationChunk: + print("My customized output stream handler.") + return GenerationChunk() + + def _process_response(self, response_json:dict) -> List[Generation]: + print("My customized output handler.") + return [Generation()] + + def _construct_json_body(self, prompt: str, param:dict) -> dict: + print("My customized input handler.") + return {} + + llm = MyCutomizedModel( + endpoint=f"https://modeldeployment.us-ashburn-1.oci.customer-oci.com/{ocid}/predict", + model="", + } + + llm.invoke("tell me a joke.") + + """ # noqa: E501 + + model: str = DEFAULT_MODEL_NAME + """The name of the model.""" + max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" temperature: float = 0.2 """A non-negative float that tunes the degree of randomness in generation.""" - k: int = 0 + k: int = -1 """Number of most likely tokens to consider at each step.""" p: float = 0.75 """Total probability mass of tokens to consider at each step.""" - endpoint: str = "" - """The uri of the endpoint from the deployed Model Deployment model.""" - best_of: int = 1 """Generates best_of completions server-side and returns the "best" (the one with the highest log probability per token). @@ -47,71 +441,130 @@ class OCIModelDeploymentLLM(LLM): """Stop words to use when generating. Model output is cut off at the first occurrence of any of these substrings.""" - @pre_init - def validate_environment( # pylint: disable=no-self-argument - cls, values: Dict - ) -> Dict: - """Validate that python package exists in environment.""" - try: - import ads + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" - except ImportError as ex: - raise ImportError( - "Could not import ads python package. " - "Please install it with `pip install oracle_ads`." - ) from ex - if not values.get("auth", None): - values["auth"] = ads.common.auth.default_signer() - values["endpoint"] = get_from_dict_or_env( - values, - "endpoint", - "OCI_LLM_ENDPOINT", - ) - return values + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci_model_deployment_endpoint" @property def _default_params(self) -> Dict[str, Any]: - """Default parameters for the model.""" - raise NotImplementedError + """Get the default parameters.""" + return { + "best_of": self.best_of, + "max_tokens": self.max_tokens, + "model": self.model, + "stop": self.stop, + "stream": self.streaming, + "temperature": self.temperature, + "top_k": self.k, + "top_p": self.p, + } @property def _identifying_params(self) -> Dict[str, Any]: """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} return { - **{"endpoint": self.endpoint}, + **{"endpoint": self.endpoint, "model_kwargs": _model_kwargs}, **self._default_params, } - def _construct_json_body(self, prompt: str, params: dict) -> dict: - """Constructs the request body as a dictionary (JSON).""" - raise NotImplementedError + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Call out to OCI Data Science Model Deployment endpoint with k unique prompts. - def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: - """Combines the invocation parameters with default parameters.""" - params = self._default_params - if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: - params["stop"] = self.stop - elif stop is not None: - params["stop"] = stop - else: - # Don't set "stop" in param as None. It should be a list. - params["stop"] = [] + Args: + prompts: The prompts to pass into the service. + stop: Optional list of stop words to use when generating. - return {**params, **kwargs} + Returns: + The full LLM output. - def _process_response(self, response_json: dict) -> str: - raise NotImplementedError + Example: + .. code-block:: python - def _call( + response = llm.invoke("Tell me a joke.") + response = llm.generate(["Tell me a joke."]) + """ + generations: List[List[Generation]] = [] + params = self._invocation_params(stop, **kwargs) + for prompt in prompts: + body = self._construct_json_body(prompt, params) + if self.streaming: + generation = GenerationChunk(text="") + for chunk in self._stream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + generation += chunk + generations.append([generation]) + else: + res = self.completion_with_retry( + data=body, + run_manager=run_manager, + **kwargs, + ) + generations.append(self._process_response(res.json())) + return LLMResult(generations=generations) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + """Call out to OCI Data Science Model Deployment endpoint async with k unique prompts. + + Args: + prompts: The prompts to pass into the service. + stop: Optional list of stop words to use when generating. + + Returns: + The full LLM output. + + Example: + .. code-block:: python + + response = await llm.ainvoke("Tell me a joke.") + response = await llm.agenerate(["Tell me a joke."]) + """ # noqa: E501 + generations: List[List[Generation]] = [] + params = self._invocation_params(stop, **kwargs) + for prompt in prompts: + body = self._construct_json_body(prompt, params) + if self.streaming: + generation = GenerationChunk(text="") + async for chunk in self._astream( + prompt, stop=stop, run_manager=run_manager, **kwargs + ): + generation += chunk + generations.append([generation]) + else: + res = await self.acompletion_with_retry( + data=body, + run_manager=run_manager, + **kwargs, + ) + generations.append(self._process_response(res)) + return LLMResult(generations=generations) + + def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> str: - """Call out to OCI Data Science Model Deployment endpoint. + ) -> Iterator[GenerationChunk]: + """Stream OCI Data Science Model Deployment endpoint on given prompt. + Args: prompt (str): @@ -123,100 +576,162 @@ class OCIModelDeploymentLLM(LLM): Additional ``**kwargs`` to pass to requests.post Returns: - The string generated by the model. + An iterator of GenerationChunks. + Example: + .. code-block:: python - response = oci_md("Tell me a joke.") + response = llm.stream("Tell me a joke.") """ requests_kwargs = kwargs.pop("requests_kwargs", {}) + self.streaming = True params = self._invocation_params(stop, **kwargs) body = self._construct_json_body(prompt, params) - logger.info(f"LLM API Request:\n{prompt}") - response = self._send_request( - data=body, endpoint=self.endpoint, **requests_kwargs - ) - completion = self._process_response(response) - logger.info(f"LLM API Completion:\n{completion}") - return completion - def _send_request( + response = self.completion_with_retry( + data=body, run_manager=run_manager, stream=True, **requests_kwargs + ) + for line in self._parse_stream(response.iter_lines()): + chunk = self._handle_sse_line(line) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + + yield chunk + + async def _astream( self, - data: Any, - endpoint: str, - header: Optional[dict] = {}, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Dict: - """Sends request to the oci data science model deployment endpoint. + ) -> AsyncIterator[GenerationChunk]: + """Stream OCI Data Science Model Deployment endpoint async on given prompt. + Args: - data (Json serializable): - data need to be sent to the endpoint. - endpoint (str): - The model HTTP endpoint. - header (dict, optional): - A dictionary of HTTP headers to send to the specified url. - Defaults to {}. + prompt (str): + The prompt to pass into the model. + stop (List[str], Optional): + List of stop words to use when generating. kwargs: - Additional ``**kwargs`` to pass to requests.post. - Raises: - Exception: - Raise when invoking fails. + requests_kwargs: + Additional ``**kwargs`` to pass to requests.post Returns: - A JSON representation of a requests.Response object. + An iterator of GenerationChunks. + + + Example: + + .. code-block:: python + + async for chunk in llm.astream(("Tell me a joke."): + print(chunk, end="", flush=True) + """ - if not header: - header = {} - header["Content-Type"] = ( - header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON) - or DEFAULT_CONTENT_TYPE_JSON - ) - request_kwargs = {"json": data} - request_kwargs["headers"] = header - timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT) + requests_kwargs = kwargs.pop("requests_kwargs", {}) + self.streaming = True + params = self._invocation_params(stop, **kwargs) + body = self._construct_json_body(prompt, params) - attempts = 0 - while attempts < 2: - request_kwargs["auth"] = self.auth.get("signer") - response = requests.post( - endpoint, timeout=timeout, **request_kwargs, **kwargs - ) - if response.status_code == 401: - self._refresh_signer() - attempts += 1 - continue - break - - try: - response.raise_for_status() - response_json = response.json() - - except Exception: - logger.error( - "DEBUG INFO: request_kwargs=%s, status_code=%s, content=%s", - request_kwargs, - response.status_code, - response.content, - ) - raise - - return response_json - - def _refresh_signer(self) -> None: - if self.auth.get("signer", None) and hasattr( - self.auth["signer"], "refresh_security_token" + async for line in await self.acompletion_with_retry( + data=body, run_manager=run_manager, stream=True, **requests_kwargs ): - self.auth["signer"].refresh_security_token() + chunk = self._handle_sse_line(line) + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + + def _construct_json_body(self, prompt: str, params: dict) -> dict: + """Constructs the request body as a dictionary (JSON).""" + return { + "prompt": prompt, + **params, + } + + def _invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> dict: + """Combines the invocation parameters with default parameters.""" + params = self._default_params + _model_kwargs = self.model_kwargs or {} + params["stop"] = stop or params.get("stop", []) + return {**params, **_model_kwargs, **kwargs} + + def _process_stream_response(self, response_json: dict) -> GenerationChunk: + """Formats streaming response for OpenAI spec into GenerationChunk.""" + try: + choice = response_json["choices"][0] + if not isinstance(choice, dict): + raise TypeError("Endpoint response is not well formed.") + except (KeyError, IndexError, TypeError) as e: + raise ValueError("Error while formatting response payload.") from e + + return GenerationChunk(text=choice.get("text", "")) + + def _process_response(self, response_json: dict) -> List[Generation]: + """Formats response in OpenAI spec. + + Args: + response_json (dict): The JSON response from the chat model endpoint. + + Returns: + ChatResult: An object containing the list of `ChatGeneration` objects + and additional LLM output information. + + Raises: + ValueError: If the response JSON is not well-formed or does not + contain the expected structure. + + """ + generations = [] + try: + choices = response_json["choices"] + if not isinstance(choices, list): + raise TypeError("Endpoint response is not well formed.") + except (KeyError, TypeError) as e: + raise ValueError("Error while formatting response payload.") from e + + for choice in choices: + gen = Generation( + text=choice.get("text"), + generation_info=self._generate_info(choice), + ) + generations.append(gen) + + return generations + + def _generate_info(self, choice: dict) -> Any: + """Extracts generation info from the response.""" + gen_info = {} + finish_reason = choice.get("finish_reason", None) + logprobs = choice.get("logprobs", None) + index = choice.get("index", None) + if finish_reason: + gen_info.update({"finish_reason": finish_reason}) + if logprobs is not None: + gen_info.update({"logprobs": logprobs}) + if index is not None: + gen_info.update({"index": index}) + + return gen_info or None + + def _handle_sse_line(self, line: str) -> GenerationChunk: + try: + obj = json.loads(line) + return self._process_stream_response(obj) + except Exception: + return GenerationChunk(text="") class OCIModelDeploymentTGI(OCIModelDeploymentLLM): """OCI Data Science Model Deployment TGI Endpoint. To use, you must provide the model HTTP endpoint from your deployed - model, e.g. https:///predict. + model, e.g. https://modeldeployment..oci.customer-oci.com//predict. To authenticate, `oracle-ads` has been used to automatically load credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html @@ -228,12 +743,34 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM): Example: .. code-block:: python - from langchain_community.llms import ModelDeploymentTGI + from langchain_community.llms import OCIModelDeploymentTGI - oci_md = ModelDeploymentTGI(endpoint="https:///predict") + llm = OCIModelDeploymentTGI( + endpoint="https://modeldeployment..oci.customer-oci.com//predict", + api="/v1/completions", + streaming=True, + temperature=0.2, + seed=42, + # other model parameters ... + ) """ + api: Literal["/generate", "/v1/completions"] = "/v1/completions" + """Api spec.""" + + frequency_penalty: float = 0.0 + """Penalizes repeated tokens according to frequency. Between 0 and 1.""" + + seed: Optional[int] = None + """Random sampling seed""" + + repetition_penalty: Optional[float] = None + """The parameter for repetition penalty. 1.0 means no penalty.""" + + suffix: Optional[str] = None + """The text to append to the prompt. """ + do_sample: bool = True """If set to True, this parameter enables decoding strategies such as multi-nominal sampling, beam-search multi-nominal sampling, Top-K @@ -255,34 +792,78 @@ class OCIModelDeploymentTGI(OCIModelDeploymentLLM): @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for invoking OCI model deployment TGI endpoint.""" + return ( + { + "model": self.model, # can be any + "frequency_penalty": self.frequency_penalty, + "max_tokens": self.max_tokens, + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.p, + "seed": self.seed, + "stream": self.streaming, + "suffix": self.suffix, + "stop": self.stop, + } + if self.api == "/v1/completions" + else { + "best_of": self.best_of, + "max_new_tokens": self.max_tokens, + "temperature": self.temperature, + "top_k": ( + self.k if self.k > 0 else None + ), # `top_k` must be strictly positive' + "top_p": self.p, + "do_sample": self.do_sample, + "return_full_text": self.return_full_text, + "watermark": self.watermark, + "stop": self.stop, + } + ) + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} return { - "best_of": self.best_of, - "max_new_tokens": self.max_tokens, - "temperature": self.temperature, - "top_k": self.k - if self.k > 0 - else None, # `top_k` must be strictly positive' - "top_p": self.p, - "do_sample": self.do_sample, - "return_full_text": self.return_full_text, - "watermark": self.watermark, + **{ + "endpoint": self.endpoint, + "api": self.api, + "model_kwargs": _model_kwargs, + }, + **self._default_params, } def _construct_json_body(self, prompt: str, params: dict) -> dict: + """Construct request payload.""" + if self.api == "/v1/completions": + return super()._construct_json_body(prompt, params) + return { "inputs": prompt, "parameters": params, } - def _process_response(self, response_json: dict) -> str: - return str(response_json.get("generated_text", response_json)) + "\n" + def _process_response(self, response_json: dict) -> List[Generation]: + """Formats response.""" + if self.api == "/v1/completions": + return super()._process_response(response_json) + + try: + text = response_json["generated_text"] + except KeyError as e: + raise ValueError( + f"Error while formatting response payload.response_json={response_json}" + ) from e + + return [Generation(text=text)] class OCIModelDeploymentVLLM(OCIModelDeploymentLLM): """VLLM deployed on OCI Data Science Model Deployment To use, you must provide the model HTTP endpoint from your deployed - model, e.g. https:///predict. + model, e.g. https://modeldeployment..oci.customer-oci.com//predict. To authenticate, `oracle-ads` has been used to automatically load credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html @@ -296,16 +877,19 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM): from langchain_community.llms import OCIModelDeploymentVLLM - oci_md = OCIModelDeploymentVLLM( - endpoint="https:///predict", - model="mymodel" + llm = OCIModelDeploymentVLLM( + endpoint="https://modeldeployment..oci.customer-oci.com//predict", + model="odsc-llm", + streaming=False, + temperature=0.2, + max_tokens=512, + n=3, + best_of=3, + # other model parameters ) """ - model: str - """The name of the model.""" - n: int = 1 """Number of output sequences to return for the given prompt.""" @@ -346,17 +930,9 @@ class OCIModelDeploymentVLLM(OCIModelDeploymentLLM): "n": self.n, "presence_penalty": self.presence_penalty, "stop": self.stop, + "stream": self.streaming, "temperature": self.temperature, "top_k": self.k, "top_p": self.p, "use_beam_search": self.use_beam_search, } - - def _construct_json_body(self, prompt: str, params: dict) -> dict: - return { - "prompt": prompt, - **params, - } - - def _process_response(self, response_json: dict) -> str: - return response_json["choices"][0]["text"] diff --git a/libs/community/scripts/check_pydantic.sh b/libs/community/scripts/check_pydantic.sh index 1b5d616257..6d6f6dbcef 100755 --- a/libs/community/scripts/check_pydantic.sh +++ b/libs/community/scripts/check_pydantic.sh @@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini # PRs that increase the current count will not be accepted. # PRs that decrease update the code in the repository # and allow decreasing the count of are welcome! -current_count=128 +current_count=127 if [ "$count" -gt "$current_count" ]; then echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator." diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 6be5a41d9e..035ecbca8b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -30,6 +30,9 @@ EXPECTED_ALL = [ "ChatMLX", "ChatNebula", "ChatOCIGenAI", + "ChatOCIModelDeployment", + "ChatOCIModelDeploymentVLLM", + "ChatOCIModelDeploymentTGI", "ChatOllama", "ChatOpenAI", "ChatPerplexity", diff --git a/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py b/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py new file mode 100644 index 0000000000..f385f2ddb0 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_oci_data_science.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024, Oracle and/or its affiliates. + +"""Test Chat model for OCI Data Science Model Deployment Endpoint.""" + +import sys +from typing import Any, AsyncGenerator, Dict, Generator +from unittest import mock + +import pytest +from langchain_core.messages import AIMessage, AIMessageChunk +from requests.exceptions import HTTPError + +from langchain_community.chat_models import ( + ChatOCIModelDeploymentTGI, + ChatOCIModelDeploymentVLLM, +) + +CONST_MODEL_NAME = "odsc-vllm" +CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" +CONST_PROMPT = "This is a prompt." +CONST_COMPLETION = "This is a completion." +CONST_COMPLETION_RESPONSE = { + "id": "chat-123456789", + "object": "chat.completion", + "created": 123456789, + "model": "mistral", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": CONST_COMPLETION, + "tool_calls": [], + }, + "logprobs": None, + "finish_reason": "length", + "stop_reason": None, + } + ], + "usage": {"prompt_tokens": 115, "total_tokens": 371, "completion_tokens": 256}, + "prompt_logprobs": None, +} +CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION} +CONST_STREAM_TEMPLATE = ( + 'data: {"id":"chat-123456","object":"chat.completion.chunk","created":123456789,' + '"model":"odsc-llm","choices":[{"index":0,"delta":,"finish_reason":null}]}' +) +CONST_STREAM_DELTAS = ['{"role":"assistant","content":""}'] + [ + '{"content":" ' + word + '"}' for word in CONST_COMPLETION.split(" ") +] +CONST_STREAM_RESPONSE = ( + content + for content in [ + CONST_STREAM_TEMPLATE.replace("", delta).encode() + for delta in CONST_STREAM_DELTAS + ] + + [b"data: [DONE]"] +) + +CONST_ASYNC_STREAM_TEMPLATE = ( + '{"id":"chat-123456","object":"chat.completion.chunk","created":123456789,' + '"model":"odsc-llm","choices":[{"index":0,"delta":,"finish_reason":null}]}' +) +CONST_ASYNC_STREAM_RESPONSE = ( + CONST_ASYNC_STREAM_TEMPLATE.replace("", delta).encode() + for delta in CONST_STREAM_DELTAS +) + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires Python 3.9 or higher" +) + + +class MockResponse: + """Represents a mocked response.""" + + def __init__(self, json_data: Dict, status_code: int = 200): + self.json_data = json_data + self.status_code = status_code + + def raise_for_status(self) -> None: + """Mocked raise for status.""" + if 400 <= self.status_code < 600: + raise HTTPError() + + def json(self) -> Dict: + """Returns mocked json data.""" + return self.json_data + + def iter_lines(self, chunk_size: int = 4096) -> Generator[bytes, None, None]: + """Returns a generator of mocked streaming response.""" + return CONST_STREAM_RESPONSE + + @property + def text(self) -> str: + """Returns the mocked text representation.""" + return "" + + +def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse: + """Method to mock post requests""" + + payload: dict = kwargs.get("json", {}) + messages: list = payload.get("messages", []) + prompt = messages[0].get("content") + + if prompt == CONST_PROMPT: + return MockResponse(json_data=CONST_COMPLETION_RESPONSE) + + return MockResponse( + json_data={}, + status_code=404, + ) + + +@pytest.mark.requires("ads") +@pytest.mark.requires("langchain_openai") +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_invoke_vllm(*args: Any) -> None: + """Tests invoking vLLM endpoint.""" + llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + output = llm.invoke(CONST_PROMPT) + assert isinstance(output, AIMessage) + assert output.content == CONST_COMPLETION + + +@pytest.mark.requires("ads") +@pytest.mark.requires("langchain_openai") +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_invoke_tgi(*args: Any) -> None: + """Tests invoking TGI endpoint using OpenAI Spec.""" + llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + output = llm.invoke(CONST_PROMPT) + assert isinstance(output, AIMessage) + assert output.content == CONST_COMPLETION + + +@pytest.mark.requires("ads") +@pytest.mark.requires("langchain_openai") +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_stream_vllm(*args: Any) -> None: + """Tests streaming with vLLM endpoint using OpenAI spec.""" + llm = ChatOCIModelDeploymentVLLM( + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True + ) + output = None + count = 0 + for chunk in llm.stream(CONST_PROMPT): + assert isinstance(chunk, AIMessageChunk) + if output is None: + output = chunk + else: + output += chunk + count += 1 + assert count == 5 + assert output is not None + if output is not None: + assert str(output.content).strip() == CONST_COMPLETION + + +async def mocked_async_streaming_response( + *args: Any, **kwargs: Any +) -> AsyncGenerator[bytes, None]: + """Returns mocked response for async streaming.""" + for item in CONST_ASYNC_STREAM_RESPONSE: + yield item + + +@pytest.mark.asyncio +@pytest.mark.requires("ads") +@pytest.mark.requires("langchain_openai") +@mock.patch( + "ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock()) +) +@mock.patch( + "langchain_community.utilities.requests.Requests.apost", + mock.MagicMock(), +) +async def test_stream_async(*args: Any) -> None: + """Tests async streaming.""" + llm = ChatOCIModelDeploymentVLLM( + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True + ) + with mock.patch.object( + llm, + "_aiter_sse", + mock.MagicMock(return_value=mocked_async_streaming_response()), + ): + chunks = [str(chunk.content) async for chunk in llm.astream(CONST_PROMPT)] + assert "".join(chunks).strip() == CONST_COMPLETION diff --git a/libs/community/tests/unit_tests/chat_models/test_oci_model_deployment_endpoint.py b/libs/community/tests/unit_tests/chat_models/test_oci_model_deployment_endpoint.py new file mode 100644 index 0000000000..62afd4828b --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_oci_model_deployment_endpoint.py @@ -0,0 +1,98 @@ +"""Test OCI Data Science Model Deployment Endpoint.""" + +import pytest +import responses +from langchain_core.messages import AIMessage, HumanMessage +from pytest_mock import MockerFixture + +from langchain_community.chat_models import ChatOCIModelDeployment + + +@pytest.mark.requires("ads") +def test_initialization(mocker: MockerFixture) -> None: + """Test chat model initialization.""" + mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + chat = ChatOCIModelDeployment( + model="odsc", + endpoint="test_endpoint", + model_kwargs={"temperature": 0.2}, + ) + assert chat.model == "odsc" + assert chat.endpoint == "test_endpoint" + assert chat.model_kwargs == {"temperature": 0.2} + assert chat._identifying_params == { + "endpoint": chat.endpoint, + "model_kwargs": {"temperature": 0.2}, + "model": chat.model, + "stop": chat.stop, + "stream": chat.streaming, + } + + +@pytest.mark.requires("ads") +@responses.activate +def test_call(mocker: MockerFixture) -> None: + """Test valid call to oci model deployment endpoint.""" + endpoint = "https://MD_OCID/predict" + responses.add( + responses.POST, + endpoint, + json={ + "id": "cmpl-88159e77c92f46088faad75fce2e26a1", + "object": "chat.completion", + "created": 274246, + "model": "odsc-llm", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello World", + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 20, + "completion_tokens": 10, + }, + }, + status=200, + ) + mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + + chat = ChatOCIModelDeployment(endpoint=endpoint) + output = chat.invoke("this is a test.") + assert isinstance(output, AIMessage) + assert output.response_metadata == { + "token_usage": { + "prompt_tokens": 10, + "total_tokens": 20, + "completion_tokens": 10, + }, + "model_name": "odsc-llm", + "system_fingerprint": "", + "finish_reason": "stop", + } + + +@pytest.mark.requires("ads") +@responses.activate +def test_construct_json_body(mocker: MockerFixture) -> None: + """Tests constructing json body that will be sent to endpoint.""" + mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + messages = [ + HumanMessage(content="User message"), + ] + chat = ChatOCIModelDeployment( + endpoint="test_endpoint", model_kwargs={"temperature": 0.2} + ) + payload = chat._construct_json_body(messages, chat._invocation_params(stop=None)) + assert payload == { + "messages": [{"content": "User message", "role": "user"}], + "model": chat.model, + "stop": None, + "stream": chat.streaming, + "temperature": 0.2, + } diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index dd1a089fe4..df0fe68b59 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -56,6 +56,7 @@ EXPECT_ALL = [ "Modal", "MosaicML", "Nebula", + "OCIModelDeploymentLLM", "OCIModelDeploymentTGI", "OCIModelDeploymentVLLM", "OCIGenAI", diff --git a/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py b/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py index 6552316176..c87b05f12a 100644 --- a/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py +++ b/libs/community/tests/unit_tests/llms/test_oci_model_deployment_endpoint.py @@ -1,47 +1,170 @@ -"""Test OCI Data Science Model Deployment Endpoint.""" +# Copyright (c) 2023, 2024, Oracle and/or its affiliates. + +"""Test LLM for OCI Data Science Model Deployment Endpoint.""" + +import sys +from typing import Any, AsyncGenerator, Dict, Generator +from unittest import mock import pytest -import responses -from pytest_mock import MockerFixture +from requests.exceptions import HTTPError -from langchain_community.llms import OCIModelDeploymentTGI, OCIModelDeploymentVLLM +from langchain_community.llms.oci_data_science_model_deployment_endpoint import ( + OCIModelDeploymentTGI, + OCIModelDeploymentVLLM, +) + +CONST_MODEL_NAME = "odsc-vllm" +CONST_ENDPOINT = "https://oci.endpoint/ocid/predict" +CONST_PROMPT = "This is a prompt." +CONST_COMPLETION = "This is a completion." +CONST_COMPLETION_RESPONSE = { + "choices": [ + { + "index": 0, + "text": CONST_COMPLETION, + "logprobs": 0.1, + "finish_reason": "length", + } + ], +} +CONST_COMPLETION_RESPONSE_TGI = {"generated_text": CONST_COMPLETION} +CONST_STREAM_TEMPLATE = ( + 'data: {"id":"","object":"text_completion","created":123456,' + + '"choices":[{"index":0,"text":"","finish_reason":""}]}' +) +CONST_STREAM_RESPONSE = ( + CONST_STREAM_TEMPLATE.replace("", " " + word).encode() + for word in CONST_COMPLETION.split(" ") +) + +CONST_ASYNC_STREAM_TEMPLATE = ( + '{"id":"","object":"text_completion","created":123456,' + + '"choices":[{"index":0,"text":"","finish_reason":""}]}' +) +CONST_ASYNC_STREAM_RESPONSE = ( + CONST_ASYNC_STREAM_TEMPLATE.replace("", " " + word).encode() + for word in CONST_COMPLETION.split(" ") +) + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires Python 3.9 or higher" +) + + +class MockResponse: + """Represents a mocked response.""" + + def __init__(self, json_data: Dict, status_code: int = 200) -> None: + self.json_data = json_data + self.status_code = status_code + + def raise_for_status(self) -> None: + """Mocked raise for status.""" + if 400 <= self.status_code < 600: + raise HTTPError() + + def json(self) -> Dict: + """Returns mocked json data.""" + return self.json_data + + def iter_lines(self, chunk_size: int = 4096) -> Generator[bytes, None, None]: + """Returns a generator of mocked streaming response.""" + return CONST_STREAM_RESPONSE + + @property + def text(self) -> str: + """Returns the mocked text representation.""" + return "" + + +def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse: + """Method to mock post requests""" + + payload: dict = kwargs.get("json", {}) + if "inputs" in payload: + prompt = payload.get("inputs") + is_tgi = True + else: + prompt = payload.get("prompt") + is_tgi = False + + if prompt == CONST_PROMPT: + if is_tgi: + return MockResponse(json_data=CONST_COMPLETION_RESPONSE_TGI) + return MockResponse(json_data=CONST_COMPLETION_RESPONSE) + + return MockResponse( + json_data={}, + status_code=404, + ) + + +async def mocked_async_streaming_response( + *args: Any, **kwargs: Any +) -> AsyncGenerator[bytes, None]: + """Returns mocked response for async streaming.""" + for item in CONST_ASYNC_STREAM_RESPONSE: + yield item @pytest.mark.requires("ads") -@responses.activate -def test_call_vllm(mocker: MockerFixture) -> None: - """Test valid call to oci model deployment endpoint.""" - endpoint = "https://MD_OCID/predict" - responses.add( - responses.POST, - endpoint, - json={ - "choices": [{"index": 0, "text": "This is a completion."}], - }, - status=200, - ) - mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) - - llm = OCIModelDeploymentVLLM(endpoint=endpoint, model="my_model") - output = llm.invoke("This is a prompt.") - assert isinstance(output, str) +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_invoke_vllm(*args: Any) -> None: + """Tests invoking vLLM endpoint.""" + llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME) + output = llm.invoke(CONST_PROMPT) + assert output == CONST_COMPLETION @pytest.mark.requires("ads") -@responses.activate -def test_call_tgi(mocker: MockerFixture) -> None: - """Test valid call to oci model deployment endpoint.""" - endpoint = "https://MD_OCID/predict" - responses.add( - responses.POST, - endpoint, - json={ - "generated_text": "This is a completion.", - }, - status=200, +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_stream_tgi(*args: Any) -> None: + """Tests streaming with TGI endpoint using OpenAI spec.""" + llm = OCIModelDeploymentTGI( + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True ) - mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) + output = "" + count = 0 + for chunk in llm.stream(CONST_PROMPT): + output += chunk + count += 1 + assert count == 4 + assert output.strip() == CONST_COMPLETION - llm = OCIModelDeploymentTGI(endpoint=endpoint) - output = llm.invoke("This is a prompt.") - assert isinstance(output, str) + +@pytest.mark.requires("ads") +@mock.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) +@mock.patch("requests.post", side_effect=mocked_requests_post) +def test_generate_tgi(*args: Any) -> None: + """Tests invoking TGI endpoint using TGI generate spec.""" + llm = OCIModelDeploymentTGI( + endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME + ) + output = llm.invoke(CONST_PROMPT) + assert output == CONST_COMPLETION + + +@pytest.mark.asyncio +@pytest.mark.requires("ads") +@mock.patch( + "ads.common.auth.default_signer", return_value=dict(signer=mock.MagicMock()) +) +@mock.patch( + "langchain_community.utilities.requests.Requests.apost", + mock.MagicMock(), +) +async def test_stream_async(*args: Any) -> None: + """Tests async streaming.""" + llm = OCIModelDeploymentTGI( + endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True + ) + with mock.patch.object( + llm, + "_aiter_sse", + mock.MagicMock(return_value=mocked_async_streaming_response()), + ): + chunks = [chunk async for chunk in llm.astream(CONST_PROMPT)] + assert "".join(chunks).strip() == CONST_COMPLETION