diff --git a/docs/docs/integrations/chat/oci_generative_ai.ipynb b/docs/docs/integrations/chat/oci_generative_ai.ipynb new file mode 100644 index 0000000000..4ce58a13fb --- /dev/null +++ b/docs/docs/integrations/chat/oci_generative_ai.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "afaf8039", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: OCIGenAI\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "e49f1e0d", + "metadata": {}, + "source": [ + "# ChatOCIGenAI\n", + "\n", + "This notebook provides a quick overview for getting started with OCIGenAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatOCIGenAI features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html).\n", + "\n", + "Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n", + "Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n", + "\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatOCIGenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-oci-generative-ai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-oci-generative-ai?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access OCIGenAI models you'll need to install the `oci` and `langchain-community` packages.\n", + "\n", + "### Credentials\n", + "\n", + "The credentials and authentication methods supported for this integration are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n", + "\n", + "API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)" + ] + }, + { + "cell_type": "markdown", + "id": "0730d6a1-c893-4840-9817-5e5251676d5d", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain OCIGenAI integration lives in the `langchain-community` package and you will also need to install the `oci` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "652d6238-1f87-422a-b135-f5abbb8652fc", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain-community oci" + ] + }, + { + "cell_type": "markdown", + "id": "a38cde65-254d-4219-a441-068766c0d4b5", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI\n", + "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n", + "\n", + "chat = ChatOCIGenAI(\n", + " model_id=\"cohere.command-r-16k\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + " model_kwargs={\"temperature\": 0.7, \"max_tokens\": 500},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2b4f3e15", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62e0dbc3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "messages = [\n", + " SystemMessage(content=\"your are an AI assistant.\"),\n", + " AIMessage(content=\"Hi there human!\"),\n", + " HumanMessage(content=\"tell me a joke.\"),\n", + "]\n", + "response = chat.invoke(messages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d86145b3-bfef-46e8-b227-4dda5c9c2705", + "metadata": {}, + "outputs": [], + "source": [ + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "id": "18e2bfc0-7e78-4528-a73f-499ac150dca8", + "metadata": {}, + "source": [ + "## Chaining\n", + "\n", + "We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n", + "chain = prompt | chat\n", + "\n", + "response = chain.invoke({\"topic\": \"dogs\"})\n", + "print(response.content)" + ] + }, + { + "cell_type": "markdown", + "id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all ChatOCIGenAI features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/docs/integrations/llms/oci_generative_ai.ipynb b/docs/docs/integrations/llms/oci_generative_ai.ipynb index 0c2368efdc..3da80aef0e 100644 --- a/docs/docs/integrations/llms/oci_generative_ai.ipynb +++ b/docs/docs/integrations/llms/oci_generative_ai.ipynb @@ -14,15 +14,15 @@ "Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n", "Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n", "\n", - "This notebook explains how to use OCI's Genrative AI models with LangChain." + "This notebook explains how to use OCI's Generative AI complete models with LangChain." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Prerequisite\n", - "We will need to install the oci sdk" + "## Setup\n", + "Ensure that the oci sdk and the langchain-community package are installed" ] }, { @@ -31,31 +31,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -U oci" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### OCI Generative AI API endpoint \n", - "https://inference.generativeai.us-chicago-1.oci.oraclecloud.com" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Authentication\n", - "The authentication methods supported for this langchain integration are:\n", - "\n", - "1. API Key\n", - "2. Session token\n", - "3. Instance principal\n", - "4. Resource principal \n", - "\n", - "These follows the standard SDK authentication methods detailed __[here](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__.\n", - " " + "!pip install -U oci langchain-community" ] }, { @@ -71,13 +47,13 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.llms import OCIGenAI\n", + "from langchain_community.llms.oci_generative_ai import OCIGenAI\n", "\n", - "# use default authN method API-key\n", "llm = OCIGenAI(\n", - " model_id=\"MY_MODEL\",\n", + " model_id=\"cohere.command\",\n", " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", " compartment_id=\"MY_OCID\",\n", + " model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n", ")\n", "\n", "response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n", @@ -85,30 +61,10 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from langchain.chains import LLMChain\n", - "from langchain_core.prompts import PromptTemplate\n", - "\n", - "# Use Session Token to authN\n", - "llm = OCIGenAI(\n", - " model_id=\"MY_MODEL\",\n", - " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", - " compartment_id=\"MY_OCID\",\n", - " auth_type=\"SECURITY_TOKEN\",\n", - " auth_profile=\"MY_PROFILE\", # replace with your profile name\n", - " model_kwargs={\"temperature\": 0.7, \"top_p\": 0.75, \"max_tokens\": 200},\n", - ")\n", - "\n", - "prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n", - "\n", - "llm_chain = LLMChain(llm=llm, prompt=prompt)\n", - "\n", - "response = llm_chain.invoke(\"what is the capital of france?\")\n", - "print(response)" + "#### Chaining with prompt templates" ] }, { @@ -117,49 +73,95 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.embeddings import OCIGenAIEmbeddings\n", - "from langchain_community.vectorstores import FAISS\n", - "from langchain_core.output_parsers import StrOutputParser\n", - "from langchain_core.runnables import RunnablePassthrough\n", - "\n", - "embeddings = OCIGenAIEmbeddings(\n", - " model_id=\"MY_EMBEDDING_MODEL\",\n", - " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", - " compartment_id=\"MY_OCID\",\n", - ")\n", - "\n", - "vectorstore = FAISS.from_texts(\n", - " [\n", - " \"Larry Ellison co-founded Oracle Corporation in 1977 with Bob Miner and Ed Oates.\",\n", - " \"Oracle Corporation is an American multinational computer technology company headquartered in Austin, Texas, United States.\",\n", - " ],\n", - " embedding=embeddings,\n", - ")\n", - "\n", - "retriever = vectorstore.as_retriever()\n", - "\n", - "template = \"\"\"Answer the question based only on the following context:\n", - "{context}\n", - " \n", - "Question: {question}\n", - "\"\"\"\n", - "prompt = PromptTemplate.from_template(template)\n", + "from langchain_core.prompts import PromptTemplate\n", "\n", "llm = OCIGenAI(\n", - " model_id=\"MY_MODEL\",\n", + " model_id=\"cohere.command\",\n", " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", " compartment_id=\"MY_OCID\",\n", + " model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n", ")\n", "\n", - "chain = (\n", - " {\"context\": retriever, \"question\": RunnablePassthrough()}\n", - " | prompt\n", - " | llm\n", - " | StrOutputParser()\n", + "prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n", + "llm_chain = prompt | llm\n", + "\n", + "response = llm_chain.invoke(\"what is the capital of france?\")\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OCIGenAI(\n", + " model_id=\"cohere.command\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + " model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n", ")\n", "\n", - "print(chain.invoke(\"when was oracle founded?\"))\n", - "print(chain.invoke(\"where is oracle headquartered?\"))" + "for chunk in llm.stream(\"Write me a song about sparkling water.\"):\n", + " print(chunk, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Authentication\n", + "The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n", + "\n", + "API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OCIGenAI(\n", + " model_id=\"cohere.command\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"MY_OCID\",\n", + " auth_type=\"SECURITY_TOKEN\",\n", + " auth_profile=\"MY_PROFILE\", # replace with your profile name\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dedicated AI Cluster\n", + "To access models hosted in a dedicated AI cluster __[create an endpoint](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/)__ whose assigned OCID (currently prefixed by ‘ocid1.generativeaiendpoint.oc1.us-chicago-1’) is used as your model ID.\n", + "\n", + "When accessing models hosted in a dedicated AI cluster you will need to initialize the OCIGenAI interface with two extra required params (\"provider\" and \"context_size\")." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OCIGenAI(\n", + " model_id=\"ocid1.generativeaiendpoint.oc1.us-chicago-1....\",\n", + " service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n", + " compartment_id=\"DEDICATED_COMPARTMENT_OCID\",\n", + " auth_profile=\"MY_PROFILE\", # replace with your profile name,\n", + " provider=\"MODEL_PROVIDER\", # e.g., \"cohere\" or \"meta\"\n", + " context_size=\"MODEL_CONTEXT_SIZE\", # e.g., 128000\n", + ")" ] } ], diff --git a/docs/docs/integrations/providers/oci.mdx b/docs/docs/integrations/providers/oci.mdx index e0b3570028..5037fb86f1 100644 --- a/docs/docs/integrations/providers/oci.mdx +++ b/docs/docs/integrations/providers/oci.mdx @@ -2,27 +2,29 @@ The `LangChain` integrations related to [Oracle Cloud Infrastructure](https://www.oracle.com/artificial-intelligence/). -## LLMs - -### OCI Generative AI +## OCI Generative AI > Oracle Cloud Infrastructure (OCI) [Generative AI](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) is a fully managed service that provides a set of state-of-the-art, > customizable large language models (LLMs) that cover a wide range of use cases, and which are available through a single API. > Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned > custom models based on your own data on dedicated AI clusters. -To use, you should have the latest `oci` python SDK installed. +To use, you should have the latest `oci` python SDK and the langchain_community package installed. ```bash -pip install -U oci +pip install -U oci langchain-community ``` -See [usage examples](/docs/integrations/llms/oci_generative_ai). +See [chat](/docs/integrations/llms/oci_generative_ai), [complete](/docs/integrations/chat/oci_generative_ai), and [embedding](/docs/integrations/text_embedding/oci_generative_ai) usage examples. ```python +from langchain_community.chat_models import ChatOCIGenAI + from langchain_community.llms import OCIGenAI + +from langchain_community.embeddings import OCIGenAIEmbeddings ``` -### OCI Data Science Model Deployment Endpoint +## OCI Data Science Model Deployment Endpoint > [OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a > fully managed and serverless platform for data science teams. Using the OCI Data Science @@ -47,12 +49,3 @@ from langchain_community.llms import OCIModelDeploymentVLLM from langchain_community.llms import OCIModelDeploymentTGI ``` -## Text Embedding Models - -### OCI Generative AI - -See [usage examples](/docs/integrations/text_embedding/oci_generative_ai). - -```python -from langchain_community.embeddings import OCIGenAIEmbeddings -``` \ No newline at end of file diff --git a/libs/community/extended_testing_deps.txt b/libs/community/extended_testing_deps.txt index 9f5e8284af..db8d9cdfd0 100644 --- a/libs/community/extended_testing_deps.txt +++ b/libs/community/extended_testing_deps.txt @@ -46,7 +46,7 @@ mwxml>=0.3.3,<0.4 newspaper3k>=0.2.8,<0.3 numexpr>=2.8.6,<3 nvidia-riva-client>=2.14.0,<3 -oci>=2.119.1,<3 +oci>=2.128.0,<3 openai<2 openapi-pydantic>=0.3.2,<0.4 oracle-ads>=2.9.1,<3 diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 7b942a26ca..af25b60184 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -121,6 +121,9 @@ if TYPE_CHECKING: from langchain_community.chat_models.mlx import ( ChatMLX, ) + from langchain_community.chat_models.oci_generative_ai import ( + ChatOCIGenAI, # noqa: F401 + ) from langchain_community.chat_models.octoai import ChatOctoAI from langchain_community.chat_models.ollama import ( ChatOllama, @@ -194,6 +197,7 @@ __all__ = [ "ChatMLflowAIGateway", "ChatMaritalk", "ChatMlflow", + "ChatOCIGenAI", "ChatOllama", "ChatOpenAI", "ChatPerplexity", @@ -248,6 +252,7 @@ _module_lookup = { "ChatMaritalk": "langchain_community.chat_models.maritalk", "ChatMlflow": "langchain_community.chat_models.mlflow", "ChatOctoAI": "langchain_community.chat_models.octoai", + "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", "ChatOllama": "langchain_community.chat_models.ollama", "ChatOpenAI": "langchain_community.chat_models.openai", "ChatPerplexity": "langchain_community.chat_models.perplexity", diff --git a/libs/community/langchain_community/chat_models/oci_generative_ai.py b/libs/community/langchain_community/chat_models/oci_generative_ai.py new file mode 100644 index 0000000000..9409b1a2fb --- /dev/null +++ b/libs/community/langchain_community/chat_models/oci_generative_ai.py @@ -0,0 +1,363 @@ +import json +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Extra + +from langchain_community.llms.oci_generative_ai import OCIGenAIBase +from langchain_community.llms.utils import enforce_stop_tokens + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" + + +class Provider(ABC): + @property + @abstractmethod + def stop_sequence_key(self) -> str: + ... + + @abstractmethod + def chat_response_to_text(self, response: Any) -> str: + ... + + @abstractmethod + def chat_stream_to_text(self, event_data: Dict) -> str: + ... + + @abstractmethod + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + ... + + @abstractmethod + def get_role(self, message: BaseMessage) -> str: + ... + + @abstractmethod + def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]: + ... + + +class CohereProvider(Provider): + stop_sequence_key = "stop_sequences" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.oci_chat_request = models.CohereChatRequest + self.oci_chat_message = { + "USER": models.CohereUserMessage, + "CHATBOT": models.CohereChatBotMessage, + "SYSTEM": models.CohereSystemMessage, + } + self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE + + def chat_response_to_text(self, response: Any) -> str: + return response.data.chat_response.text + + def chat_stream_to_text(self, event_data: Dict) -> str: + if "text" in event_data and "finishReason" not in event_data: + return event_data["text"] + else: + return "" + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + return { + "finish_reason": response.data.chat_response.finish_reason, + } + + def get_role(self, message: BaseMessage) -> str: + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "CHATBOT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + else: + raise ValueError(f"Got unknown type {message}") + + def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]: + oci_chat_history = [ + self.oci_chat_message[self.get_role(msg)](message=msg.content) + for msg in messages[:-1] + ] + oci_params = { + "message": messages[-1].content, + "chat_history": oci_chat_history, + "api_format": self.chat_api_format, + } + + return oci_params + + +class MetaProvider(Provider): + stop_sequence_key = "stop" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.oci_chat_request = models.GenericChatRequest + self.oci_chat_message = { + "USER": models.UserMessage, + "SYSTEM": models.SystemMessage, + "ASSISTANT": models.AssistantMessage, + } + self.oci_chat_message_content = models.TextContent + self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC + + def chat_response_to_text(self, response: Any) -> str: + return response.data.chat_response.choices[0].message.content[0].text + + def chat_stream_to_text(self, event_data: Dict) -> str: + if "message" in event_data: + return event_data["message"]["content"][0]["text"] + else: + return "" + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + return { + "finish_reason": response.data.chat_response.choices[0].finish_reason, + "time_created": str(response.data.chat_response.time_created), + } + + def get_role(self, message: BaseMessage) -> str: + # meta only supports alternating user/assistant roles + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "ASSISTANT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + else: + raise ValueError(f"Got unknown type {message}") + + def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]: + oci_messages = [ + self.oci_chat_message[self.get_role(msg)]( + content=[self.oci_chat_message_content(text=msg.content)] + ) + for msg in messages + ] + oci_params = { + "messages": oci_messages, + "api_format": self.chat_api_format, + "top_k": -1, + } + + return oci_params + + +class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): + """ChatOCIGenAI chat model integration. + + Setup: + Install ``langchain-community`` and the ``oci`` sdk. + + .. code-block:: bash + + pip install -U langchain-community oci + + Key init args — completion params: + model_id: str + Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k. + is_stream: bool + Whether to stream back partial progress + model_kwargs: Optional[Dict] + Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens. + + Key init args — client params: + service_endpoint: str + The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com. + compartment_id: str + The compartment OCID. + auth_type: str + The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL. + auth_profile: Optional[str] + The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used. + provider: str + Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input. + See full list of supported init args and their descriptions in the params section. + + Instantiate: + .. code-block:: python + + from langchain_community.chat_models import ChatOCIGenAI + + chat = ChatOCIGenAI( + model_id="cohere.command-r-16k", + service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", + compartment_id="MY_OCID", + model_kwargs={"temperature": 0.7, "max_tokens": 500}, + ) + + Invoke: + .. code-block:: python + messages = [ + SystemMessage(content="your are an AI assistant."), + AIMessage(content="Hi there human!"), + HumanMessage(content="tell me a joke."), + ] + response = chat.invoke(messages) + + Stream: + .. code-block:: python + + for r in chat.stream(messages): + print(r.content, end="", flush=True) + + Response metadata + .. code-block:: python + + response = chat.invoke(messages) + print(response.response_metadata) + + """ # noqa: E501 + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "oci_generative_ai_chat" + + @property + def _provider_map(self) -> Mapping[str, Any]: + """Get the provider map""" + return { + "cohere": CohereProvider(), + "meta": MetaProvider(), + } + + @property + def _provider(self) -> Any: + """Get the internal provider object""" + return self._get_provider(provider_map=self._provider_map) + + def _prepare_request( + self, + messages: List[BaseMessage], + stop: Optional[List[str]], + kwargs: Dict[str, Any], + stream: bool, + ) -> Dict[str, Any]: + try: + from oci.generative_ai_inference import models + + except ImportError as ex: + raise ModuleNotFoundError( + "Could not import oci python package. " + "Please make sure you have the oci package installed." + ) from ex + oci_params = self._provider.messages_to_oci_params(messages) + oci_params["is_stream"] = stream # self.is_stream + _model_kwargs = self.model_kwargs or {} + + if stop is not None: + _model_kwargs[self._provider.stop_sequence_key] = stop + + chat_params = {**_model_kwargs, **kwargs, **oci_params} + + if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): + serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id) + else: + serving_mode = models.OnDemandServingMode(model_id=self.model_id) + + request = models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=serving_mode, + chat_request=self._provider.oci_chat_request(**chat_params), + ) + + return request + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Call out to a OCIGenAI chat model. + + Args: + messages: list of LangChain messages + stop: Optional list of stop words to use. + + Returns: + LangChain ChatResult + + Example: + .. code-block:: python + + messages = [ + HumanMessage(content="hello!"), + AIMessage(content="Hi there human!"), + HumanMessage(content="Meow!") + ] + + response = llm.invoke(messages) + """ + if self.is_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + request = self._prepare_request(messages, stop, kwargs, stream=False) + response = self.client.chat(request) + + content = self._provider.chat_response_to_text(response) + + if stop is not None: + content = enforce_stop_tokens(content, stop) + + generation_info = self._provider.chat_generation_info(response) + + llm_output = { + "model_id": response.data.model_id, + "model_version": response.data.model_version, + "request_id": response.request_id, + "content-length": response.headers["content-length"], + } + + return ChatResult( + generations=[ + ChatGeneration( + message=AIMessage(content=content), generation_info=generation_info + ) + ], + llm_output=llm_output, + ) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self._prepare_request(messages, stop, kwargs, stream=True) + response = self.client.chat(request) + + for event in response.data.events(): + delta = self._provider.chat_stream_to_text(json.loads(event.data)) + chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) + if run_manager: + run_manager.on_llm_new_token(delta, chunk=chunk) + yield chunk diff --git a/libs/community/langchain_community/llms/oci_generative_ai.py b/libs/community/langchain_community/llms/oci_generative_ai.py index 2c2935cc76..178694656e 100644 --- a/libs/community/langchain_community/llms/oci_generative_ai.py +++ b/libs/community/langchain_community/llms/oci_generative_ai.py @@ -1,17 +1,53 @@ from __future__ import annotations -from abc import ABC +import json +from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_community.llms.utils import enforce_stop_tokens CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" -VALID_PROVIDERS = ("cohere", "meta") + + +class Provider(ABC): + @property + @abstractmethod + def stop_sequence_key(self) -> str: + ... + + @abstractmethod + def completion_response_to_text(self, response: Any) -> str: + ... + + +class CohereProvider(Provider): + stop_sequence_key = "stop_sequences" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.llm_inference_request = models.CohereLlmInferenceRequest + + def completion_response_to_text(self, response: Any) -> str: + return response.data.inference_response.generated_texts[0].text + + +class MetaProvider(Provider): + stop_sequence_key = "stop" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.llm_inference_request = models.LlamaLlmInferenceRequest + + def completion_response_to_text(self, response: Any) -> str: + return response.data.inference_response.choices[0].text class OCIAuthType(Enum): @@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC): API_KEY, SECURITY_TOKEN, - INSTANCE_PRINCIPLE, - RESOURCE_PRINCIPLE + INSTANCE_PRINCIPAL, + RESOURCE_PRINCIPAL If not specified, API_KEY will be used """ @@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC): is_stream: bool = False """Whether to stream back partial progress""" - llm_stop_sequence_mapping: Mapping[str, str] = { - "cohere": "stop_sequences", - "meta": "stop", - } - @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that OCI config and python package exists in environment.""" @@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC): "signer" ] = oci.auth.signers.get_resource_principals_signer() else: - raise ValueError("Please provide valid value to auth_type") + raise ValueError( + "Please provide valid value to auth_type, " + f"{values['auth_type']} is not valid." + ) values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs ) except ImportError as ex: - raise ImportError( + raise ModuleNotFoundError( "Could not import oci python package. " "Please make sure you have the oci package installed." ) from ex except Exception as e: raise ValueError( - "Could not authenticate with OCI client. " - "Please check if ~/.oci/config exists. " - "If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, " - "Please check the specified " - "auth_profile and auth_type are valid." + """Could not authenticate with OCI client. + Please check if ~/.oci/config exists. + If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, + please check the specified + auth_profile and auth_type are valid.""", + e, ) from e return values @@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC): **{"model_kwargs": _model_kwargs}, } - def _get_provider(self) -> str: + def _get_provider(self, provider_map: Mapping[str, Any]) -> Any: if self.provider is not None: provider = self.provider else: provider = self.model_id.split(".")[0].lower() - if provider not in VALID_PROVIDERS: + if provider not in provider_map: raise ValueError( f"Invalid provider derived from model_id: {self.model_id} " "Please explicitly pass in the supported provider " "when using custom endpoint" ) - return provider + return provider_map[provider] class OCIGenAI(LLM, OCIGenAIBase): @@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase): https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm The authentifcation method is passed through auth_type and should be one of: - API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE + API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL Make sure you have the required policies (profile/roles) to access the OCI Generative AI service. @@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase): @property def _llm_type(self) -> str: """Return type of llm.""" - return "oci" + return "oci_generative_ai_completion" + + @property + def _provider_map(self) -> Mapping[str, Any]: + """Get the provider map""" + return { + "cohere": CohereProvider(), + "meta": MetaProvider(), + } + + @property + def _provider(self) -> Any: + """Get the internal provider object""" + return self._get_provider(provider_map=self._provider_map) def _prepare_invocation_object( self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any] ) -> Dict[str, Any]: from oci.generative_ai_inference import models - oci_llm_request_mapping = { - "cohere": models.CohereLlmInferenceRequest, - "meta": models.LlamaLlmInferenceRequest, - } - provider = self._get_provider() _model_kwargs = self.model_kwargs or {} if stop is not None: - _model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop + _model_kwargs[self._provider.stop_sequence_key] = stop if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX): serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id) @@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase): invocation_obj = models.GenerateTextDetails( compartment_id=self.compartment_id, serving_mode=serving_mode, - inference_request=oci_llm_request_mapping[provider](**inference_params), + inference_request=self._provider.llm_inference_request(**inference_params), ) return invocation_obj def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: - provider = self._get_provider() - if provider == "cohere": - text = response.data.inference_response.generated_texts[0].text - elif provider == "meta": - text = response.data.inference_response.choices[0].text - else: - raise ValueError(f"Invalid provider: {provider}") + text = self._provider.completion_response_to_text(response) if stop is not None: text = enforce_stop_tokens(text, stop) @@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase): response = llm.invoke("Tell me a joke.") """ + if self.is_stream: + text = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + text += chunk.text + if stop is not None: + text = enforce_stop_tokens(text, stop) + return text invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs) response = self.client.generate_text(invocation_obj) return self._process_response(response, stop) + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Stream OCIGenAI LLM on given prompt. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + An iterator of GenerationChunks. + + Example: + .. code-block:: python + + response = llm.stream("Tell me a joke.") + """ + + self.is_stream = True + invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs) + response = self.client.generate_text(invocation_obj) + + for event in response.data.events(): + json_load = json.loads(event.data) + if "text" in json_load: + event_data_text = json_load["text"] + else: + event_data_text = "" + chunk = GenerationChunk(text=event_data_text) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index a0e573068c..3c9b5e2254 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -27,6 +27,7 @@ EXPECTED_ALL = [ "ChatMlflow", "ChatMLflowAIGateway", "ChatMLX", + "ChatOCIGenAI", "ChatOllama", "ChatOpenAI", "ChatPerplexity", diff --git a/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py b/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py new file mode 100644 index 0000000000..b7d80d19c4 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_oci_generative_ai.py @@ -0,0 +1,105 @@ +"""Test OCI Generative AI LLM service""" +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import HumanMessage +from pytest import MonkeyPatch + +from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI + + +class MockResponseDict(dict): + def __getattr__(self, val): # type: ignore[no-untyped-def] + return self[val] + + +@pytest.mark.requires("oci") +@pytest.mark.parametrize( + "test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"] +) +def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None: + """Test valid chat call to OCI Generative AI LLM service.""" + oci_gen_ai_client = MagicMock() + llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client) + + provider = llm.model_id.split(".")[0].lower() + + def mocked_response(*args): # type: ignore[no-untyped-def] + response_text = "Assistant chat reply." + response = None + if provider == "cohere": + response = MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "text": response_text, + "finish_reason": "completed", + } + ), + "model_id": "cohere.command-r-16k", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict( + { + "content-length": "123", + } + ), + } + ) + elif provider == "meta": + response = MockResponseDict( + { + "status": 200, + "data": MockResponseDict( + { + "chat_response": MockResponseDict( + { + "choices": [ + MockResponseDict( + { + "message": MockResponseDict( + { + "content": [ + MockResponseDict( + { + "text": response_text, # noqa: E501 + } + ) + ] + } + ), + "finish_reason": "completed", + } + ) + ], + "time_created": "2024-09-01T00:00:00Z", + } + ), + "model_id": "cohere.command-r-16k", + "model_version": "1.0.0", + } + ), + "request_id": "1234567890", + "headers": MockResponseDict( + { + "content-length": "123", + } + ), + } + ) + return response + + monkeypatch.setattr(llm.client, "chat", mocked_response) + + messages = [ + HumanMessage(content="User message"), + ] + + expected = "Assistant chat reply." + actual = llm.invoke(messages, temperature=0.2) + assert actual.content == expected diff --git a/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py index b1c36ec7a5..cc3599abe6 100644 --- a/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py +++ b/libs/community/tests/unit_tests/llms/test_oci_generative_ai.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from pytest import MonkeyPatch -from langchain_community.llms import OCIGenAI +from langchain_community.llms.oci_generative_ai import OCIGenAI class MockResponseDict(dict): @@ -16,12 +16,12 @@ class MockResponseDict(dict): @pytest.mark.parametrize( "test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"] ) -def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: - """Test valid call to OCI Generative AI LLM service.""" +def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None: + """Test valid completion call to OCI Generative AI LLM service.""" oci_gen_ai_client = MagicMock() llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client) - provider = llm._get_provider() + provider = llm.model_id.split(".")[0].lower() def mocked_response(*args): # type: ignore[no-untyped-def] response_text = "This is the completion." @@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None: ) monkeypatch.setattr(llm.client, "generate_text", mocked_response) - output = llm.invoke("This is a prompt.", temperature=0.2) assert output == "This is the completion."