Batching for hf_pipeline (#10795)

The huggingface pipeline in langchain (used for locally hosted models)
does not support batching. If you send in a batch of prompts, it just
processes them serially using the base implementation of _generate:
https://github.com/docugami/langchain/blob/master/libs/langchain/langchain/llms/base.py#L1004C2-L1004C29

This PR adds support for batching in this pipeline, so that GPUs can be
fully saturated. I updated the accompanying notebook to show GPU batch
inference.

---------

Co-authored-by: Taqi Jaffri <tjaffri@docugami.com>
pull/11025/head
Taqi Jaffri 1 year ago committed by GitHub
parent aa6e6db8c7
commit b7290f01d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
"metadata": {
"tags": []
@ -75,18 +75,10 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "3acf0069",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" First, we need to understand what is an electroencephalogram. An electroencephalogram is a recording of brain activity. It is a recording of brain activity that is made by placing electrodes on the scalp. The electrodes are placed\n"
]
}
],
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate\n",
"\n",
@ -101,6 +93,42 @@
"\n",
"print(chain.invoke({\"question\": question}))"
]
},
{
"cell_type": "markdown",
"id": "dbbc3a37",
"metadata": {},
"source": [
"### Batch GPU Inference\n",
"\n",
"If running on a device with GPU, you can also run inference on the GPU in batch mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "097ba62f",
"metadata": {},
"outputs": [],
"source": [
"gpu_llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"bigscience/bloom-1b7\",\n",
" task=\"text-generation\",\n",
" device=0, # -1 for CPU\n",
" batch_size=2, # adjust as needed based on GPU map and model size.\n",
" model_kwargs={\"temperature\": 0, \"max_length\": 64},\n",
")\n",
"\n",
"gpu_chain = prompt | gpu_llm.bind(stop=[\"\\n\\n\"])\n",
"\n",
"questions = []\n",
"for i in range(4):\n",
" questions.append({\"question\": f\"What is the number {i} in french?\"})\n",
"\n",
"answers = gpu_chain.batch(questions)\n",
"for answer in answers:\n",
" print(answer)"
]
}
],
"metadata": {
@ -119,7 +147,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.2"
"version": "3.8.10"
}
},
"nbformat": 4,

@ -1,20 +1,24 @@
from __future__ import annotations
import importlib.util
import logging
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra
from langchain.schema import Generation, LLMResult
DEFAULT_MODEL_ID = "gpt2"
DEFAULT_TASK = "text-generation"
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
DEFAULT_BATCH_SIZE = 4
logger = logging.getLogger(__name__)
class HuggingFacePipeline(LLM):
class HuggingFacePipeline(BaseLLM):
"""HuggingFace Pipeline API.
To use, you should have the ``transformers`` python package installed.
@ -52,6 +56,8 @@ class HuggingFacePipeline(LLM):
"""Key word arguments passed to the model."""
pipeline_kwargs: Optional[dict] = None
"""Key word arguments passed to the pipeline."""
batch_size: int = DEFAULT_BATCH_SIZE
"""Batch size to use when passing multiple documents to generate."""
class Config:
"""Configuration for this pydantic object."""
@ -66,8 +72,9 @@ class HuggingFacePipeline(LLM):
device: int = -1,
model_kwargs: Optional[dict] = None,
pipeline_kwargs: Optional[dict] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
**kwargs: Any,
) -> LLM:
) -> HuggingFacePipeline:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import (
@ -128,6 +135,7 @@ class HuggingFacePipeline(LLM):
model=model,
tokenizer=tokenizer,
device=device,
batch_size=batch_size,
model_kwargs=_model_kwargs,
**_pipeline_kwargs,
)
@ -141,6 +149,7 @@ class HuggingFacePipeline(LLM):
model_id=model_id,
model_kwargs=_model_kwargs,
pipeline_kwargs=_pipeline_kwargs,
batch_size=batch_size,
**kwargs,
)
@ -157,28 +166,47 @@ class HuggingFacePipeline(LLM):
def _llm_type(self) -> str:
return "huggingface_pipeline"
def _call(
def _generate(
self,
prompt: str,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
response = self.pipeline(prompt)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]
elif self.pipeline.task == "text2text-generation":
text = response[0]["generated_text"]
elif self.pipeline.task == "summarization":
text = response[0]["summary_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
) -> LLMResult:
# List to hold all results
text_generations: List[str] = []
for i in range(0, len(prompts), self.batch_size):
batch_prompts = prompts[i : i + self.batch_size]
# Process batch of prompts
responses = self.pipeline(batch_prompts)
# Process each response in the batch
for j, response in enumerate(responses):
if isinstance(response, list):
# if model returns multiple generations, pick the top one
response = response[0]
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text
text = response["generated_text"][len(batch_prompts[j]) :]
elif self.pipeline.task == "text2text-generation":
text = response["generated_text"]
elif self.pipeline.task == "summarization":
text = response["summary_text"]
else:
raise ValueError(
f"Got invalid task {self.pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
if stop:
# Enforce stop tokens
text = enforce_stop_tokens(text, stop)
# Append the processed text to results
text_generations.append(text)
return LLMResult(
generations=[[Generation(text=text)] for text in text_generations]
)

Loading…
Cancel
Save