added support for inference from Model Garden (#9367)

#8850

---------

Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
Leonid Kuligin 2023-09-02 00:58:21 +02:00 committed by GitHub
parent 54a8df87b9
commit 30239b3025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 284 additions and 55 deletions

View File

@ -206,6 +206,68 @@
"\n", "\n",
"llm_chain.run(question)" "llm_chain.run(question)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using models deployed on Vertex Model Garden"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vertex Model Garden [exposes](https://cloud.google.com/vertex-ai/docs/start/explore-models) open-sourced models that can be deployed and served on Vertex AI. If you have successfully deployed a model from Vertex Model Garden, you can find a corresponding Vertex AI [endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#what_happens_when_you_deploy_a_model) in the console or via API."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import VertexAIModelGarden"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_oss = VertexAIModelGarden(\n",
" project=\"YOUR PROJECT\",\n",
" endpoint_id=\"YOUR ENDPOINT_ID\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_oss(\"What is the meaning of life?\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use it as a chain:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_oss_chain = LLMChain(prompt=prompt, llm=llm_oss(\"What is the meaning of life?\")\n",
")\n",
"llm_oss_chain.run(question)"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -81,7 +81,7 @@ from langchain.llms.symblai_nebula import Nebula
from langchain.llms.textgen import TextGen from langchain.llms.textgen import TextGen
from langchain.llms.titan_takeoff import TitanTakeoff from langchain.llms.titan_takeoff import TitanTakeoff
from langchain.llms.tongyi import Tongyi from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI from langchain.llms.vertexai import VertexAI, VertexAIModelGarden
from langchain.llms.vllm import VLLM, VLLMOpenAI from langchain.llms.vllm import VLLM, VLLMOpenAI
from langchain.llms.writer import Writer from langchain.llms.writer import Writer
from langchain.llms.xinference import Xinference from langchain.llms.xinference import Xinference
@ -152,6 +152,7 @@ __all__ = [
"TitanTakeoff", "TitanTakeoff",
"Tongyi", "Tongyi",
"VertexAI", "VertexAI",
"VertexAIModelGarden",
"VLLM", "VLLM",
"VLLMOpenAI", "VLLMOpenAI",
"Writer", "Writer",
@ -217,6 +218,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"tongyi": Tongyi, "tongyi": Tongyi,
"titan_takeoff": TitanTakeoff, "titan_takeoff": TitanTakeoff,
"vertexai": VertexAI, "vertexai": VertexAI,
"vertexai_model_garden": VertexAIModelGarden,
"openllm": OpenLLM, "openllm": OpenLLM,
"openllm_client": OpenLLM, "openllm_client": OpenLLM,
"vllm": VLLM, "vllm": VLLM,

View File

@ -11,12 +11,17 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM, create_base_retry_decorator from langchain.llms.base import LLM, create_base_retry_decorator
from langchain.llms.utils import enforce_stop_tokens from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import BaseModel, root_validator from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
Generation,
LLMResult,
)
from langchain.utilities.vertexai import ( from langchain.utilities.vertexai import (
init_vertexai, init_vertexai,
raise_vertex_import_error, raise_vertex_import_error,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from google.cloud.aiplatform.gapic import PredictionServiceClient
from vertexai.language_models._language_models import _LanguageModel from vertexai.language_models._language_models import _LanguageModel
@ -57,10 +62,40 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
return _completion_with_retry(*args, **kwargs) return _completion_with_retry(*args, **kwargs)
class _VertexAICommon(BaseModel): class _VertexAIBase(BaseModel):
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = None
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
model_name: Optional[str] = None
"Underlying model name."
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
if stop is None and self.stop is not None:
stop = self.stop
if stop:
return enforce_stop_tokens(text, stop)
return text
@classmethod
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
if cls.task_executor is None:
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
return cls.task_executor
class _VertexAICommon(_VertexAIBase):
client: "_LanguageModel" = None #: :meta private: client: "_LanguageModel" = None #: :meta private:
model_name: str model_name: str
"Model name to use." "Underlying model name."
temperature: float = 0.0 temperature: float = 0.0
"Sampling temperature, it controls the degree of randomness in token selection." "Sampling temperature, it controls the degree of randomness in token selection."
max_output_tokens: int = 128 max_output_tokens: int = 128
@ -71,27 +106,20 @@ class _VertexAICommon(BaseModel):
top_k: int = 40 top_k: int = 40
"How the model selects tokens for output, the next token is selected from " "How the model selects tokens for output, the next token is selected from "
"among the top-k most probable tokens. Top-k is ignored for Codey models." "among the top-k most probable tokens. Top-k is ignored for Codey models."
stop: Optional[List[str]] = None
"Optional list of stop words to use when generating."
project: Optional[str] = None
"The default GCP project to use when making Vertex API calls."
location: str = "us-central1"
"The default location to use when making API calls."
credentials: Any = None credentials: Any = None
"The default custom credentials (google.auth.credentials.Credentials) to use " "The default custom credentials (google.auth.credentials.Credentials) to use "
"when making API calls. If not provided, credentials will be ascertained from " "when making API calls. If not provided, credentials will be ascertained from "
"the environment." "the environment."
request_parallelism: int = 5
"The amount of parallelism allowed for requests issued to VertexAI models. "
"Default is 5."
max_retries: int = 6
"""The maximum number of retries to make when generating."""
task_executor: ClassVar[Optional[Executor]] = None
@property @property
def is_codey_model(self) -> bool: def is_codey_model(self) -> bool:
return is_codey_model(self.model_name) return is_codey_model(self.model_name)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
if self.is_codey_model: if self.is_codey_model:
@ -114,28 +142,10 @@ class _VertexAICommon(BaseModel):
res = completion_with_retry(self, prompt, **params) # type: ignore res = completion_with_retry(self, prompt, **params) # type: ignore
return self._enforce_stop_words(res.text, stop) return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
if stop is None and self.stop is not None:
stop = self.stop
if stop:
return enforce_stop_tokens(text, stop)
return text
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "vertexai" return "vertexai"
@classmethod
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
if cls.task_executor is None:
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
return cls.task_executor
@classmethod @classmethod
def _try_init_vertexai(cls, values: Dict) -> None: def _try_init_vertexai(cls, values: Dict) -> None:
allowed_params = ["project", "location", "credentials"] allowed_params = ["project", "location", "credentials"]
@ -176,27 +186,6 @@ class VertexAI(_VertexAICommon, LLM):
raise_vertex_import_error() raise_vertex_import_error()
return values return values
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._predict, prompt, stop)
)
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -215,3 +204,145 @@ class VertexAI(_VertexAICommon, LLM):
The string generated by the model. The string generated by the model.
""" """
return self._predict(prompt, stop, **kwargs) return self._predict(prompt, stop, **kwargs)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
)
class VertexAIModelGarden(_VertexAIBase, LLM):
"""Large language models served from Vertex AI Model Garden."""
client: "PredictionServiceClient" = None #: :meta private:
endpoint_id: str
"A name of an endpoint where the model has been deployed."
allowed_model_args: Optional[List[str]] = None
"""Allowed optional args to be passed to the model."""
prompt_arg: str = "prompt"
result_arg: str = "generated_text"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
try:
from google.cloud.aiplatform.gapic import PredictionServiceClient
except ImportError:
raise_vertex_import_error()
if values["project"] is None:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)
client_options = {
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
}
values["client"] = PredictionServiceClient(client_options=client_options)
return values
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A Callbackmanager for LLM run, optional.
Returns:
The string generated by the model.
"""
result = self._generate(
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
)
return result.generations[0][0].text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
except ImportError:
raise ImportError(
"protobuf package not found, please install it with"
" `pip install protobuf`"
)
instances = []
for prompt in prompts:
if self.allowed_model_args:
instance = {
k: v for k, v in kwargs.items() if k in self.allowed_model_args
}
else:
instance = {}
instance[self.prompt_arg] = prompt
instances.append(instance)
predict_instances = [
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
]
endpoint = self.client.endpoint_path(
project=self.project, location=self.location, endpoint=self.endpoint_id
)
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
generations: List[List[Generation]] = []
for result in response.predictions:
generations.append(
[Generation(text=prediction[self.result_arg]) for prediction in result]
)
return LLMResult(generations=generations)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
Args:
prompt: The prompt to pass into the model.
stop: A list of stop words (optional).
run_manager: A callback manager for async interaction with LLMs.
Returns:
The string generated by the model.
"""
return await asyncio.wrap_future(
self._get_task_executor().submit(self._call, prompt, stop)
)

View File

@ -7,7 +7,10 @@ pip install google-cloud-aiplatform>=1.25.0
Your end-user credentials would be used to make the calls (make sure you've run Your end-user credentials would be used to make the calls (make sure you've run
`gcloud auth login` first). `gcloud auth login` first).
""" """
from langchain.llms import VertexAI import os
from langchain.llms import VertexAI, VertexAIModelGarden
from langchain.schema import LLMResult
def test_vertex_call() -> None: def test_vertex_call() -> None:
@ -16,3 +19,34 @@ def test_vertex_call() -> None:
assert isinstance(output, str) assert isinstance(output, str)
assert llm._llm_type == "vertexai" assert llm._llm_type == "vertexai"
assert llm.model_name == llm.client._model_id assert llm.model_name == llm.client._model_id
def test_model_garden() -> None:
"""In order to run this test, you should provide an endpoint name.
Example:
export ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = llm("What is the meaning of life?")
print(output)
assert isinstance(output, str)
assert llm._llm_type == "vertexai_model_garden"
def test_model_garden_batch() -> None:
"""In order to run this test, you should provide an endpoint name.
Example:
export ENDPOINT_ID=...
export PROJECT=...
"""
endpoint_id = os.environ["ENDPOINT_ID"]
project = os.environ["PROJECT"]
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
assert isinstance(output, LLMResult)
assert len(output.generations) == 2