From 1861cc7100bc5b4dd3e4e927bbaae5052f5fc99f Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Fri, 13 Oct 2023 09:48:24 -0700 Subject: [PATCH] General anthropic functions, steps towards experimental integration tests (#11727) To match change in js here https://github.com/langchain-ai/langchainjs/pull/2892 Some integration tests need a bit more work in experimental: ![Screenshot 2023-10-12 at 12 02 49 PM](https://github.com/langchain-ai/langchain/assets/9557659/262d7d22-c405-40e9-afef-669e8d585307) Pretty sure the sqldatabase ones are an actual regression or change in interface because it's returning a placeholder. --------- Co-authored-by: Bagatur --- libs/experimental/Makefile | 3 + .../llms/anthropic_functions.py | 10 +- libs/experimental/pyproject.toml | 3 +- .../integration_tests/chains/test_cpal.py | 2 +- .../integration_tests/chains/test_pal.py | 3 +- .../chains/test_sql_database.py | 5 +- .../llms/test_anthropic_functions.py | 109 ++++++++++++++++++ 7 files changed, 128 insertions(+), 7 deletions(-) create mode 100644 libs/experimental/tests/integration_tests/llms/test_anthropic_functions.py diff --git a/libs/experimental/Makefile b/libs/experimental/Makefile index a8b926d8f2..e3a12d1721 100644 --- a/libs/experimental/Makefile +++ b/libs/experimental/Makefile @@ -18,6 +18,9 @@ test_watch: extended_tests: poetry run pytest --only-extended tests/unit_tests +integration_tests: + poetry run pytest tests/integration_tests + ###################### # LINTING AND FORMATTING diff --git a/libs/experimental/langchain_experimental/llms/anthropic_functions.py b/libs/experimental/langchain_experimental/llms/anthropic_functions.py index 958ba254d5..68e0eac1f8 100644 --- a/libs/experimental/langchain_experimental/llms/anthropic_functions.py +++ b/libs/experimental/langchain_experimental/llms/anthropic_functions.py @@ -124,11 +124,17 @@ def _destrip(tool_input: Any) -> Any: class AnthropicFunctions(BaseChatModel): - model: ChatAnthropic + llm: BaseChatModel @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: - return {"model": ChatAnthropic(**values)} + values["llm"] = values.get("llm") or ChatAnthropic(**values) + return values + + @property + def model(self) -> BaseChatModel: + """For backwards compatibility.""" + return self.llm def _generate( self, diff --git a/libs/experimental/pyproject.toml b/libs/experimental/pyproject.toml index 1add2f45f1..8f5452fae7 100644 --- a/libs/experimental/pyproject.toml +++ b/libs/experimental/pyproject.toml @@ -84,5 +84,6 @@ addopts = "--strict-markers --strict-config --durations=5" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ - "requires: mark tests as requiring a specific library" + "requires: mark tests as requiring a specific library", + "asyncio: mark tests as requiring asyncio" ] diff --git a/libs/experimental/tests/integration_tests/chains/test_cpal.py b/libs/experimental/tests/integration_tests/chains/test_cpal.py index cec8b8da21..570f5656d9 100644 --- a/libs/experimental/tests/integration_tests/chains/test_cpal.py +++ b/libs/experimental/tests/integration_tests/chains/test_cpal.py @@ -39,7 +39,7 @@ from langchain_experimental.cpal.templates.univariate.narrative import ( from langchain_experimental.cpal.templates.univariate.query import ( template as query_template, ) -from tests.unit_tests.llms.fake_llm import FakeLLM +from tests.unit_tests.fake_llm import FakeLLM class TestUnitCPALChain_MathWordProblems(unittest.TestCase): diff --git a/libs/experimental/tests/integration_tests/chains/test_pal.py b/libs/experimental/tests/integration_tests/chains/test_pal.py index 6f3d83b5a1..2215a72a8b 100644 --- a/libs/experimental/tests/integration_tests/chains/test_pal.py +++ b/libs/experimental/tests/integration_tests/chains/test_pal.py @@ -1,8 +1,9 @@ """Test PAL chain.""" -from langchain.chains.pal.base import PALChain from langchain.llms import OpenAI +from langchain_experimental.pal_chain.base import PALChain + def test_math_prompt() -> None: """Test math prompt.""" diff --git a/libs/experimental/tests/integration_tests/chains/test_sql_database.py b/libs/experimental/tests/integration_tests/chains/test_sql_database.py index cf1dc445ca..a6ebe2df58 100644 --- a/libs/experimental/tests/integration_tests/chains/test_sql_database.py +++ b/libs/experimental/tests/integration_tests/chains/test_sql_database.py @@ -1,11 +1,12 @@ """Test SQL Database Chain.""" from langchain.llms.openai import OpenAI from langchain.utilities.sql_database import SQLDatabase -from libs.experimental.langchain_experimental.sql.base import ( +from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert + +from langchain_experimental.sql.base import ( SQLDatabaseChain, SQLDatabaseSequentialChain, ) -from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine, insert metadata_obj = MetaData() diff --git a/libs/experimental/tests/integration_tests/llms/test_anthropic_functions.py b/libs/experimental/tests/integration_tests/llms/test_anthropic_functions.py new file mode 100644 index 0000000000..008e268a35 --- /dev/null +++ b/libs/experimental/tests/integration_tests/llms/test_anthropic_functions.py @@ -0,0 +1,109 @@ +"""Test AnthropicFunctions""" + +import unittest + +from langchain.chat_models.anthropic import ChatAnthropic +from langchain.chat_models.bedrock import BedrockChat + +from langchain_experimental.llms.anthropic_functions import AnthropicFunctions + + +class TestAnthropicFunctions(unittest.TestCase): + """ + Test AnthropicFunctions with default llm (ChatAnthropic) as well as a passed-in llm + """ + + def test_default_chat_anthropic(self) -> None: + base_model = AnthropicFunctions(model="claude-2") + self.assertIsInstance(base_model.model, ChatAnthropic) + + # bind functions + model = base_model.bind( + functions=[ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, " + "e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + } + ], + function_call={"name": "get_current_weather"}, + ) + + res = model.invoke("What's the weather in San Francisco?") + + function_call = res.additional_kwargs.get("function_call") + assert function_call + self.assertEqual(function_call.get("name"), "get_current_weather") + self.assertEqual( + function_call.get("arguments"), + '{"location": "San Francisco, CA", "unit": "fahrenheit"}', + ) + + def test_bedrock_chat_anthropic(self) -> None: + """ + const chatBedrock = new ChatBedrock({ + region: process.env.BEDROCK_AWS_REGION ?? "us-east-1", + model: "anthropic.claude-v2", + temperature: 0.1, + credentials: { + secretAccessKey: process.env.BEDROCK_AWS_SECRET_ACCESS_KEY!, + accessKeyId: process.env.BEDROCK_AWS_ACCESS_KEY_ID!, + }, + });""" + llm = BedrockChat( + model_id="anthropic.claude-v2", + model_kwargs={"temperature": 0.1}, + region_name="us-east-1", + ) + base_model = AnthropicFunctions(llm=llm) + assert isinstance(base_model.model, BedrockChat) + + # bind functions + model = base_model.bind( + functions=[ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, " + "e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + } + ], + function_call={"name": "get_current_weather"}, + ) + + res = model.invoke("What's the weather in San Francisco?") + + function_call = res.additional_kwargs.get("function_call") + assert function_call + self.assertEqual(function_call.get("name"), "get_current_weather") + self.assertEqual( + function_call.get("arguments"), + '{"location": "San Francisco, CA", "unit": "fahrenheit"}', + )