Harrison/version 0040 (#366)

harrison/sequential_chain_from_prompts
Harrison Chase 1 year ago committed by GitHub
parent 50257fce59
commit a7084ad6e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,24 +61,36 @@ class HuggingFacePipeline(LLM, BaseModel):
) -> LLM:
"""Construct the pipeline object from model_id and task."""
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
)
from transformers import pipeline as hf_pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
if task == "text-generation":
model = AutoModelForCausalLM.from_pretrained(model_id)
elif task == "text2text-generation":
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
raise ValueError(
f"Got invalid task {task}, "
f"currently only {VALID_TASKS} are supported"
)
_model_kwargs = model_kwargs or {}
pipeline = hf_pipeline(
task=task, model=model, tokenizer=tokenizer, **model_kwargs
task=task, model=model, tokenizer=tokenizer, **_model_kwargs
)
if pipeline.task not in VALID_TASKS:
raise ValueError(
f"Got invalid task {pipeline.task}, "
f"currently only {VALID_TASKS} are supported"
)
return cls(
pipeline=pipeline,
model_id=model_id,
model_kwargs=model_kwargs,
model_kwargs=_model_kwargs,
**kwargs,
)
except ImportError:
@ -100,7 +112,7 @@ class HuggingFacePipeline(LLM, BaseModel):
return "huggingface_pipeline"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
response = self.pipeline(text_inputs=prompt)
response = self.pipeline(prompt)
if self.pipeline.task == "text-generation":
# Text generation return includes the starter text.
text = response[0]["generated_text"][len(prompt) :]

@ -163,6 +163,9 @@ class OpenAI(LLM, BaseModel):
def stream(self, prompt: str) -> Generator:
"""Call OpenAI with streaming flag and return the resulting generator.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
Args:
prompt: The prompts to pass into the model.

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain"
version = "0.0.39"
version = "0.0.40"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"

@ -18,6 +18,15 @@ def test_huggingface_pipeline_text_generation() -> None:
assert isinstance(output, str)
def test_huggingface_pipeline_text2text_generation() -> None:
"""Test valid call to HuggingFace text2text generation model."""
llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-small", task="text2text-generation"
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an HuggingFaceHub LLM."""
llm = HuggingFacePipeline.from_model_id(

Loading…
Cancel
Save