|
|
|
@ -1,6 +1,9 @@
|
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.llms.base import LLM
|
|
|
|
|
from langchain.llms.utils import enforce_stop_tokens
|
|
|
|
|
from langchain.load.serializable import Serializable
|
|
|
|
@ -128,3 +131,75 @@ class YandexGPT(_BaseYandexGPT, LLM):
|
|
|
|
|
if stop is not None:
|
|
|
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
async def _acall(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> str:
|
|
|
|
|
"""Async call the Yandex GPT model and return the output.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: The prompt to pass into the model.
|
|
|
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The string generated by the model.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
import asyncio
|
|
|
|
|
|
|
|
|
|
import grpc
|
|
|
|
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
|
|
|
|
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
|
|
|
|
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import (
|
|
|
|
|
InstructRequest,
|
|
|
|
|
InstructResponse,
|
|
|
|
|
)
|
|
|
|
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
|
|
|
|
TextGenerationAsyncServiceStub,
|
|
|
|
|
)
|
|
|
|
|
from yandex.cloud.operation.operation_service_pb2 import GetOperationRequest
|
|
|
|
|
from yandex.cloud.operation.operation_service_pb2_grpc import (
|
|
|
|
|
OperationServiceStub,
|
|
|
|
|
)
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
|
|
|
|
) from e
|
|
|
|
|
operation_api_url = "operation.api.cloud.yandex.net:443"
|
|
|
|
|
channel_credentials = grpc.ssl_channel_credentials()
|
|
|
|
|
async with grpc.aio.secure_channel(self.url, channel_credentials) as channel:
|
|
|
|
|
request = InstructRequest(
|
|
|
|
|
model=self.model_name,
|
|
|
|
|
request_text=prompt,
|
|
|
|
|
generation_options=GenerationOptions(
|
|
|
|
|
temperature=DoubleValue(value=self.temperature),
|
|
|
|
|
max_tokens=Int64Value(value=self.max_tokens),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
stub = TextGenerationAsyncServiceStub(channel)
|
|
|
|
|
if self.iam_token:
|
|
|
|
|
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
|
|
|
|
else:
|
|
|
|
|
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
|
|
|
|
operation = await stub.Instruct(request, metadata=metadata)
|
|
|
|
|
async with grpc.aio.secure_channel(
|
|
|
|
|
operation_api_url, channel_credentials
|
|
|
|
|
) as operation_channel:
|
|
|
|
|
operation_stub = OperationServiceStub(operation_channel)
|
|
|
|
|
while not operation.done:
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
operation_request = GetOperationRequest(operation_id=operation.id)
|
|
|
|
|
operation = await operation_stub.Get(
|
|
|
|
|
operation_request, metadata=metadata
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
instruct_response = InstructResponse()
|
|
|
|
|
operation.response.Unpack(instruct_response)
|
|
|
|
|
text = instruct_response.alternatives[0].text
|
|
|
|
|
if stop is not None:
|
|
|
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
|
return text
|
|
|
|
|