mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
added support for inference from Model Garden (#9367)
#8850 --------- Co-authored-by: Leonid Kuligin <kuligin@google.com>
This commit is contained in:
parent
54a8df87b9
commit
30239b3025
@ -206,6 +206,68 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"llm_chain.run(question)"
|
"llm_chain.run(question)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Using models deployed on Vertex Model Garden"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Vertex Model Garden [exposes](https://cloud.google.com/vertex-ai/docs/start/explore-models) open-sourced models that can be deployed and served on Vertex AI. If you have successfully deployed a model from Vertex Model Garden, you can find a corresponding Vertex AI [endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment#what_happens_when_you_deploy_a_model) in the console or via API."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.llms import VertexAIModelGarden"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_oss = VertexAIModelGarden(\n",
|
||||||
|
" project=\"YOUR PROJECT\",\n",
|
||||||
|
" endpoint_id=\"YOUR ENDPOINT_ID\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_oss(\"What is the meaning of life?\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also use it as a chain:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_oss_chain = LLMChain(prompt=prompt, llm=llm_oss(\"What is the meaning of life?\")\n",
|
||||||
|
")\n",
|
||||||
|
"llm_oss_chain.run(question)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -81,7 +81,7 @@ from langchain.llms.symblai_nebula import Nebula
|
|||||||
from langchain.llms.textgen import TextGen
|
from langchain.llms.textgen import TextGen
|
||||||
from langchain.llms.titan_takeoff import TitanTakeoff
|
from langchain.llms.titan_takeoff import TitanTakeoff
|
||||||
from langchain.llms.tongyi import Tongyi
|
from langchain.llms.tongyi import Tongyi
|
||||||
from langchain.llms.vertexai import VertexAI
|
from langchain.llms.vertexai import VertexAI, VertexAIModelGarden
|
||||||
from langchain.llms.vllm import VLLM, VLLMOpenAI
|
from langchain.llms.vllm import VLLM, VLLMOpenAI
|
||||||
from langchain.llms.writer import Writer
|
from langchain.llms.writer import Writer
|
||||||
from langchain.llms.xinference import Xinference
|
from langchain.llms.xinference import Xinference
|
||||||
@ -152,6 +152,7 @@ __all__ = [
|
|||||||
"TitanTakeoff",
|
"TitanTakeoff",
|
||||||
"Tongyi",
|
"Tongyi",
|
||||||
"VertexAI",
|
"VertexAI",
|
||||||
|
"VertexAIModelGarden",
|
||||||
"VLLM",
|
"VLLM",
|
||||||
"VLLMOpenAI",
|
"VLLMOpenAI",
|
||||||
"Writer",
|
"Writer",
|
||||||
@ -217,6 +218,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
|||||||
"tongyi": Tongyi,
|
"tongyi": Tongyi,
|
||||||
"titan_takeoff": TitanTakeoff,
|
"titan_takeoff": TitanTakeoff,
|
||||||
"vertexai": VertexAI,
|
"vertexai": VertexAI,
|
||||||
|
"vertexai_model_garden": VertexAIModelGarden,
|
||||||
"openllm": OpenLLM,
|
"openllm": OpenLLM,
|
||||||
"openllm_client": OpenLLM,
|
"openllm_client": OpenLLM,
|
||||||
"vllm": VLLM,
|
"vllm": VLLM,
|
||||||
|
@ -11,12 +11,17 @@ from langchain.callbacks.manager import (
|
|||||||
from langchain.llms.base import LLM, create_base_retry_decorator
|
from langchain.llms.base import LLM, create_base_retry_decorator
|
||||||
from langchain.llms.utils import enforce_stop_tokens
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||||
|
from langchain.schema import (
|
||||||
|
Generation,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
from langchain.utilities.vertexai import (
|
from langchain.utilities.vertexai import (
|
||||||
init_vertexai,
|
init_vertexai,
|
||||||
raise_vertex_import_error,
|
raise_vertex_import_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
||||||
from vertexai.language_models._language_models import _LanguageModel
|
from vertexai.language_models._language_models import _LanguageModel
|
||||||
|
|
||||||
|
|
||||||
@ -57,10 +62,40 @@ def completion_with_retry(llm: VertexAI, *args: Any, **kwargs: Any) -> Any:
|
|||||||
return _completion_with_retry(*args, **kwargs)
|
return _completion_with_retry(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _VertexAICommon(BaseModel):
|
class _VertexAIBase(BaseModel):
|
||||||
|
project: Optional[str] = None
|
||||||
|
"The default GCP project to use when making Vertex API calls."
|
||||||
|
location: str = "us-central1"
|
||||||
|
"The default location to use when making API calls."
|
||||||
|
request_parallelism: int = 5
|
||||||
|
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
||||||
|
"Default is 5."
|
||||||
|
max_retries: int = 6
|
||||||
|
"""The maximum number of retries to make when generating."""
|
||||||
|
task_executor: ClassVar[Optional[Executor]] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
"Optional list of stop words to use when generating."
|
||||||
|
model_name: Optional[str] = None
|
||||||
|
"Underlying model name."
|
||||||
|
|
||||||
|
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
if stop is None and self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
|
if stop:
|
||||||
|
return enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
||||||
|
if cls.task_executor is None:
|
||||||
|
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
|
||||||
|
return cls.task_executor
|
||||||
|
|
||||||
|
|
||||||
|
class _VertexAICommon(_VertexAIBase):
|
||||||
client: "_LanguageModel" = None #: :meta private:
|
client: "_LanguageModel" = None #: :meta private:
|
||||||
model_name: str
|
model_name: str
|
||||||
"Model name to use."
|
"Underlying model name."
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
"Sampling temperature, it controls the degree of randomness in token selection."
|
"Sampling temperature, it controls the degree of randomness in token selection."
|
||||||
max_output_tokens: int = 128
|
max_output_tokens: int = 128
|
||||||
@ -71,27 +106,20 @@ class _VertexAICommon(BaseModel):
|
|||||||
top_k: int = 40
|
top_k: int = 40
|
||||||
"How the model selects tokens for output, the next token is selected from "
|
"How the model selects tokens for output, the next token is selected from "
|
||||||
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
"among the top-k most probable tokens. Top-k is ignored for Codey models."
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
"Optional list of stop words to use when generating."
|
|
||||||
project: Optional[str] = None
|
|
||||||
"The default GCP project to use when making Vertex API calls."
|
|
||||||
location: str = "us-central1"
|
|
||||||
"The default location to use when making API calls."
|
|
||||||
credentials: Any = None
|
credentials: Any = None
|
||||||
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
"The default custom credentials (google.auth.credentials.Credentials) to use "
|
||||||
"when making API calls. If not provided, credentials will be ascertained from "
|
"when making API calls. If not provided, credentials will be ascertained from "
|
||||||
"the environment."
|
"the environment."
|
||||||
request_parallelism: int = 5
|
|
||||||
"The amount of parallelism allowed for requests issued to VertexAI models. "
|
|
||||||
"Default is 5."
|
|
||||||
max_retries: int = 6
|
|
||||||
"""The maximum number of retries to make when generating."""
|
|
||||||
task_executor: ClassVar[Optional[Executor]] = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_codey_model(self) -> bool:
|
def is_codey_model(self) -> bool:
|
||||||
return is_codey_model(self.model_name)
|
return is_codey_model(self.model_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
if self.is_codey_model:
|
if self.is_codey_model:
|
||||||
@ -114,28 +142,10 @@ class _VertexAICommon(BaseModel):
|
|||||||
res = completion_with_retry(self, prompt, **params) # type: ignore
|
res = completion_with_retry(self, prompt, **params) # type: ignore
|
||||||
return self._enforce_stop_words(res.text, stop)
|
return self._enforce_stop_words(res.text, stop)
|
||||||
|
|
||||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
|
||||||
if stop is None and self.stop is not None:
|
|
||||||
stop = self.stop
|
|
||||||
if stop:
|
|
||||||
return enforce_stop_tokens(text, stop)
|
|
||||||
return text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "vertexai"
|
return "vertexai"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_task_executor(cls, request_parallelism: int = 5) -> Executor:
|
|
||||||
if cls.task_executor is None:
|
|
||||||
cls.task_executor = ThreadPoolExecutor(max_workers=request_parallelism)
|
|
||||||
return cls.task_executor
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _try_init_vertexai(cls, values: Dict) -> None:
|
def _try_init_vertexai(cls, values: Dict) -> None:
|
||||||
allowed_params = ["project", "location", "credentials"]
|
allowed_params = ["project", "location", "credentials"]
|
||||||
@ -176,27 +186,6 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
raise_vertex_import_error()
|
raise_vertex_import_error()
|
||||||
return values
|
return values
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
"""Call Vertex model to get predictions based on the prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
stop: A list of stop words (optional).
|
|
||||||
run_manager: A callback manager for async interaction with LLMs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The string generated by the model.
|
|
||||||
"""
|
|
||||||
return await asyncio.wrap_future(
|
|
||||||
self._get_task_executor().submit(self._predict, prompt, stop)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -215,3 +204,145 @@ class VertexAI(_VertexAICommon, LLM):
|
|||||||
The string generated by the model.
|
The string generated by the model.
|
||||||
"""
|
"""
|
||||||
return self._predict(prompt, stop, **kwargs)
|
return self._predict(prompt, stop, **kwargs)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call Vertex model to get predictions based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: A list of stop words (optional).
|
||||||
|
run_manager: A callback manager for async interaction with LLMs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
return await asyncio.wrap_future(
|
||||||
|
self._get_task_executor().submit(self._call, prompt, stop)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIModelGarden(_VertexAIBase, LLM):
|
||||||
|
"""Large language models served from Vertex AI Model Garden."""
|
||||||
|
|
||||||
|
client: "PredictionServiceClient" = None #: :meta private:
|
||||||
|
endpoint_id: str
|
||||||
|
"A name of an endpoint where the model has been deployed."
|
||||||
|
allowed_model_args: Optional[List[str]] = None
|
||||||
|
"""Allowed optional args to be passed to the model."""
|
||||||
|
prompt_arg: str = "prompt"
|
||||||
|
result_arg: str = "generated_text"
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that the python package exists in environment."""
|
||||||
|
try:
|
||||||
|
from google.cloud.aiplatform.gapic import PredictionServiceClient
|
||||||
|
except ImportError:
|
||||||
|
raise_vertex_import_error()
|
||||||
|
|
||||||
|
if values["project"] is None:
|
||||||
|
raise ValueError(
|
||||||
|
"A GCP project should be provided to run inference on Model Garden!"
|
||||||
|
)
|
||||||
|
|
||||||
|
client_options = {
|
||||||
|
"api_endpoint": f"{values['location']}-aiplatform.googleapis.com"
|
||||||
|
}
|
||||||
|
values["client"] = PredictionServiceClient(client_options=client_options)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "vertexai_model_garden"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call Vertex model to get predictions based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: A list of stop words (optional).
|
||||||
|
run_manager: A Callbackmanager for LLM run, optional.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
result = self._generate(
|
||||||
|
prompts=[prompt], stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return result.generations[0][0].text
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
try:
|
||||||
|
from google.protobuf import json_format
|
||||||
|
from google.protobuf.struct_pb2 import Value
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"protobuf package not found, please install it with"
|
||||||
|
" `pip install protobuf`"
|
||||||
|
)
|
||||||
|
|
||||||
|
instances = []
|
||||||
|
for prompt in prompts:
|
||||||
|
if self.allowed_model_args:
|
||||||
|
instance = {
|
||||||
|
k: v for k, v in kwargs.items() if k in self.allowed_model_args
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
instance = {}
|
||||||
|
instance[self.prompt_arg] = prompt
|
||||||
|
instances.append(instance)
|
||||||
|
|
||||||
|
predict_instances = [
|
||||||
|
json_format.ParseDict(instance_dict, Value()) for instance_dict in instances
|
||||||
|
]
|
||||||
|
|
||||||
|
endpoint = self.client.endpoint_path(
|
||||||
|
project=self.project, location=self.location, endpoint=self.endpoint_id
|
||||||
|
)
|
||||||
|
response = self.client.predict(endpoint=endpoint, instances=predict_instances)
|
||||||
|
generations: List[List[Generation]] = []
|
||||||
|
for result in response.predictions:
|
||||||
|
generations.append(
|
||||||
|
[Generation(text=prediction[self.result_arg]) for prediction in result]
|
||||||
|
)
|
||||||
|
return LLMResult(generations=generations)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call Vertex model to get predictions based on the prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: A list of stop words (optional).
|
||||||
|
run_manager: A callback manager for async interaction with LLMs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
"""
|
||||||
|
return await asyncio.wrap_future(
|
||||||
|
self._get_task_executor().submit(self._call, prompt, stop)
|
||||||
|
)
|
||||||
|
@ -7,7 +7,10 @@ pip install google-cloud-aiplatform>=1.25.0
|
|||||||
Your end-user credentials would be used to make the calls (make sure you've run
|
Your end-user credentials would be used to make the calls (make sure you've run
|
||||||
`gcloud auth login` first).
|
`gcloud auth login` first).
|
||||||
"""
|
"""
|
||||||
from langchain.llms import VertexAI
|
import os
|
||||||
|
|
||||||
|
from langchain.llms import VertexAI, VertexAIModelGarden
|
||||||
|
from langchain.schema import LLMResult
|
||||||
|
|
||||||
|
|
||||||
def test_vertex_call() -> None:
|
def test_vertex_call() -> None:
|
||||||
@ -16,3 +19,34 @@ def test_vertex_call() -> None:
|
|||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
assert llm._llm_type == "vertexai"
|
assert llm._llm_type == "vertexai"
|
||||||
assert llm.model_name == llm.client._model_id
|
assert llm.model_name == llm.client._model_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_garden() -> None:
|
||||||
|
"""In order to run this test, you should provide an endpoint name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
export ENDPOINT_ID=...
|
||||||
|
export PROJECT=...
|
||||||
|
"""
|
||||||
|
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||||
|
project = os.environ["PROJECT"]
|
||||||
|
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||||
|
output = llm("What is the meaning of life?")
|
||||||
|
print(output)
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert llm._llm_type == "vertexai_model_garden"
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_garden_batch() -> None:
|
||||||
|
"""In order to run this test, you should provide an endpoint name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
export ENDPOINT_ID=...
|
||||||
|
export PROJECT=...
|
||||||
|
"""
|
||||||
|
endpoint_id = os.environ["ENDPOINT_ID"]
|
||||||
|
project = os.environ["PROJECT"]
|
||||||
|
llm = VertexAIModelGarden(endpoint_id=endpoint_id, project=project)
|
||||||
|
output = llm._generate(["What is the meaning of life?", "How much is 2+2"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
assert len(output.generations) == 2
|
||||||
|
Loading…
Reference in New Issue
Block a user