mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add YandexGPT LLM and Chat model (#11703)
**Description:** Introducing an ability to work with the [YandexGPT](https://cloud.yandex.com/en/services/yandexgpt) language model.
This commit is contained in:
parent
c4341463e8
commit
e8c1850369
109
docs/docs/integrations/chat/yandex.ipynb
Normal file
109
docs/docs/integrations/chat/yandex.ipynb
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "af63c9db-e4bd-4d3b-a4d7-7927f5541734",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# YandexGPT\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Langchain with [YandexGPT](https://cloud.yandex.com/en/services/yandexgpt) chat model.\n",
|
||||||
|
"\n",
|
||||||
|
"To use, you should have the `yandexcloud` python package installed."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f3a8f9cb-ff03-4fb8-8185-ff19f2b8fc89",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install yandexcloud"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "95fa21fb-3669-43fb-bb92-91de7bc591bc",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"First, you should [create service account](https://cloud.yandex.com/en/docs/iam/operations/sa/create) with the `ai.languageModels.user` role.\n",
|
||||||
|
"\n",
|
||||||
|
"Next, you have two authentication options:\n",
|
||||||
|
"- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n",
|
||||||
|
" You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n",
|
||||||
|
"- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n",
|
||||||
|
" You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "eba2d63b-f871-4f61-b55f-f6092bdc297a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ChatYandexGPT\n",
|
||||||
|
"from langchain.schema import HumanMessage, SystemMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "75905d9a-dfae-43aa-95b9-a160280e43f7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat_model = ChatYandexGPT()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "40844fe7-7fe5-4679-b6c9-1b3238807bdc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\"Je t'aime programmer.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"answer = chat_model(\n",
|
||||||
|
" [\n",
|
||||||
|
" SystemMessage(content=\"You are a helpful assistant that translates English to French.\"),\n",
|
||||||
|
" HumanMessage(content=\"I love programming.\")\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"answer"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.18"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
119
docs/docs/integrations/llms/yandex.ipynb
Normal file
119
docs/docs/integrations/llms/yandex.ipynb
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# YandexGPT\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook goes over how to use Langchain with [YandexGPT](https://cloud.yandex.com/en/services/yandexgpt).\n",
|
||||||
|
"\n",
|
||||||
|
"To use, you should have the `yandexcloud` python package installed."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install yandexcloud"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"First, you should [create service account](https://cloud.yandex.com/en/docs/iam/operations/sa/create) with the `ai.languageModels.user` role.\n",
|
||||||
|
"\n",
|
||||||
|
"Next, you have two authentication options:\n",
|
||||||
|
"- [IAM token](https://cloud.yandex.com/en/docs/iam/operations/iam-token/create-for-sa).\n",
|
||||||
|
" You can specify the token in a constructor parameter `iam_token` or in an environment variable `YC_IAM_TOKEN`.\n",
|
||||||
|
"- [API key](https://cloud.yandex.com/en/docs/iam/operations/api-key/create)\n",
|
||||||
|
" You can specify the key in a constructor parameter `api_key` or in an environment variable `YC_API_KEY`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 246,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import LLMChain\n",
|
||||||
|
"from langchain.llms import YandexGPT\n",
|
||||||
|
"from langchain.prompts import PromptTemplate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 247,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"What is the capital of {country}?\"\n",
|
||||||
|
"prompt = PromptTemplate(template=template, input_variables=[\"country\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 248,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = YandexGPT()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 249,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 250,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'Moscow'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 250,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"country = \"Russia\"\n",
|
||||||
|
"\n",
|
||||||
|
"llm_chain.run(country)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.18"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
33
docs/docs/integrations/providers/yandex.mdx
Normal file
33
docs/docs/integrations/providers/yandex.mdx
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Yandex
|
||||||
|
|
||||||
|
All functionality related to Yandex Cloud
|
||||||
|
|
||||||
|
>[Yandex Cloud](https://cloud.yandex.com/en/) is a public cloud platform.
|
||||||
|
|
||||||
|
## Installation and Setup
|
||||||
|
|
||||||
|
Yandex Cloud SDK can be installed via pip from PyPI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install yandexcloud
|
||||||
|
```
|
||||||
|
|
||||||
|
## LLMs
|
||||||
|
|
||||||
|
### YandexGPT
|
||||||
|
|
||||||
|
See a [usage example](/docs/integrations/llms/yandex).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.llms import YandexGPT
|
||||||
|
```
|
||||||
|
|
||||||
|
## Chat models
|
||||||
|
|
||||||
|
### YandexGPT
|
||||||
|
|
||||||
|
See a [usage example](/docs/integrations/chat/yandex).
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.chat_models import ChatYandexGPT
|
||||||
|
```
|
@ -39,6 +39,7 @@ from langchain.chat_models.ollama import ChatOllama
|
|||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
|
||||||
from langchain.chat_models.vertexai import ChatVertexAI
|
from langchain.chat_models.vertexai import ChatVertexAI
|
||||||
|
from langchain.chat_models.yandex import ChatYandexGPT
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
@ -63,4 +64,5 @@ __all__ = [
|
|||||||
"ChatKonko",
|
"ChatKonko",
|
||||||
"QianfanChatEndpoint",
|
"QianfanChatEndpoint",
|
||||||
"ChatFireworks",
|
"ChatFireworks",
|
||||||
|
"ChatYandexGPT",
|
||||||
]
|
]
|
||||||
|
131
libs/langchain/langchain/chat_models/yandex.py
Normal file
131
libs/langchain/langchain/chat_models/yandex.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
"""Wrapper around YandexGPT chat models."""
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.llms.yandex import _BaseYandexGPT
|
||||||
|
from langchain.schema import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatGeneration,
|
||||||
|
ChatResult,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_message(role: str, text: str) -> Dict:
|
||||||
|
return {"role": role, "text": text}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_chat_history(history: List[BaseMessage]) -> Tuple[List[Dict[str, str]], str]:
|
||||||
|
"""Parse a sequence of messages into history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of a list of parsed messages and an instruction message for the model.
|
||||||
|
"""
|
||||||
|
chat_history = []
|
||||||
|
instruction = ""
|
||||||
|
for message in history:
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
chat_history.append(_parse_message("user", message.content))
|
||||||
|
if isinstance(message, AIMessage):
|
||||||
|
chat_history.append(_parse_message("assistant", message.content))
|
||||||
|
if isinstance(message, SystemMessage):
|
||||||
|
instruction = message.content
|
||||||
|
return chat_history, instruction
|
||||||
|
|
||||||
|
|
||||||
|
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
|
||||||
|
"""Wrapper around YandexGPT large language models.
|
||||||
|
|
||||||
|
There are two authentication options for the service account
|
||||||
|
with the ``ai.languageModels.user`` role:
|
||||||
|
- You can specify the token in a constructor parameter `iam_token`
|
||||||
|
or in an environment variable `YC_IAM_TOKEN`.
|
||||||
|
- You can specify the key in a constructor parameter `api_key`
|
||||||
|
or in an environment variable `YC_API_KEY`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatYandexGPT
|
||||||
|
chat_model = ChatYandexGPT(iam_token="t1.9eu...")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Generate next turn in the conversation.
|
||||||
|
Args:
|
||||||
|
messages: The history of the conversation as a list of messages.
|
||||||
|
stop: The list of stop words (optional).
|
||||||
|
run_manager: The CallbackManager for LLM run, it's not used at the moment.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ChatResult that contains outputs generated by the model.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the last message in the list is not from human.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions, Message
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import ChatRequest
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
||||||
|
TextGenerationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
if not messages:
|
||||||
|
raise ValueError(
|
||||||
|
"You should provide at least one message to start the chat!"
|
||||||
|
)
|
||||||
|
message_history, instruction = _parse_chat_history(messages)
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||||
|
request = ChatRequest(
|
||||||
|
model=self.model_name,
|
||||||
|
generation_options=GenerationOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
instruction_text=instruction,
|
||||||
|
messages=[Message(**message) for message in message_history],
|
||||||
|
)
|
||||||
|
stub = TextGenerationServiceStub(channel)
|
||||||
|
if self.iam_token:
|
||||||
|
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||||
|
else:
|
||||||
|
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||||
|
res = stub.Chat(request, metadata=metadata)
|
||||||
|
text = list(res)[0].message.text
|
||||||
|
text = text if stop is None else enforce_stop_tokens(text, stop)
|
||||||
|
message = AIMessage(content=text)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""YandexGPT doesn't support async requests at the moment."""
|
||||||
|
)
|
@ -480,6 +480,12 @@ def _import_xinference() -> Any:
|
|||||||
return Xinference
|
return Xinference
|
||||||
|
|
||||||
|
|
||||||
|
def _import_yandex_gpt() -> Any:
|
||||||
|
from langchain.llms.yandex import YandexGPT
|
||||||
|
|
||||||
|
return YandexGPT
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str) -> Any:
|
def __getattr__(name: str) -> Any:
|
||||||
if name == "AI21":
|
if name == "AI21":
|
||||||
return _import_ai21()
|
return _import_ai21()
|
||||||
@ -633,6 +639,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_writer()
|
return _import_writer()
|
||||||
elif name == "Xinference":
|
elif name == "Xinference":
|
||||||
return _import_xinference()
|
return _import_xinference()
|
||||||
|
elif name == "YandexGPT":
|
||||||
|
return _import_yandex_gpt()
|
||||||
elif name == "type_to_cls_dict":
|
elif name == "type_to_cls_dict":
|
||||||
# for backwards compatibility
|
# for backwards compatibility
|
||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||||
@ -719,6 +727,7 @@ __all__ = [
|
|||||||
"Xinference",
|
"Xinference",
|
||||||
"JavelinAIGateway",
|
"JavelinAIGateway",
|
||||||
"QianfanLLMEndpoint",
|
"QianfanLLMEndpoint",
|
||||||
|
"YandexGPT",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -794,4 +803,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"xinference": _import_xinference,
|
"xinference": _import_xinference,
|
||||||
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
"javelin-ai-gateway": _import_javelin_ai_gateway,
|
||||||
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
|
||||||
|
"yandex_gpt": _import_yandex_gpt,
|
||||||
}
|
}
|
||||||
|
130
libs/langchain/langchain/llms/yandex.py
Normal file
130
libs/langchain/langchain/llms/yandex.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
from langchain.llms.utils import enforce_stop_tokens
|
||||||
|
from langchain.load.serializable import Serializable
|
||||||
|
from langchain.pydantic_v1 import root_validator
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseYandexGPT(Serializable):
|
||||||
|
iam_token: str = ""
|
||||||
|
"""Yandex Cloud IAM token for service account
|
||||||
|
with the `ai.languageModels.user` role"""
|
||||||
|
api_key: str = ""
|
||||||
|
"""Yandex Cloud Api Key for service account
|
||||||
|
with the `ai.languageModels.user` role"""
|
||||||
|
model_name: str = "general"
|
||||||
|
"""Model name to use."""
|
||||||
|
temperature: float = 0.6
|
||||||
|
"""What sampling temperature to use.
|
||||||
|
Should be a double number between 0 (inclusive) and 1 (inclusive)."""
|
||||||
|
max_tokens: int = 7400
|
||||||
|
"""Sets the maximum limit on the total number of tokens
|
||||||
|
used for both the input prompt and the generated response.
|
||||||
|
Must be greater than zero and not exceed 7400 tokens."""
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
"""Sequences when completion generation will stop."""
|
||||||
|
url: str = "llm.api.cloud.yandex.net:443"
|
||||||
|
"""The url of the API."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "yandex_gpt"
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that iam token exists in environment."""
|
||||||
|
|
||||||
|
iam_token = get_from_dict_or_env(values, "iam_token", "YC_IAM_TOKEN", "")
|
||||||
|
values["iam_token"] = iam_token
|
||||||
|
api_key = get_from_dict_or_env(values, "api_key", "YC_API_KEY", "")
|
||||||
|
values["api_key"] = api_key
|
||||||
|
if api_key == "" and iam_token == "":
|
||||||
|
raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class YandexGPT(_BaseYandexGPT, LLM):
|
||||||
|
"""Yandex large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``yandexcloud`` python package installed.
|
||||||
|
|
||||||
|
There are two authentication options for the service account
|
||||||
|
with the ``ai.languageModels.user`` role:
|
||||||
|
- You can specify the token in a constructor parameter `iam_token`
|
||||||
|
or in an environment variable `YC_IAM_TOKEN`.
|
||||||
|
- You can specify the key in a constructor parameter `api_key`
|
||||||
|
or in an environment variable `YC_API_KEY`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import YandexGPT
|
||||||
|
yandex_gpt = YandexGPT(iam_token="t1.9eu...")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"stop": self.stop,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call the Yandex GPT model and return the output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
response = YandexGPT("Tell me a joke.")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import grpc
|
||||||
|
from google.protobuf.wrappers_pb2 import DoubleValue, Int64Value
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_pb2 import GenerationOptions
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2 import InstructRequest
|
||||||
|
from yandex.cloud.ai.llm.v1alpha.llm_service_pb2_grpc import (
|
||||||
|
TextGenerationServiceStub,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install YandexCloud SDK" " with `pip install yandexcloud`."
|
||||||
|
) from e
|
||||||
|
channel_credentials = grpc.ssl_channel_credentials()
|
||||||
|
channel = grpc.secure_channel(self.url, channel_credentials)
|
||||||
|
request = InstructRequest(
|
||||||
|
model=self.model_name,
|
||||||
|
request_text=prompt,
|
||||||
|
generation_options=GenerationOptions(
|
||||||
|
temperature=DoubleValue(value=self.temperature),
|
||||||
|
max_tokens=Int64Value(value=self.max_tokens),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
stub = TextGenerationServiceStub(channel)
|
||||||
|
if self.iam_token:
|
||||||
|
metadata = (("authorization", f"Bearer {self.iam_token}"),)
|
||||||
|
else:
|
||||||
|
metadata = (("authorization", f"Api-Key {self.api_key}"),)
|
||||||
|
res = stub.Instruct(request, metadata=metadata)
|
||||||
|
text = list(res)[0].alternatives[0].text
|
||||||
|
if stop is not None:
|
||||||
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
return text
|
Loading…
Reference in New Issue
Block a user