mirror of https://github.com/hwchase17/langchain
feat: add bedrock chat model (#8017)
Replace this comment with: - Description: Add Bedrock implementation of Anthropic Claude for Chat - Tag maintainer: @hwchase17, @baskaryan - Twitter handle: @bwmatson --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/10133/head
parent
a7c9bd30d4
commit
58d7d86e51
@ -0,0 +1,106 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "bf733a38-db84-4363-89e2-de6735c37230",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Bedrock Chat\n",
|
||||||
|
"\n",
|
||||||
|
"[Amazon Bedrock](https://aws.amazon.com/bedrock/) is a fully managed service that makes FMs from leading AI startups and Amazon available via an API, so you can choose from a wide range of FMs to find the model that is best suited for your use case"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d51edc81",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install boto3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "d4a7c55d-b235-4ca4-a579-c90cc9570da9",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import BedrockChat\n",
|
||||||
|
"from langchain.schema import HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "70cf04e8-423a-4ff6-8b09-f11fb711c817",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = BedrockChat(model_id=\"anthropic.claude-v2\", model_kwargs={\"temperature\":0.1})"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content=\" Voici la traduction en français : J'adore programmer.\", additional_kwargs={}, example=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [\n",
|
||||||
|
" HumanMessage(\n",
|
||||||
|
" content=\"Translate this sentence from English to French. I love programming.\"\n",
|
||||||
|
" )\n",
|
||||||
|
"]\n",
|
||||||
|
"chat(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "c253883f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.anthropic import convert_messages_to_prompt_anthropic
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.llms.bedrock import BedrockBase
|
||||||
|
from langchain.pydantic_v1 import Extra
|
||||||
|
from langchain.schema.messages import AIMessage, BaseMessage
|
||||||
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptAdapter:
|
||||||
|
"""Adapter class to prepare the inputs from Langchain to prompt format
|
||||||
|
that Chat model expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_messages_to_prompt(
|
||||||
|
cls, provider: str, messages: List[BaseMessage]
|
||||||
|
) -> str:
|
||||||
|
if provider == "anthropic":
|
||||||
|
prompt = convert_messages_to_prompt_anthropic(messages=messages)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {provider} model does not support chat."
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockChat(BaseChatModel, BedrockBase):
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "amazon_bedrock_chat"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support stream requests at the moment."""
|
||||||
|
)
|
||||||
|
|
||||||
|
def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support async requests at the moment."""
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
provider = self._get_provider()
|
||||||
|
prompt = ChatPromptAdapter.convert_messages_to_prompt(
|
||||||
|
provider=provider, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
params: Dict[str, Any] = {**kwargs}
|
||||||
|
if stop:
|
||||||
|
params["stop_sequences"] = stop
|
||||||
|
|
||||||
|
completion = self._prepare_input_and_invoke(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **params
|
||||||
|
)
|
||||||
|
|
||||||
|
message = AIMessage(content=completion)
|
||||||
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"""Bedrock doesn't support async stream requests at the moment."""
|
||||||
|
)
|
Loading…
Reference in New Issue