langchain/libs/community/langchain_community/llms/oci_generative_ai.py
Rave Harpaz f5ff7f178b
Add OCI Generative AI new model support (#22880)
- [x] PR title: 
community: Add OCI Generative AI new model support
 
- [x] PR message:
- Description: adding support for new models offered by OCI Generative
AI services. This is a moderate update of our initial integration PR
16548 and includes a new integration for our chat models under
/langchain_community/chat_models/oci_generative_ai.py
    - Issue: NA
- Dependencies: No new Dependencies, just latest version of our OCI sdk
    - Twitter handle: NA


- [x] Add tests and docs: 
  1. we have updated our unit tests
2. we have updated our documentation including a new ipynb for our new
chat integration


- [x] Lint and test: 
 `make format`, `make lint`, and `make test` run successfully

---------

Co-authored-by: RHARPAZ <RHARPAZ@RHARPAZ-5750.us.oracle.com>
Co-authored-by: Arthur Cheng <arthur.cheng@oracle.com>
2024-06-24 14:48:23 -04:00

360 lines
12 KiB
Python

from __future__ import annotations
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str:
...
@abstractmethod
def completion_response_to_text(self, response: Any) -> str:
...
class CohereProvider(Provider):
stop_sequence_key = "stop_sequences"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.CohereLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.generated_texts[0].text
class MetaProvider(Provider):
stop_sequence_key = "stop"
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.llm_inference_request = models.LlamaLlmInferenceRequest
def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text
class OCIAuthType(Enum):
"""OCI authentication types as enumerator."""
API_KEY = 1
SECURITY_TOKEN = 2
INSTANCE_PRINCIPAL = 3
RESOURCE_PRINCIPAL = 4
class OCIGenAIBase(BaseModel, ABC):
"""Base class for OCI GenAI models"""
client: Any #: :meta private:
auth_type: Optional[str] = "API_KEY"
"""Authentication type, could be
API_KEY,
SECURITY_TOKEN,
INSTANCE_PRINCIPAL,
RESOURCE_PRINCIPAL
If not specified, API_KEY will be used
"""
auth_profile: Optional[str] = "DEFAULT"
"""The name of the profile in ~/.oci/config
If not specified , DEFAULT will be used
"""
model_id: str = None # type: ignore[assignment]
"""Id of the model to call, e.g., cohere.command"""
provider: str = None # type: ignore[assignment]
"""Provider name of the model. Default to None,
will try to be derived from the model_id
otherwise, requires user input
"""
model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model"""
service_endpoint: str = None # type: ignore[assignment]
"""service endpoint url"""
compartment_id: str = None # type: ignore[assignment]
"""OCID of compartment"""
is_stream: bool = False
"""Whether to stream back partial progress"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that OCI config and python package exists in environment."""
# Skip creating new client if passed in constructor
if values["client"] is not None:
return values
try:
import oci
client_kwargs = {
"config": {},
"signer": None,
"service_endpoint": values["service_endpoint"],
"retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY,
"timeout": (10, 240), # default timeout config for OCI Gen AI service
}
if values["auth_type"] == OCIAuthType(1).name:
client_kwargs["config"] = oci.config.from_file(
profile_name=values["auth_profile"]
)
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)
with open(
oci_config.get("security_token_file"), encoding="utf-8"
) as f:
st_string = f.read()
return oci.auth.signers.SecurityTokenSigner(st_string, pk)
client_kwargs["config"] = oci.config.from_file(
profile_name=values["auth_profile"]
)
client_kwargs["signer"] = make_security_token_signer(
oci_config=client_kwargs["config"]
)
elif values["auth_type"] == OCIAuthType(3).name:
client_kwargs[
"signer"
] = oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
elif values["auth_type"] == OCIAuthType(4).name:
client_kwargs[
"signer"
] = oci.auth.signers.get_resource_principals_signer()
else:
raise ValueError(
"Please provide valid value to auth_type, "
f"{values['auth_type']} is not valid."
)
values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient(
**client_kwargs
)
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
except Exception as e:
raise ValueError(
"""Could not authenticate with OCI client.
Please check if ~/.oci/config exists.
If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used,
please check the specified
auth_profile and auth_type are valid.""",
e,
) from e
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"model_kwargs": _model_kwargs},
}
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
if self.provider is not None:
provider = self.provider
else:
provider = self.model_id.split(".")[0].lower()
if provider not in provider_map:
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return provider_map[provider]
class OCIGenAI(LLM, OCIGenAIBase):
"""OCI large language models.
To authenticate, the OCI client uses the methods described in
https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm
The authentifcation method is passed through auth_type and should be one of:
API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL
Make sure you have the required policies (profile/roles) to
access the OCI Generative AI service.
If a specific config profile is used, you must pass
the name of the profile (from ~/.oci/config) through auth_profile.
To use, you must provide the compartment id
along with the endpoint url, and model id
as named parameters to the constructor.
Example:
.. code-block:: python
from langchain_community.llms import OCIGenAI
llm = OCIGenAI(
model_id="MY_MODEL_ID",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID"
)
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_generative_ai_completion"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
from oci.generative_ai_inference import models
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self._provider.stop_sequence_key] = stop
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
inference_params = {**_model_kwargs, **kwargs}
inference_params["prompt"] = prompt
inference_params["is_stream"] = self.is_stream
invocation_obj = models.GenerateTextDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
inference_request=self._provider.llm_inference_request(**inference_params),
)
return invocation_obj
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
text = self._provider.completion_response_to_text(response)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to OCIGenAI generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = llm.invoke("Tell me a joke.")
"""
if self.is_stream:
text = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
text += chunk.text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
return self._process_response(response, stop)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream OCIGenAI LLM on given prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
An iterator of GenerationChunks.
Example:
.. code-block:: python
response = llm.stream("Tell me a joke.")
"""
self.is_stream = True
invocation_obj = self._prepare_invocation_object(prompt, stop, kwargs)
response = self.client.generate_text(invocation_obj)
for event in response.data.events():
json_load = json.loads(event.data)
if "text" in json_load:
event_data_text = json_load["text"]
else:
event_data_text = ""
chunk = GenerationChunk(text=event_data_text)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk