langchain/libs/community/tests/integration_tests/llms/test_bedrock.py
Harel Gal a91181fe6d
community[minor]: add support for Guardrails for Amazon Bedrock (#15099)
Added support for optionally supplying 'Guardrails for Amazon Bedrock'
on both types of model invocations (batch/regular and streaming) and for
all models supported by the Amazon Bedrock service.

@baskaryan  @hwchase17

```python 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  guardrails={"id": " <guardrail_id>",
                              "version": "<guardrail_version>",
                               "trace": True}, callbacks=[BedrockAsyncCallbackHandler()])

class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
    """Async callback handler that can be used to handle callbacks from langchain."""

    async def on_llm_error(
            self,
            error: BaseException,
            **kwargs: Any,
    ) -> Any:
        reason = kwargs.get("reason")
        if reason == "GUARDRAIL_INTERVENED":
           # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service.
            print(f"""Guardrails: {kwargs}""")


# streaming 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  streaming=True,
                  guardrails={"id": "<guardrail_id>",
                              "version": "<guardrail_version>"})
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
2024-01-24 14:44:19 -08:00

137 lines
4.0 KiB
Python

"""
Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'.
You can get a list of models from the bedrock client by running 'bedrock_models()'
"""
import os
from typing import Any
import pytest
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_community.llms.bedrock import Bedrock
# this is the guardrails id for the model you want to test
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
# this is the guardrails version for the model you want to test
GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1")
# this should trigger the guardrails - you can change this to any text you want which
# will trigger the guardrails
GUARDRAILS_TRIGGER = os.environ.get(
"GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics."
)
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
guardrails_intervened = False
async def on_llm_error(
self,
error: BaseException,
**kwargs: Any,
) -> Any:
reason = kwargs.get("reason")
if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True
def get_response(self):
return self.guardrails_intervened
@pytest.fixture(autouse=True)
def bedrock_runtime_client():
import boto3
try:
client = boto3.client(
"bedrock-runtime",
region_name=os.environ.get("AWS_REGION", "us-east-1"),
)
return client
except Exception as e:
pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False)
@pytest.fixture(autouse=True)
def bedrock_client():
import boto3
try:
client = boto3.client(
"bedrock",
region_name=os.environ.get("AWS_REGION", "us-east-1"),
)
return client
except Exception as e:
pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False)
@pytest.fixture
def bedrock_models(bedrock_client):
"""List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {}
for model in response:
models[model.get("modelId")] = model.get("modelName")
return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
)
output = llm("Say something positive:")
assert isinstance(output, str)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
bedrock_runtime_client, bedrock_models
):
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
guardrails={
"id": GUARDRAILS_ID,
"version": GUARDRAILS_VERSION,
"trace": False,
},
)
output = llm("Say something positive:")
assert isinstance(output, str)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
bedrock_runtime_client, bedrock_models
):
try:
handler = BedrockAsyncCallbackHandler()
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
guardrails={
"id": GUARDRAILS_ID,
"version": GUARDRAILS_VERSION,
"trace": True,
},
callbacks=[handler],
)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
else:
llm(GUARDRAILS_TRIGGER)
guardrails_intervened = handler.get_response()
assert guardrails_intervened is True