mirror of https://github.com/arc53/DocsGPT
anthropic working
parent
e54d46aae1
commit
04b4001277
@ -0,0 +1,40 @@
|
|||||||
|
from application.llm.base import BaseLLM
|
||||||
|
from application.core.settings import settings
|
||||||
|
|
||||||
|
class AnthropicLLM(BaseLLM):
|
||||||
|
|
||||||
|
def __init__(self, api_key=None):
|
||||||
|
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
|
||||||
|
self.api_key = api_key or settings.ANTHROPIC_API_KEY # If not provided, use a default from settings
|
||||||
|
self.anthropic = Anthropic(api_key=self.api_key)
|
||||||
|
self.HUMAN_PROMPT = HUMAN_PROMPT
|
||||||
|
self.AI_PROMPT = AI_PROMPT
|
||||||
|
|
||||||
|
def gen(self, model, messages, engine=None, max_tokens=300, stream=False, **kwargs):
|
||||||
|
context = messages[0]['content']
|
||||||
|
user_question = messages[-1]['content']
|
||||||
|
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||||
|
if stream:
|
||||||
|
return self.gen_stream(model, prompt, max_tokens, **kwargs)
|
||||||
|
|
||||||
|
completion = self.anthropic.completions.create(
|
||||||
|
model=model,
|
||||||
|
max_tokens_to_sample=max_tokens,
|
||||||
|
stream=stream,
|
||||||
|
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
|
||||||
|
)
|
||||||
|
return completion.completion
|
||||||
|
|
||||||
|
def gen_stream(self, model, messages, engine=None, max_tokens=300, **kwargs):
|
||||||
|
context = messages[0]['content']
|
||||||
|
user_question = messages[-1]['content']
|
||||||
|
prompt = f"### Context \n {context} \n ### Question \n {user_question}"
|
||||||
|
stream_response = self.anthropic.completions.create(
|
||||||
|
model=model,
|
||||||
|
prompt=f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT}",
|
||||||
|
max_tokens_to_sample=max_tokens,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for completion in stream_response:
|
||||||
|
yield completion.completion
|
@ -0,0 +1,57 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
from application.llm.anthropic import AnthropicLLM
|
||||||
|
|
||||||
|
class TestAnthropicLLM(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.api_key = "TEST_API_KEY"
|
||||||
|
self.llm = AnthropicLLM(api_key=self.api_key)
|
||||||
|
|
||||||
|
@patch("application.llm.anthropic.settings")
|
||||||
|
def test_init_default_api_key(self, mock_settings):
|
||||||
|
mock_settings.ANTHROPIC_API_KEY = "DEFAULT_API_KEY"
|
||||||
|
llm = AnthropicLLM()
|
||||||
|
self.assertEqual(llm.api_key, "DEFAULT_API_KEY")
|
||||||
|
|
||||||
|
def test_gen(self):
|
||||||
|
messages = [
|
||||||
|
{"content": "context"},
|
||||||
|
{"content": "question"}
|
||||||
|
]
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.completion = "test completion"
|
||||||
|
|
||||||
|
with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create:
|
||||||
|
response = self.llm.gen("test_model", messages)
|
||||||
|
self.assertEqual(response, "test completion")
|
||||||
|
|
||||||
|
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||||
|
mock_create.assert_called_with(
|
||||||
|
model="test_model",
|
||||||
|
max_tokens_to_sample=300,
|
||||||
|
stream=False,
|
||||||
|
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_gen_stream(self):
|
||||||
|
messages = [
|
||||||
|
{"content": "context"},
|
||||||
|
{"content": "question"}
|
||||||
|
]
|
||||||
|
mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")]
|
||||||
|
|
||||||
|
with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create:
|
||||||
|
responses = list(self.llm.gen_stream("test_model", messages))
|
||||||
|
self.assertListEqual(responses, ["response_1", "response_2"])
|
||||||
|
|
||||||
|
prompt_expected = "### Context \n context \n ### Question \n question"
|
||||||
|
mock_create.assert_called_with(
|
||||||
|
model="test_model",
|
||||||
|
prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}",
|
||||||
|
max_tokens_to_sample=300,
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue