You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/llms/oci_data_science_model_depl...

363 lines
12 KiB
Python

import logging
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
DEFAULT_TIME_OUT = 300
DEFAULT_CONTENT_TYPE_JSON = "application/json"
class OCIModelDeploymentLLM(LLM):
"""Base class for LLM deployed on OCI Data Science Model Deployment."""
auth: dict = Field(default_factory=dict, exclude=True)
"""ADS auth dictionary for OCI authentication:
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
This can be generated by calling `ads.common.auth.api_keys()`
or `ads.common.auth.resource_principal()`. If this is not
provided then the `ads.common.default_signer()` will be used."""
max_tokens: int = 256
"""Denotes the number of tokens to predict per generation."""
temperature: float = 0.2
"""A non-negative float that tunes the degree of randomness in generation."""
k: int = 0
"""Number of most likely tokens to consider at each step."""
p: float = 0.75
"""Total probability mass of tokens to consider at each step."""
endpoint: str = ""
"""The uri of the endpoint from the deployed Model Deployment model."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"
(the one with the highest log probability per token).
"""
stop: Optional[List[str]] = None
"""Stop words to use when generating. Model output is cut off
at the first occurrence of any of these substrings."""
@root_validator()
def validate_environment( # pylint: disable=no-self-argument
cls, values: Dict
) -> Dict:
"""Validate that python package exists in environment."""
try:
import ads
except ImportError as ex:
raise ImportError(
"Could not import ads python package. "
"Please install it with `pip install oracle_ads`."
) from ex
if not values.get("auth", None):
values["auth"] = ads.common.auth.default_signer()
values["endpoint"] = get_from_dict_or_env(
values,
"endpoint",
"OCI_LLM_ENDPOINT",
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Default parameters for the model."""
raise NotImplementedError
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
**{"endpoint": self.endpoint},
**self._default_params,
}
def _construct_json_body(self, prompt: str, params: dict) -> dict:
"""Constructs the request body as a dictionary (JSON)."""
raise NotImplementedError
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
"""Combines the invocation parameters with default parameters."""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop"] = self.stop
elif stop is not None:
params["stop"] = stop
else:
# Don't set "stop" in param as None. It should be a list.
params["stop"] = []
return {**params, **kwargs}
def _process_response(self, response_json: dict) -> str:
raise NotImplementedError
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to OCI Data Science Model Deployment endpoint.
Args:
prompt (str):
The prompt to pass into the model.
stop (List[str], Optional):
List of stop words to use when generating.
kwargs:
requests_kwargs:
Additional ``**kwargs`` to pass to requests.post
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = oci_md("Tell me a joke.")
"""
requests_kwargs = kwargs.pop("requests_kwargs", {})
params = self._invocation_params(stop, **kwargs)
body = self._construct_json_body(prompt, params)
logger.info(f"LLM API Request:\n{prompt}")
response = self._send_request(
data=body, endpoint=self.endpoint, **requests_kwargs
)
completion = self._process_response(response)
logger.info(f"LLM API Completion:\n{completion}")
return completion
def _send_request(
self,
data: Any,
endpoint: str,
header: Optional[dict] = {},
**kwargs: Any,
) -> Dict:
"""Sends request to the oci data science model deployment endpoint.
Args:
data (Json serializable):
data need to be sent to the endpoint.
endpoint (str):
The model HTTP endpoint.
header (dict, optional):
A dictionary of HTTP headers to send to the specified url.
Defaults to {}.
kwargs:
Additional ``**kwargs`` to pass to requests.post.
Raises:
Exception:
Raise when invoking fails.
Returns:
A JSON representation of a requests.Response object.
"""
if not header:
header = {}
header["Content-Type"] = (
header.pop("content_type", DEFAULT_CONTENT_TYPE_JSON)
or DEFAULT_CONTENT_TYPE_JSON
)
request_kwargs = {"json": data}
request_kwargs["headers"] = header
timeout = kwargs.pop("timeout", DEFAULT_TIME_OUT)
attempts = 0
while attempts < 2:
request_kwargs["auth"] = self.auth.get("signer")
response = requests.post(
endpoint, timeout=timeout, **request_kwargs, **kwargs
)
if response.status_code == 401:
self._refresh_signer()
attempts += 1
continue
break
try:
response.raise_for_status()
response_json = response.json()
except Exception:
logger.error(
"DEBUG INFO: request_kwargs=%s, status_code=%s, content=%s",
request_kwargs,
response.status_code,
response.content,
)
raise
return response_json
def _refresh_signer(self) -> None:
if self.auth.get("signer", None) and hasattr(
self.auth["signer"], "refresh_security_token"
):
self.auth["signer"].refresh_security_token()
class OCIModelDeploymentTGI(OCIModelDeploymentLLM):
"""OCI Data Science Model Deployment TGI Endpoint.
To use, you must provide the model HTTP endpoint from your deployed
model, e.g. https://<MD_OCID>/predict.
To authenticate, `oracle-ads` has been used to automatically load
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
Make sure to have the required policies to access the OCI Data
Science Model Deployment endpoint. See:
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
Example:
.. code-block:: python
from langchain_community.llms import ModelDeploymentTGI
oci_md = ModelDeploymentTGI(endpoint="https://<MD_OCID>/predict")
"""
do_sample: bool = True
"""If set to True, this parameter enables decoding strategies such as
multi-nominal sampling, beam-search multi-nominal sampling, Top-K
sampling and Top-p sampling.
"""
watermark: bool = True
"""Watermarking with `A Watermark for Large Language Models <https://arxiv.org/abs/2301.10226>`_.
Defaults to True."""
return_full_text = False
"""Whether to prepend the prompt to the generated text. Defaults to False."""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_model_deployment_tgi_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for invoking OCI model deployment TGI endpoint."""
return {
"best_of": self.best_of,
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"top_k": self.k
if self.k > 0
else None, # `top_k` must be strictly positive'
"top_p": self.p,
"do_sample": self.do_sample,
"return_full_text": self.return_full_text,
"watermark": self.watermark,
}
def _construct_json_body(self, prompt: str, params: dict) -> dict:
return {
"inputs": prompt,
"parameters": params,
}
def _process_response(self, response_json: dict) -> str:
return str(response_json.get("generated_text", response_json)) + "\n"
class OCIModelDeploymentVLLM(OCIModelDeploymentLLM):
"""VLLM deployed on OCI Data Science Model Deployment
To use, you must provide the model HTTP endpoint from your deployed
model, e.g. https://<MD_OCID>/predict.
To authenticate, `oracle-ads` has been used to automatically load
credentials: https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html
Make sure to have the required policies to access the OCI Data
Science Model Deployment endpoint. See:
https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm#model_dep_policies_auth__predict-endpoint
Example:
.. code-block:: python
from langchain_community.llms import OCIModelDeploymentVLLM
oci_md = OCIModelDeploymentVLLM(
endpoint="https://<MD_OCID>/predict",
model="mymodel"
)
"""
model: str
"""The name of the model."""
n: int = 1
"""Number of output sequences to return for the given prompt."""
k: int = -1
"""Number of most likely tokens to consider at each step."""
frequency_penalty: float = 0.0
"""Penalizes repeated tokens according to frequency. Between 0 and 1."""
presence_penalty: float = 0.0
"""Penalizes repeated tokens. Between 0 and 1."""
use_beam_search: bool = False
"""Whether to use beam search instead of sampling."""
ignore_eos: bool = False
"""Whether to ignore the EOS token and continue generating tokens after
the EOS token is generated."""
logprobs: Optional[int] = None
"""Number of log probabilities to return per output token."""
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_model_deployment_vllm_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling vllm."""
return {
"best_of": self.best_of,
"frequency_penalty": self.frequency_penalty,
"ignore_eos": self.ignore_eos,
"logprobs": self.logprobs,
"max_tokens": self.max_tokens,
"model": self.model,
"n": self.n,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
"temperature": self.temperature,
"top_k": self.k,
"top_p": self.p,
"use_beam_search": self.use_beam_search,
}
def _construct_json_body(self, prompt: str, params: dict) -> dict:
return {
"prompt": prompt,
**params,
}
def _process_response(self, response_json: dict) -> str:
return response_json["choices"][0]["text"]