diff --git a/langchain/llms/ctransformers.py b/langchain/llms/ctransformers.py index 52223ece67..7d310109fc 100644 --- a/langchain/llms/ctransformers.py +++ b/langchain/llms/ctransformers.py @@ -1,9 +1,13 @@ """Wrapper around the C Transformers library.""" -from typing import Any, Dict, Optional, Sequence +from functools import partial +from typing import Any, Dict, List, Optional, Sequence from pydantic import root_validator -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM @@ -103,3 +107,36 @@ class CTransformers(LLM): text.append(chunk) _run_manager.on_llm_new_token(chunk, verbose=self.verbose) return "".join(text) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Asynchronous Call out to CTransformers generate method. + Very helpful when streaming (like with websockets!) + + Args: + prompt: The prompt to pass into the model. + stop: A list of strings to stop generation when encountered. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + response = llm("Once upon a time, ") + """ + text_callback = None + if run_manager: + text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) + + text = "" + for token in self.client(prompt, stop=stop, stream=True): + if text_callback: + await text_callback(token) + text += token + + return text diff --git a/tests/integration_tests/llms/test_ctransformers.py b/tests/integration_tests/llms/test_ctransformers.py index ead4dbce02..2ed4c09131 100644 --- a/tests/integration_tests/llms/test_ctransformers.py +++ b/tests/integration_tests/llms/test_ctransformers.py @@ -1,4 +1,5 @@ """Test C Transformers wrapper.""" +import pytest from langchain.llms import CTransformers from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -19,3 +20,20 @@ def test_ctransformers_call() -> None: assert isinstance(output, str) assert len(output) > 1 assert 0 < callback_handler.llm_streams <= config["max_new_tokens"] + + +@pytest.mark.asyncio +async def test_ctransformers_async_inference() -> None: + config = {"max_new_tokens": 5} + callback_handler = FakeCallbackHandler() + + llm = CTransformers( + model="marella/gpt-2-ggml", + config=config, + callbacks=[callback_handler], + ) + + output = await llm._acall(prompt="Say foo:") + assert isinstance(output, str) + assert len(output) > 1 + assert 0 < callback_handler.llm_streams <= config["max_new_tokens"]