mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
Vwp/alpaca streaming (#3468)
Co-authored-by: Luke Stanley <306671+lukestanley@users.noreply.github.com>
This commit is contained in:
parent
26035dfa59
commit
416f3bdf11
@ -41,7 +41,9 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from langchain.llms import LlamaCpp\n",
|
"from langchain.llms import LlamaCpp\n",
|
||||||
"from langchain import PromptTemplate, LLMChain"
|
"from langchain import PromptTemplate, LLMChain\n",
|
||||||
|
"from langchain.callbacks.base import CallbackManager\n",
|
||||||
|
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -67,7 +69,14 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"llm = LlamaCpp(model_path=\"./ggml-model-q4_0.bin\")"
|
"# Callbacks support token-wise streaming\n",
|
||||||
|
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
|
||||||
|
"# Verbose is required to pass to the callback manager\n",
|
||||||
|
"\n",
|
||||||
|
"# Make sure the model path is correct for your system!\n",
|
||||||
|
"llm = LlamaCpp(\n",
|
||||||
|
" model_path=\"./ggml-model-q4_0.bin\", callback_manager=callback_manager, verbose=True\n",
|
||||||
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -84,10 +93,17 @@
|
|||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"'\\n\\nWe know that Justin Bieber is currently 25 years old and that he was born on March 1st, 1994 and that he is a singer and he has an album called Purpose, so we know that he was born when Super Bowl XXXVIII was played between Dallas and Seattle and that it took place February 1st, 2004 and that the Seattle Seahawks won 24-21, so Seattle is our answer!'"
|
"' First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers.'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 6,
|
"execution_count": 6,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Wrapper around llama.cpp."""
|
"""Wrapper around llama.cpp."""
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Generator, List, Optional
|
||||||
|
|
||||||
from pydantic import Field, root_validator
|
from pydantic import Field, root_validator
|
||||||
|
|
||||||
@ -87,6 +87,9 @@ class LlamaCpp(LLM):
|
|||||||
last_n_tokens_size: Optional[int] = 64
|
last_n_tokens_size: Optional[int] = 64
|
||||||
"""The number of tokens to look back when applying the repeat_penalty."""
|
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||||
|
|
||||||
|
streaming: bool = True
|
||||||
|
"""Whether to stream the results, token by token."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that llama-cpp-python library is installed."""
|
"""Validate that llama-cpp-python library is installed."""
|
||||||
@ -139,7 +142,7 @@ class LlamaCpp(LLM):
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"logprobs": self.logprobs,
|
"logprobs": self.logprobs,
|
||||||
"echo": self.echo,
|
"echo": self.echo,
|
||||||
"stop_sequences": self.stop,
|
"stop_sequences": self.stop, # key here is convention among LLM classes
|
||||||
"repeat_penalty": self.repeat_penalty,
|
"repeat_penalty": self.repeat_penalty,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
}
|
}
|
||||||
@ -154,6 +157,31 @@ class LlamaCpp(LLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "llama.cpp"
|
return "llama.cpp"
|
||||||
|
|
||||||
|
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Performs sanity check, preparing paramaters in format needed by llama_cpp.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stop (Optional[List[str]]): List of stop sequences for llama_cpp.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the combined parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Raise error if stop sequences are in both input and default params
|
||||||
|
if self.stop and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
|
||||||
|
params = self._default_params
|
||||||
|
|
||||||
|
# llama_cpp expects the "stop" key not this, so we remove it:
|
||||||
|
params.pop("stop_sequences")
|
||||||
|
|
||||||
|
# then sets it as configured, or default to an empty list:
|
||||||
|
params["stop"] = self.stop or stop or []
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
"""Call the Llama model and return the output.
|
"""Call the Llama model and return the output.
|
||||||
|
|
||||||
@ -167,31 +195,65 @@ class LlamaCpp(LLM):
|
|||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import LlamaCppEmbeddings
|
from langchain.llms import LlamaCpp
|
||||||
llm = LlamaCppEmbeddings(model_path="/path/to/local/llama/model.bin")
|
llm = LlamaCpp(model_path="/path/to/local/llama/model.bin")
|
||||||
llm("This is a prompt.")
|
llm("This is a prompt.")
|
||||||
"""
|
"""
|
||||||
|
if self.streaming:
|
||||||
params = self._default_params
|
# If streaming is enabled, we use the stream
|
||||||
if self.stop and stop is not None:
|
# method that yields as they are generated
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
# and return the combined strings from the first choices's text:
|
||||||
elif self.stop:
|
combined_text_output = ""
|
||||||
params["stop_sequences"] = self.stop
|
for token in self.stream(prompt=prompt, stop=stop):
|
||||||
elif stop:
|
combined_text_output += token["choices"][0]["text"]
|
||||||
params["stop_sequences"] = stop
|
return combined_text_output
|
||||||
else:
|
else:
|
||||||
params["stop_sequences"] = []
|
params = self._get_parameters(stop)
|
||||||
|
result = self.client(prompt=prompt, **params)
|
||||||
|
return result["choices"][0]["text"]
|
||||||
|
|
||||||
"""Call the Llama model and return the output."""
|
def stream(
|
||||||
text = self.client(
|
self, prompt: str, stop: Optional[List[str]] = None
|
||||||
prompt=prompt,
|
) -> Generator[Dict, None, None]:
|
||||||
max_tokens=params["max_tokens"],
|
"""Yields results objects as they are generated in real time.
|
||||||
temperature=params["temperature"],
|
|
||||||
top_p=params["top_p"],
|
BETA: this is a beta feature while we figure out the right abstraction:
|
||||||
logprobs=params["logprobs"],
|
Once that happens, this interface could change.
|
||||||
echo=params["echo"],
|
|
||||||
stop=params["stop_sequences"],
|
It also calls the callback manager's on_llm_new_token event with
|
||||||
repeat_penalty=params["repeat_penalty"],
|
similar parameters to the OpenAI LLM class method of the same name.
|
||||||
top_k=params["top_k"],
|
|
||||||
)
|
Args:
|
||||||
return text["choices"][0]["text"]
|
prompt: The prompts to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens being generated.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A dictionary like objects containing a string token and metadata.
|
||||||
|
See llama-cpp-python docs and below for more.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import LlamaCpp
|
||||||
|
llm = LlamaCpp(
|
||||||
|
model_path="/path/to/local/model.bin",
|
||||||
|
temperature = 0.5
|
||||||
|
)
|
||||||
|
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
||||||
|
stop=["'","\n"]):
|
||||||
|
result = chunk["choices"][0]
|
||||||
|
print(result["text"], end='', flush=True)
|
||||||
|
|
||||||
|
"""
|
||||||
|
params = self._get_parameters(stop)
|
||||||
|
result = self.client(prompt=prompt, stream=True, **params)
|
||||||
|
for chunk in result:
|
||||||
|
token = chunk["choices"][0]["text"]
|
||||||
|
log_probs = chunk["choices"][0].get("logprobs", None)
|
||||||
|
self.callback_manager.on_llm_new_token(
|
||||||
|
token=token, verbose=self.verbose, log_probs=log_probs
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
@ -1,9 +1,13 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
"""Test Llama.cpp wrapper."""
|
"""Test Llama.cpp wrapper."""
|
||||||
import os
|
import os
|
||||||
|
from typing import Generator
|
||||||
from urllib.request import urlretrieve
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
from langchain.llms import LlamaCpp
|
from langchain.llms import LlamaCpp
|
||||||
|
from langchain.callbacks.base import CallbackManager
|
||||||
|
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
def get_model() -> str:
|
def get_model() -> str:
|
||||||
@ -32,3 +36,37 @@ def test_llamacpp_inference() -> None:
|
|||||||
llm = LlamaCpp(model_path=model_path)
|
llm = LlamaCpp(model_path=model_path)
|
||||||
output = llm("Say foo:")
|
output = llm("Say foo:")
|
||||||
assert isinstance(output, str)
|
assert isinstance(output, str)
|
||||||
|
assert len(output) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamacpp_streaming() -> None:
|
||||||
|
"""Test streaming tokens from LlamaCpp."""
|
||||||
|
model_path = get_model()
|
||||||
|
llm = LlamaCpp(model_path=model_path, max_tokens=10)
|
||||||
|
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["'"])
|
||||||
|
stream_results_string = ""
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
|
for chunk in generator:
|
||||||
|
assert not isinstance(chunk, str)
|
||||||
|
# Note that this matches the OpenAI format:
|
||||||
|
assert isinstance(chunk["choices"][0]["text"], str)
|
||||||
|
stream_results_string += chunk["choices"][0]["text"]
|
||||||
|
assert len(stream_results_string.strip()) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamacpp_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
MAX_TOKENS = 5
|
||||||
|
OFF_BY_ONE = 1 # There may be an off by one error in the upstream code!
|
||||||
|
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
llm = LlamaCpp(
|
||||||
|
model_path=get_model(),
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
max_tokens=MAX_TOKENS,
|
||||||
|
)
|
||||||
|
llm("Q: Can you count to 10? A:'1, ")
|
||||||
|
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE
|
||||||
|
Loading…
Reference in New Issue
Block a user