Vwp/alpaca streaming (#3468)

Co-authored-by: Luke Stanley <306671+lukestanley@users.noreply.github.com>
fix_agent_callbacks
Zander Chase 1 year ago committed by GitHub
parent 26035dfa59
commit 416f3bdf11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,7 +41,9 @@
"outputs": [],
"source": [
"from langchain.llms import LlamaCpp\n",
"from langchain import PromptTemplate, LLMChain"
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
@ -67,7 +69,14 @@
},
"outputs": [],
"source": [
"llm = LlamaCpp(model_path=\"./ggml-model-q4_0.bin\")"
"# Callbacks support token-wise streaming\n",
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
"# Verbose is required to pass to the callback manager\n",
"\n",
"# Make sure the model path is correct for your system!\n",
"llm = LlamaCpp(\n",
" model_path=\"./ggml-model-q4_0.bin\", callback_manager=callback_manager, verbose=True\n",
")"
]
},
{
@ -84,10 +93,17 @@
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers."
]
},
{
"data": {
"text/plain": [
"'\\n\\nWe know that Justin Bieber is currently 25 years old and that he was born on March 1st, 1994 and that he is a singer and he has an album called Purpose, so we know that he was born when Super Bowl XXXVIII was played between Dallas and Seattle and that it took place February 1st, 2004 and that the Seattle Seahawks won 24-21, so Seattle is our answer!'"
"' First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers.'"
]
},
"execution_count": 6,

@ -1,6 +1,6 @@
"""Wrapper around llama.cpp."""
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional
from pydantic import Field, root_validator
@ -87,6 +87,9 @@ class LlamaCpp(LLM):
last_n_tokens_size: Optional[int] = 64
"""The number of tokens to look back when applying the repeat_penalty."""
streaming: bool = True
"""Whether to stream the results, token by token."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
@ -139,7 +142,7 @@ class LlamaCpp(LLM):
"top_p": self.top_p,
"logprobs": self.logprobs,
"echo": self.echo,
"stop_sequences": self.stop,
"stop_sequences": self.stop, # key here is convention among LLM classes
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
}
@ -154,6 +157,31 @@ class LlamaCpp(LLM):
"""Return type of llm."""
return "llama.cpp"
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Performs sanity check, preparing paramaters in format needed by llama_cpp.
Args:
stop (Optional[List[str]]): List of stop sequences for llama_cpp.
Returns:
Dictionary containing the combined parameters.
"""
# Raise error if stop sequences are in both input and default params
if self.stop and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
params = self._default_params
# llama_cpp expects the "stop" key not this, so we remove it:
params.pop("stop_sequences")
# then sets it as configured, or default to an empty list:
params["stop"] = self.stop or stop or []
return params
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call the Llama model and return the output.
@ -167,31 +195,65 @@ class LlamaCpp(LLM):
Example:
.. code-block:: python
from langchain.llms import LlamaCppEmbeddings
llm = LlamaCppEmbeddings(model_path="/path/to/local/llama/model.bin")
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path="/path/to/local/llama/model.bin")
llm("This is a prompt.")
"""
params = self._default_params
if self.stop and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop:
params["stop_sequences"] = self.stop
elif stop:
params["stop_sequences"] = stop
if self.streaming:
# If streaming is enabled, we use the stream
# method that yields as they are generated
# and return the combined strings from the first choices's text:
combined_text_output = ""
for token in self.stream(prompt=prompt, stop=stop):
combined_text_output += token["choices"][0]["text"]
return combined_text_output
else:
params["stop_sequences"] = []
"""Call the Llama model and return the output."""
text = self.client(
prompt=prompt,
max_tokens=params["max_tokens"],
temperature=params["temperature"],
top_p=params["top_p"],
logprobs=params["logprobs"],
echo=params["echo"],
stop=params["stop_sequences"],
repeat_penalty=params["repeat_penalty"],
top_k=params["top_k"],
)
return text["choices"][0]["text"]
params = self._get_parameters(stop)
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]
def stream(
self, prompt: str, stop: Optional[List[str]] = None
) -> Generator[Dict, None, None]:
"""Yields results objects as they are generated in real time.
BETA: this is a beta feature while we figure out the right abstraction:
Once that happens, this interface could change.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See llama-cpp-python docs and below for more.
Example:
.. code-block:: python
from langchain.llms import LlamaCpp
llm = LlamaCpp(
model_path="/path/to/local/model.bin",
temperature = 0.5
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
result = chunk["choices"][0]
print(result["text"], end='', flush=True)
"""
params = self._get_parameters(stop)
result = self.client(prompt=prompt, stream=True, **params)
for chunk in result:
token = chunk["choices"][0]["text"]
log_probs = chunk["choices"][0].get("logprobs", None)
self.callback_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield chunk

@ -1,9 +1,13 @@
# flake8: noqa
"""Test Llama.cpp wrapper."""
import os
from typing import Generator
from urllib.request import urlretrieve
from langchain.llms import LlamaCpp
from langchain.callbacks.base import CallbackManager
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def get_model() -> str:
@ -32,3 +36,37 @@ def test_llamacpp_inference() -> None:
llm = LlamaCpp(model_path=model_path)
output = llm("Say foo:")
assert isinstance(output, str)
assert len(output) > 1
def test_llamacpp_streaming() -> None:
"""Test streaming tokens from LlamaCpp."""
model_path = get_model()
llm = LlamaCpp(model_path=model_path, max_tokens=10)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["'"])
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert not isinstance(chunk, str)
# Note that this matches the OpenAI format:
assert isinstance(chunk["choices"][0]["text"], str)
stream_results_string += chunk["choices"][0]["text"]
assert len(stream_results_string.strip()) > 1
def test_llamacpp_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
MAX_TOKENS = 5
OFF_BY_ONE = 1 # There may be an off by one error in the upstream code!
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = LlamaCpp(
model_path=get_model(),
callback_manager=callback_manager,
verbose=True,
max_tokens=MAX_TOKENS,
)
llm("Q: Can you count to 10? A:'1, ")
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE

Loading…
Cancel
Save