From 04b40012774a484313f0d8ced997caa4140cf804 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 28 Oct 2023 19:51:12 +0100 Subject: [PATCH] anthropic working --- application/api/answer/routes.py | 2 ++ application/llm/anthropic.py | 40 ++++++++++++++++++++++ application/llm/llm_creator.py | 4 ++- application/requirements.txt | 1 + docker-compose.yaml | 2 ++ tests/llm/test_anthropic.py | 57 ++++++++++++++++++++++++++++++++ 6 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 application/llm/anthropic.py create mode 100644 tests/llm/test_anthropic.py diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index b694c4a..a932725 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -29,6 +29,8 @@ answer = Blueprint('answer', __name__) if settings.LLM_NAME == "gpt4": gpt_model = 'gpt-4' +elif settings.LLM_NAME == "anthropic": + gpt_model = 'claude-2' else: gpt_model = 'gpt-3.5-turbo' diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py new file mode 100644 index 0000000..a64d71e --- /dev/null +++ b/application/llm/anthropic.py @@ -0,0 +1,40 @@ +from application.llm.base import BaseLLM +from application.core.settings import settings + +class AnthropicLLM(BaseLLM): + + def __init__(self, api_key=None): + from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT + self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings + self.anthropic = Anthropic(api_key=self.api_key) + self.HUMAN_PROMPT = HUMAN_PROMPT + self.AI_PROMPT = AI_PROMPT + + def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs): + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Context \n {context} \n ### Question \n {user_question}" + if stream: + return self.gen_stream(model, prompt, max_tokens, **kwargs) + + completion = self.anthropic.completions.create( + model=model, + max_tokens_to_sample=max_tokens, + stream=stream, + prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}", + ) + return completion.completion + + def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs): + context = messages[0]['content'] + user_question = messages[-1]['content'] + prompt = f"### Context \n {context} \n ### Question \n {user_question}" + stream_response = self.anthropic.completions.create( + model=model, + prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}", + max_tokens_to_sample=max_tokens, + stream=True, + ) + + for completion in stream_response: + yield completion.completion \ No newline at end of file diff --git a/application/llm/llm_creator.py b/application/llm/llm_creator.py index 6a60f1b..bbc2c79 100644 --- a/application/llm/llm_creator.py +++ b/application/llm/llm_creator.py @@ -2,6 +2,7 @@ from application.llm.openai import OpenAILLM, AzureOpenAILLM from application.llm.sagemaker import SagemakerAPILLM from application.llm.huggingface import HuggingFaceLLM from application.llm.llama_cpp import LlamaCpp +from application.llm.anthropic import AnthropicLLM @@ -11,7 +12,8 @@ class LLMCreator: 'azure_openai': AzureOpenAILLM, 'sagemaker': SagemakerAPILLM, 'huggingface': HuggingFaceLLM, - 'llama.cpp': LlamaCpp + 'llama.cpp': LlamaCpp, + 'anthropic': AnthropicLLM } @classmethod diff --git a/application/requirements.txt b/application/requirements.txt index 693e628..e2a2b2d 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -4,6 +4,7 @@ aiohttp-retry==2.8.3 aiosignal==1.3.1 aleph-alpha-client==2.16.1 amqp==5.1.1 +anthropic==0.5.0 async-timeout==4.0.2 attrs==22.2.0 billiard==3.6.4.0 diff --git a/docker-compose.yaml b/docker-compose.yaml index 84cc568..7008b53 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -16,6 +16,7 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt @@ -35,6 +36,7 @@ services: environment: - API_KEY=$API_KEY - EMBEDDINGS_KEY=$API_KEY + - LLM_NAME=$LLM_NAME - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py new file mode 100644 index 0000000..ee4ba15 --- /dev/null +++ b/tests/llm/test_anthropic.py @@ -0,0 +1,57 @@ +import unittest +from unittest.mock import patch, Mock +from application.llm.anthropic import AnthropicLLM + +class TestAnthropicLLM(unittest.TestCase): + + def setUp(self): + self.api_key = "TEST_API_KEY" + self.llm = AnthropicLLM(api_key=self.api_key) + + @patch("application.llm.anthropic.settings") + def test_init_default_api_key(self, mock_settings): + mock_settings.ANTHROPIC_API_KEY = "DEFAULT_API_KEY" + llm = AnthropicLLM() + self.assertEqual(llm.api_key, "DEFAULT_API_KEY") + + def test_gen(self): + messages = [ + {"content": "context"}, + {"content": "question"} + ] + mock_response = Mock() + mock_response.completion = "test completion" + + with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: + response = self.llm.gen("test_model", messages) + self.assertEqual(response, "test completion") + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + max_tokens_to_sample=300, + stream=False, + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" + ) + + def test_gen_stream(self): + messages = [ + {"content": "context"}, + {"content": "question"} + ] + mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] + + with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: + responses = list(self.llm.gen_stream("test_model", messages)) + self.assertListEqual(responses, ["response_1", "response_2"]) + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", + max_tokens_to_sample=300, + stream=True + ) + +if __name__ == "__main__": + unittest.main()