Add OCI Generative AI new model support (#22880)

- [x] PR title: 
community: Add OCI Generative AI new model support
 
- [x] PR message:
- Description: adding support for new models offered by OCI Generative
AI services. This is a moderate update of our initial integration PR
16548 and includes a new integration for our chat models under
/langchain_community/chat_models/oci_generative_ai.py
    - Issue: NA
- Dependencies: No new Dependencies, just latest version of our OCI sdk
    - Twitter handle: NA


- [x] Add tests and docs: 
  1. we have updated our unit tests
2. we have updated our documentation including a new ipynb for our new
chat integration


- [x] Lint and test: 
 `make format`, `make lint`, and `make test` run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
pull/23366/head
Rave Harpaz 3 months ago committed by GitHub
parent 753edf9c80
commit f5ff7f178b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,190 @@
{
"cells": [
{
"cell_type": "raw",
"id": "afaf8039",
"metadata": {},
"source": [
"---\n",
"sidebar_label: OCIGenAI\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# ChatOCIGenAI\n",
"\n",
"This notebook provides a quick overview for getting started with OCIGenAI [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatOCIGenAI features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html).\n",
"\n",
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
"\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [JS support](https://js.langchain.com/v0.2/docs/integrations/chat/oci_generative_ai) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatOCIGenAI](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-oci-generative-ai?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-oci-generative-ai?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"To access OCIGenAI models you'll need to install the `oci` and `langchain-community` packages.\n",
"\n",
"### Credentials\n",
"\n",
"The credentials and authentication methods supported for this integration are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n",
"\n",
"API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)"
]
},
{
"cell_type": "markdown",
"id": "0730d6a1-c893-4840-9817-5e5251676d5d",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain OCIGenAI integration lives in the `langchain-community` package and you will also need to install the `oci` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "652d6238-1f87-422a-b135-f5abbb8652fc",
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community oci"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI\n",
"from langchain_core.messages import AIMessage, HumanMessage, SystemMessage\n",
"\n",
"chat = ChatOCIGenAI(\n",
" model_id=\"cohere.command-r-16k\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
" model_kwargs={\"temperature\": 0.7, \"max_tokens\": 500},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2b4f3e15",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"messages = [\n",
" SystemMessage(content=\"your are an AI assistant.\"),\n",
" AIMessage(content=\"Hi there human!\"),\n",
" HumanMessage(content=\"tell me a joke.\"),\n",
"]\n",
"response = chat.invoke(messages)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [],
"source": [
"print(response.content)"
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n",
"chain = prompt | chat\n",
"\n",
"response = chain.invoke({\"topic\": \"dogs\"})\n",
"print(response.content)"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatOCIGenAI features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.oci_generative_ai.ChatOCIGenAI.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -14,15 +14,15 @@
"Oracle Cloud Infrastructure (OCI) Generative AI is a fully managed service that provides a set of state-of-the-art, customizable large language models (LLMs) that cover a wide range of use cases, and which is available through a single API.\n",
"Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned custom models based on your own data on dedicated AI clusters. Detailed documentation of the service and API is available __[here](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)__ and __[here](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai/20231130/)__.\n",
"\n",
"This notebook explains how to use OCI's Genrative AI models with LangChain."
"This notebook explains how to use OCI's Generative AI complete models with LangChain."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prerequisite\n",
"We will need to install the oci sdk"
"## Setup\n",
"Ensure that the oci sdk and the langchain-community package are installed"
]
},
{
@ -31,38 +31,40 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install -U oci"
"!pip install -U oci langchain-community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### OCI Generative AI API endpoint \n",
"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com"
"## Usage"
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## Authentication\n",
"The authentication methods supported for this langchain integration are:\n",
"from langchain_community.llms.oci_generative_ai import OCIGenAI\n",
"\n",
"1. API Key\n",
"2. Session token\n",
"3. Instance principal\n",
"4. Resource principal \n",
"llm = OCIGenAI(\n",
" model_id=\"cohere.command\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
")\n",
"\n",
"These follows the standard SDK authentication methods detailed __[here](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__.\n",
" "
"response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Usage"
"#### Chaining with prompt templates"
]
},
{
@ -71,44 +73,54 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.llms import OCIGenAI\n",
"from langchain_core.prompts import PromptTemplate\n",
"\n",
"# use default authN method API-key\n",
"llm = OCIGenAI(\n",
" model_id=\"MY_MODEL\",\n",
" model_id=\"cohere.command\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
")\n",
"\n",
"response = llm.invoke(\"Tell me one fact about earth\", temperature=0.7)\n",
"prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n",
"llm_chain = prompt | llm\n",
"\n",
"response = llm_chain.invoke(\"what is the capital of france?\")\n",
"print(response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Streaming"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain_core.prompts import PromptTemplate\n",
"\n",
"# Use Session Token to authN\n",
"llm = OCIGenAI(\n",
" model_id=\"MY_MODEL\",\n",
" model_id=\"cohere.command\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
" auth_type=\"SECURITY_TOKEN\",\n",
" auth_profile=\"MY_PROFILE\", # replace with your profile name\n",
" model_kwargs={\"temperature\": 0.7, \"top_p\": 0.75, \"max_tokens\": 200},\n",
" model_kwargs={\"temperature\": 0, \"max_tokens\": 500},\n",
")\n",
"\n",
"prompt = PromptTemplate(input_variables=[\"query\"], template=\"{query}\")\n",
"\n",
"llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
"for chunk in llm.stream(\"Write me a song about sparkling water.\"):\n",
" print(chunk, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Authentication\n",
"The authentication methods supported for LlamaIndex are equivalent to those used with other OCI services and follow the __[standard SDK authentication](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm)__ methods, specifically API Key, session token, instance principal, and resource principal.\n",
"\n",
"response = llm_chain.invoke(\"what is the capital of france?\")\n",
"print(response)"
"API key is the default authentication method used in the examples above. The following example demonstrates how to use a different authentication method (session token)"
]
},
{
@ -117,49 +129,39 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.embeddings import OCIGenAIEmbeddings\n",
"from langchain_community.vectorstores import FAISS\n",
"from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.runnables import RunnablePassthrough\n",
"\n",
"embeddings = OCIGenAIEmbeddings(\n",
" model_id=\"MY_EMBEDDING_MODEL\",\n",
"llm = OCIGenAI(\n",
" model_id=\"cohere.command\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
")\n",
"\n",
"vectorstore = FAISS.from_texts(\n",
" [\n",
" \"Larry Ellison co-founded Oracle Corporation in 1977 with Bob Miner and Ed Oates.\",\n",
" \"Oracle Corporation is an American multinational computer technology company headquartered in Austin, Texas, United States.\",\n",
" ],\n",
" embedding=embeddings,\n",
")\n",
"\n",
"retriever = vectorstore.as_retriever()\n",
"\n",
"template = \"\"\"Answer the question based only on the following context:\n",
"{context}\n",
" \n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = PromptTemplate.from_template(template)\n",
" auth_type=\"SECURITY_TOKEN\",\n",
" auth_profile=\"MY_PROFILE\", # replace with your profile name\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dedicated AI Cluster\n",
"To access models hosted in a dedicated AI cluster __[create an endpoint](https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/)__ whose assigned OCID (currently prefixed by ocid1.generativeaiendpoint.oc1.us-chicago-1) is used as your model ID.\n",
"\n",
"When accessing models hosted in a dedicated AI cluster you will need to initialize the OCIGenAI interface with two extra required params (\"provider\" and \"context_size\")."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = OCIGenAI(\n",
" model_id=\"MY_MODEL\",\n",
" model_id=\"ocid1.generativeaiendpoint.oc1.us-chicago-1....\",\n",
" service_endpoint=\"https://inference.generativeai.us-chicago-1.oci.oraclecloud.com\",\n",
" compartment_id=\"MY_OCID\",\n",
")\n",
"\n",
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | llm\n",
" | StrOutputParser()\n",
")\n",
"\n",
"print(chain.invoke(\"when was oracle founded?\"))\n",
"print(chain.invoke(\"where is oracle headquartered?\"))"
" compartment_id=\"DEDICATED_COMPARTMENT_OCID\",\n",
" auth_profile=\"MY_PROFILE\", # replace with your profile name,\n",
" provider=\"MODEL_PROVIDER\", # e.g., \"cohere\" or \"meta\"\n",
" context_size=\"MODEL_CONTEXT_SIZE\", # e.g., 128000\n",
")"
]
}
],

@ -2,27 +2,29 @@
The `LangChain` integrations related to [Oracle Cloud Infrastructure](https://www.oracle.com/artificial-intelligence/).
## LLMs
### OCI Generative AI
## OCI Generative AI
> Oracle Cloud Infrastructure (OCI) [Generative AI](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm) is a fully managed service that provides a set of state-of-the-art,
> customizable large language models (LLMs) that cover a wide range of use cases, and which are available through a single API.
> Using the OCI Generative AI service you can access ready-to-use pretrained models, or create and host your own fine-tuned
> custom models based on your own data on dedicated AI clusters.
To use, you should have the latest `oci` python SDK installed.
To use, you should have the latest `oci` python SDK and the langchain_community package installed.
```bash
pip install -U oci
pip install -U oci langchain-community
```
See [usage examples](/docs/integrations/llms/oci_generative_ai).
See [chat](/docs/integrations/llms/oci_generative_ai), [complete](/docs/integrations/chat/oci_generative_ai), and [embedding](/docs/integrations/text_embedding/oci_generative_ai) usage examples.
```python
from langchain_community.chat_models import ChatOCIGenAI
from langchain_community.llms import OCIGenAI
from langchain_community.embeddings import OCIGenAIEmbeddings
```
### OCI Data Science Model Deployment Endpoint
## OCI Data Science Model Deployment Endpoint
> [OCI Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/home.htm) is a
> fully managed and serverless platform for data science teams. Using the OCI Data Science
@ -47,12 +49,3 @@ from langchain_community.llms import OCIModelDeploymentVLLM
from langchain_community.llms import OCIModelDeploymentTGI
```
## Text Embedding Models
### OCI Generative AI
See [usage examples](/docs/integrations/text_embedding/oci_generative_ai).
```python
from langchain_community.embeddings import OCIGenAIEmbeddings
```

@ -46,7 +46,7 @@ mwxml>=0.3.3,<0.4
newspaper3k>=0.2.8,<0.3
numexpr>=2.8.6,<3
nvidia-riva-client>=2.14.0,<3
oci>=2.119.1,<3
oci>=2.128.0,<3
openai<2
openapi-pydantic>=0.3.2,<0.4
oracle-ads>=2.9.1,<3

@ -121,6 +121,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.mlx import (
ChatMLX,
)
from langchain_community.chat_models.oci_generative_ai import (
ChatOCIGenAI, # noqa: F401
)
from langchain_community.chat_models.octoai import ChatOctoAI
from langchain_community.chat_models.ollama import (
ChatOllama,
@ -194,6 +197,7 @@ __all__ = [
"ChatMLflowAIGateway",
"ChatMaritalk",
"ChatMlflow",
"ChatOCIGenAI",
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",
@ -248,6 +252,7 @@ _module_lookup = {
"ChatMaritalk": "langchain_community.chat_models.maritalk",
"ChatMlflow": "langchain_community.chat_models.mlflow",
"ChatOctoAI": "langchain_community.chat_models.octoai",
"ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai",
"ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai",
"ChatPerplexity": "langchain_community.chat_models.perplexity",

@ -0,0 +1,363 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Extra
from langchain_community.llms.oci_generative_ai import OCIGenAIBase
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def chat_response_to_text(self, response: Any) -> str:
...
@abstractmethod
def chat_stream_to_text(self, event_data: Dict) -> str:
...
@abstractmethod
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
...
@abstractmethod
def get_role(self, message: BaseMessage) -> str:
...
@abstractmethod
def messages_to_oci_params(self, messages: Any) -> Dict[str, Any]:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.oci_chat_request = models.CohereChatRequest
self.oci_chat_message = {
"USER": models.CohereUserMessage,
"CHATBOT": models.CohereChatBotMessage,
"SYSTEM": models.CohereSystemMessage,
}
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
def chat_response_to_text(self, response: Any) -> str:
return response.data.chat_response.text
def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data and "finishReason" not in event_data:
return event_data["text"]
else:
return ""
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return {
"finish_reason": response.data.chat_response.finish_reason,
}
def get_role(self, message: BaseMessage) -> str:
if isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "CHATBOT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
else:
raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: Sequence[ChatMessage]) -> Dict[str, Any]:
oci_chat_history = [
self.oci_chat_message[self.get_role(msg)](message=msg.content)
for msg in messages[:-1]
]
oci_params = {
"message": messages[-1].content,
"chat_history": oci_chat_history,
"api_format": self.chat_api_format,
}
return oci_params
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.oci_chat_request = models.GenericChatRequest
self.oci_chat_message = {
"USER": models.UserMessage,
"SYSTEM": models.SystemMessage,
"ASSISTANT": models.AssistantMessage,
}
self.oci_chat_message_content = models.TextContent
self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC
def chat_response_to_text(self, response: Any) -> str:
return response.data.chat_response.choices[0].message.content[0].text
def chat_stream_to_text(self, event_data: Dict) -> str:
if "message" in event_data:
return event_data["message"]["content"][0]["text"]
else:
return ""
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
return {
"finish_reason": response.data.chat_response.choices[0].finish_reason,
"time_created": str(response.data.chat_response.time_created),
}
def get_role(self, message: BaseMessage) -> str:
# meta only supports alternating user/assistant roles
if isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "ASSISTANT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
else:
raise ValueError(f"Got unknown type {message}")
def messages_to_oci_params(self, messages: List[BaseMessage]) -> Dict[str, Any]:
oci_messages = [
self.oci_chat_message[self.get_role(msg)](
content=[self.oci_chat_message_content(text=msg.content)]
)
for msg in messages
]
oci_params = {
"messages": oci_messages,
"api_format": self.chat_api_format,
"top_k": -1,
}
return oci_params
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"""ChatOCIGenAI chat model integration.
Setup:
Install ``langchain-community`` and the ``oci`` sdk.
.. code-block:: bash
pip install -U langchain-community oci
Key init args completion params:
model_id: str
Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k.
is_stream: bool
Whether to stream back partial progress
model_kwargs: Optional[Dict]
Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens.
Key init args client params:
service_endpoint: str
The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com.
compartment_id: str
The compartment OCID.
auth_type: str
The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL.
auth_profile: Optional[str]
The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used.
provider: str
Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_community.chat_models import ChatOCIGenAI
chat = ChatOCIGenAI(
model_id="cohere.command-r-16k",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID",
model_kwargs={"temperature": 0.7, "max_tokens": 500},
)
Invoke:
.. code-block:: python
messages = [
SystemMessage(content="your are an AI assistant."),
AIMessage(content="Hi there human!"),
HumanMessage(content="tell me a joke."),
]
response = chat.invoke(messages)
Stream:
.. code-block:: python
for r in chat.stream(messages):
print(r.content, end="", flush=True)
Response metadata
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata)
""" # noqa: E501
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_generative_ai_chat"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
kwargs: Dict[str, Any],
stream: bool,
) -> Dict[str, Any]:
try:
from oci.generative_ai_inference import models
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
oci_params = self._provider.messages_to_oci_params(messages)
oci_params["is_stream"] = stream # self.is_stream
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self._provider.stop_sequence_key] = stop
chat_params = {**_model_kwargs, **kwargs, **oci_params}
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
request = models.ChatDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
chat_request=self._provider.oci_chat_request(**chat_params),
)
return request
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to a OCIGenAI chat model.
Args:
messages: list of LangChain messages
stop: Optional list of stop words to use.
Returns:
LangChain ChatResult
Example:
.. code-block:: python
messages = [
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!")
]
response = llm.invoke(messages)
"""
if self.is_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
request = self._prepare_request(messages, stop, kwargs, stream=False)
response = self.client.chat(request)
content = self._provider.chat_response_to_text(response)
if stop is not None:
content = enforce_stop_tokens(content, stop)
generation_info = self._provider.chat_generation_info(response)
llm_output = {
"model_id": response.data.model_id,
"model_version": response.data.model_version,
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=content), generation_info=generation_info
)
],
llm_output=llm_output,
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(messages, stop, kwargs, stream=True)
response = self.client.chat(request)
for event in response.data.events():
delta = self._provider.chat_stream_to_text(json.loads(event.data))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

@ -1,17 +1,53 @@
from __future__ import annotations
from abc import ABC
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
VALID_PROVIDERS = ("cohere", "meta")
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def completion_response_to_text(self, response: Any) -> str:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.CohereLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.generated_texts[0].text
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.LlamaLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text
class OCIAuthType(Enum):
@ -33,8 +69,8 @@ class OCIGenAIBase(BaseModel, ABC):
API_KEY,
SECURITY_TOKEN,
INSTANCE_PRINCIPLE,
RESOURCE_PRINCIPLE
INSTANCE_PRINCIPAL,
RESOURCE_PRINCIPAL
If not specified, API_KEY will be used
"""
@ -65,11 +101,6 @@ class OCIGenAIBase(BaseModel, ABC):
is_stream: bool = False
"""Whether to stream back partial progress"""
llm_stop_sequence_mapping: Mapping[str, str] = {
"cohere": "stop_sequences",
"meta": "stop",
}
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that OCI config and python package exists in environment."""
@ -121,24 +152,28 @@ class OCIGenAIBase(BaseModel, ABC):
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError("Please provide valid value to auth_type")
raise ValueError(
"Please provide valid value to auth_type, "
f"{values['auth_type']} is not valid."
)
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ImportError(
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"Could not authenticate with OCI client. "
"Please check if ~/.oci/config exists. "
"If INSTANCE_PRINCIPLE or RESOURCE_PRINCIPLE is used, "
"Please check the specified "
"auth_profile and auth_type are valid."
"""Could not authenticate with OCI client.
Please check if ~/.oci/config exists.
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
please check the specified
auth_profile and auth_type are valid.""",
e,
) from e
return values
@ -151,19 +186,19 @@ class OCIGenAIBase(BaseModel, ABC):
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self) -> str:
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
if self.provider is not None:
provider = self.provider
else:
provider = self.model_id.split(".")[0].lower()
if provider not in VALID_PROVIDERS:
if provider not in provider_map:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return provider
return provider_map[provider]
class OCIGenAI(LLM, OCIGenAIBase):
@ -173,7 +208,7 @@ class OCIGenAI(LLM, OCIGenAIBase):
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
The authentifcation method is passed through auth_type and should be one of:
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPLE, RESOURCE_PRINCIPLE
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
Make sure you have the required policies (profile/roles) to
access the OCI Generative AI service.
@ -204,21 +239,29 @@ class OCIGenAI(LLM, OCIGenAIBase):
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci"
return "oci_generative_ai_completion"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
from oci.generative_ai_inference import models
oci_llm_request_mapping = {
"cohere": models.CohereLlmInferenceRequest,
"meta": models.LlamaLlmInferenceRequest,
}
provider = self._get_provider()
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self.llm_stop_sequence_mapping[provider]] = stop
_model_kwargs[self._provider.stop_sequence_key] = stop
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
@ -232,19 +275,13 @@ class OCIGenAI(LLM, OCIGenAIBase):
invocation_obj = models.GenerateTextDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
inference_request=oci_llm_request_mapping[provider](**inference_params),
inference_request=self._provider.llm_inference_request(**inference_params),
)
return invocation_obj
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
provider = self._get_provider()
if provider == "cohere":
text = response.data.inference_response.generated_texts[0].text
elif provider == "meta":
text = response.data.inference_response.choices[0].text
else:
raise ValueError(f"Invalid provider: {provider}")
text = self._provider.completion_response_to_text(response)
if stop is not None:
text = enforce_stop_tokens(text, stop)
@ -272,7 +309,51 @@ class OCIGenAI(LLM, OCIGenAIBase):
response = llm.invoke("Tell me a joke.")
"""
if self.is_stream:
text = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
text += chunk.text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
return self._process_response(response, stop)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream OCIGenAI LLM on given prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
response = llm.stream("Tell me a joke.")
"""
self.is_stream = True
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
for event in response.data.events():
json_load = json.loads(event.data)
if "text" in json_load:
event_data_text = json_load["text"]
else:
event_data_text = ""
chunk = GenerationChunk(text=event_data_text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

@ -27,6 +27,7 @@ EXPECTED_ALL = [
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatMLX",
"ChatOCIGenAI",
"ChatOllama",
"ChatOpenAI",
"ChatPerplexity",

@ -0,0 +1,105 @@
"""Test OCI Generative AI LLM service"""
from unittest.mock import MagicMock
import pytest
from langchain_core.messages import HumanMessage
from pytest import MonkeyPatch
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val]
@pytest.mark.requires("oci")
@pytest.mark.parametrize(
"test_model_id", ["cohere.command-r-16k", "meta.llama-3-70b-instruct"]
)
def test_llm_chat(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid chat call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = ChatOCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "Assistant chat reply."
response = None
if provider == "cohere":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"text": response_text,
"finish_reason": "completed",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
elif provider == "meta":
response = MockResponseDict(
{
"status": 200,
"data": MockResponseDict(
{
"chat_response": MockResponseDict(
{
"choices": [
MockResponseDict(
{
"message": MockResponseDict(
{
"content": [
MockResponseDict(
{
"text": response_text, # noqa: E501
}
)
]
}
),
"finish_reason": "completed",
}
)
],
"time_created": "2024-09-01T00:00:00Z",
}
),
"model_id": "cohere.command-r-16k",
"model_version": "1.0.0",
}
),
"request_id": "1234567890",
"headers": MockResponseDict(
{
"content-length": "123",
}
),
}
)
return response
monkeypatch.setattr(llm.client, "chat", mocked_response)
messages = [
HumanMessage(content="User message"),
]
expected = "Assistant chat reply."
actual = llm.invoke(messages, temperature=0.2)
assert actual.content == expected

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest
from pytest import MonkeyPatch
from langchain_community.llms import OCIGenAI
from langchain_community.llms.oci_generative_ai import OCIGenAI
class MockResponseDict(dict):
@ -16,12 +16,12 @@ class MockResponseDict(dict):
@pytest.mark.parametrize(
"test_model_id", ["cohere.command", "cohere.command-light", "meta.llama-2-70b-chat"]
)
def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid call to OCI Generative AI LLM service."""
def test_llm_complete(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
"""Test valid completion call to OCI Generative AI LLM service."""
oci_gen_ai_client = MagicMock()
llm = OCIGenAI(model_id=test_model_id, client=oci_gen_ai_client)
provider = llm._get_provider()
provider = llm.model_id.split(".")[0].lower()
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "This is the completion."
@ -71,6 +71,5 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
)
monkeypatch.setattr(llm.client, "generate_text", mocked_response)
output = llm.invoke("This is a prompt.", temperature=0.2)
assert output == "This is the completion."

Loading…
Cancel
Save