add _acall method to YandexGPT (#12029)

- **Description:** Add async support for YandexGPT LLM model

Co-authored-by: Dmitry Tyumentsev <dmitry.tyumentsev@raftds.com>
pull/12040/head
Dmitry Tyumentsev 12 months ago committed by GitHub
parent 720ecacb1c
commit 5dd2161c4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,9 @@
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.load.serializable import Serializable
@ -128,3 +131,75 @@ class YandexGPT(_BaseYandexGPT, LLM):
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call the Yandex GPT model and return the output.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
try:
import asyncio
import grpc
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
InstructRequest,
InstructResponse,
)
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
TextGenerationAsyncServiceStub,
)
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
from yandex.cloud.operation.operation_service_pb2_grpc import (
OperationServiceStub,
)
except ImportError as e:
raise ImportError(
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
) from e
operation_api_url = "operation.api.cloud.yandex.net:443"
channel_credentials = grpc.ssl_channel_credentials()
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
request = InstructRequest(
model=self.model_name,
request_text=prompt,
generation_options=GenerationOptions(
temperature=DoubleValue(value=self.temperature),
max_tokens=Int64Value(value=self.max_tokens),
),
)
stub = TextGenerationAsyncServiceStub(channel)
if self.iam_token:
metadata = (("authorization", f"Bearer {self.iam_token}"),)
else:
metadata = (("authorization", f"Api-Key {self.api_key}"),)
operation = await stub.Instruct(request, metadata=metadata)
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
operation_stub = OperationServiceStub(operation_channel)
while not operation.done:
await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request, metadata=metadata
)
instruct_response = InstructResponse()
operation.response.Unpack(instruct_response)
text = instruct_response.alternatives[0].text
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

Loading…
Cancel
Save