Compare commits

...

3 Commits

Author SHA1 Message Date
William Fu-Hinthorn
9814161edb update tests 2022-10-27 06:49:50 -07:00
William Fu-Hinthorn
86fdeaf4ec interface 2022-10-27 06:40:20 -07:00
William Fu-Hinthorn
ad53a2ef81 Self-consistent draft 2022-10-27 06:30:58 -07:00
19 changed files with 602 additions and 36 deletions

View File

@ -12,6 +12,7 @@ from langchain.chains import (
ReActChain,
SelfAskWithSearchChain,
SerpAPIChain,
SelfConsistencyChain,
)
from langchain.docstore import Wikipedia
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
@ -22,6 +23,7 @@ __all__ = [
"LLMMathChain",
"PythonChain",
"SelfAskWithSearchChain",
"SelfConsistencyChain",
"SerpAPIChain",
"Cohere",
"OpenAI",

View File

@ -5,12 +5,13 @@ from langchain.chains.python import PythonChain
from langchain.chains.react.base import ReActChain
from langchain.chains.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.chains.serpapi import SerpAPIChain
from langchain.chains.self_consistency.base import SelfConsistencyChain
__all__ = [
"LLMChain",
"LLMMathChain",
"PythonChain",
"SelfAskWithSearchChain",
"SelfConsistencyChain",
"SerpAPIChain",
"ReActChain",
]

View File

@ -1,11 +1,14 @@
"""Chain that just formats a prompt and calls an LLM."""
from typing import Any, Dict, List
from collections import defaultdict
import math
from typing import Any, Callable, Dict, List, Tuple
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.llms.base import LLM
from langchain.prompt import Prompt
import re
class LLMChain(Chain, BaseModel):
@ -73,3 +76,90 @@ class LLMChain(Chain, BaseModel):
completion = llm.predict(adjective="funny")
"""
return self(kwargs)[self.output_key]
class ChainOfThoughtParser(BaseModel):
"""Parser to separate the reasoning steps from the answer."""
reasoning_parser: Callable[[str], str]
"""Function to parse the reasoning steps from the generated text."""
answer_parser: Callable[[str], str]
"""Function to parse the answer from the generated text."""
def parse_completion(self, text: str) -> Tuple[str, str]:
"""Parse the reasoning steps and answer from the completion."""
reasoning = self.reasoning_parser(text)
answer = self.answer_parser(text)
return reasoning, answer
# Default parser returns the string preceding "The answer is" (case invariant) as the reasoning
# and the string following as the answer.
_UNKNOWN_ANSWER = "I don't know."
def _default_answer_parser(text: str) -> str:
"""Default answer parser."""
try:
# Use re to split the text along "The answer is" (case invariant) and return the second
# element of the resulting list.
return re.split(r"(?i)the\sanswer\sis", text)[1].strip()
except IndexError:
return _UNKNOWN_ANSWER
def _default_reasoning_parser(text: str) -> str:
"""Default reasoning parser."""
try:
return re.split(r"(?i)the\sanswer\sis", text)[0].strip()
except IndexError:
return text
DEFAULT_CHAIN_OF_THOUGHT_PARSER = ChainOfThoughtParser(
reasoning_parser=_default_reasoning_parser, answer_parser=_default_answer_parser
)
class SelfConsistencyLLMChain(LLMChain, BaseModel):
"""LLM Chain that uses self-consistency to improve the reliability of its outputs."""
parser: ChainOfThoughtParser = DEFAULT_CHAIN_OF_THOUGHT_PARSER
"""Parser to separate the reasoning steps from the answer."""
max_iterations: int = 5
"""Maximum number of iterations to run."""
normalize_probs: bool = True
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Run the chain."""
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
kwargs = {}
if "stop" in inputs:
kwargs["stop"] = inputs["stop"]
answers = defaultdict(float)
responses = defaultdict(list)
n = 0
while n < self.max_iterations:
_responses = self.llm.generate(prompt, **kwargs)
for response in _responses:
reasoning, answer = self.parser.parse_completion(response.text)
if response.logprobs is not None:
total_logprob = sum(response.logprobs)
if self.normalize_probs:
total_logprob /= len(response.logprobs)
generated_prob = math.exp(total_logprob)
else:
generated_prob = 1.0
answers[answer] += generated_prob
responses[answer].append((reasoning, answer, generated_prob))
n += 1
answer = max(answers, key=answers.get)
sorted_answers = sorted(responses[answer], key=lambda x: x[2], reverse=True)
if answer == _UNKNOWN_ANSWER:
# If the model doesn't know, output the related reasoning steps.
flipped_response = sorted_answers[0][0]
else:
flipped_response = answer
return {self.output_key: flipped_response}

View File

@ -0,0 +1,116 @@
"""Implement an LLM driven browser."""
from typing import Dict, List, Optional
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import SelfConsistencyLLMChain
from langchain.chains.self_consistency.prompts.anli_prompt import ANLI_PROMPT
from langchain.chains.self_consistency.prompts.aqua_rat_prompt import AQUA_RAT_PROMPT
from langchain.chains.self_consistency.prompts.arc_prompt import ARC_PROMPT
from langchain.chains.self_consistency.prompts.arithmetic_reasoning_prompt import (
ARITHMETIC_REASONING_PROMPT,
)
from langchain.chains.self_consistency.prompts.boolq_prompt import BOOLQ_PROMPT
from langchain.chains.self_consistency.prompts.hotpotqa_prompt import HOTPOTQA_PROMPT
from langchain.chains.self_consistency.prompts.esnli_prompt import ESNLI_PROMPT
from langchain.llms.base import LLM
from langchain.llms.openai import OpenAI
from langchain.prompt import Prompt
_CLASS_TO_PROMPT: Dict[str, Prompt] = {
"anli": ANLI_PROMPT,
"aqua_rat": AQUA_RAT_PROMPT,
"arc": ARC_PROMPT,
"arithmetic_reasoning": ARITHMETIC_REASONING_PROMPT,
"boolq": BOOLQ_PROMPT,
"esnli": ESNLI_PROMPT,
"hotpotqa": HOTPOTQA_PROMPT,
}
# TODO: Add auto-routing and more prompts
_FALLBACK_MAP: Dict[str, str] = {
"nli": "anli",
"natural_language_inference": "anli",
"rte": "anli",
"math": "aqua_rat",
"qna": "hotpotqa",
}
class SelfConsistencyChain(Chain, BaseModel):
"""Implement an LLM chain to reason in a self-consistent manner.
Based on Self-Consistency Improves Chain of Thought Reasoning in
Language Models
Example:
.. code-block:: python
from langchain import SelfConsistencyChain, OpenAI
natbot = SelfConsistencyChain(llm=OpenAI(), objective="Buy me a new hat.")
"""
llm: LLM
"""LLM wrapper to use."""
default_task: str
"""The default task to run."""
input_key: str = "prompt_inputs" #: :meta private:
output_key: str = "answer" #: :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@classmethod
def from_default(cls, objective: str) -> "SelfConsistencyChain":
"""Load with default LLM."""
llm = OpenAI(temperature=0.5, best_of=10, n=3, max_tokens=50)
return cls(llm=llm, objective=objective)
@property
def input_keys(self) -> List[str]:
"""Expect different keys depending on the task.
:meta private:
"""
return [self.input_key, "task"]
@property
def output_keys(self) -> List[str]:
"""Return command.
:meta private:
"""
return [self.output_key]
def _get_prompt(self, task: Optional[str]) -> Prompt:
"""Get the prompt for the task."""
if task in _CLASS_TO_PROMPT:
return _CLASS_TO_PROMPT[task]
if task in _FALLBACK_MAP:
return _CLASS_TO_PROMPT[_FALLBACK_MAP[task]]
raise ValueError(f"Unknown task {task}")
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
task = inputs["task"]
prompt = self._get_prompt(task)
llm_executor = SelfConsistencyLLMChain(prompt=prompt, llm=self.llm)
llm_inputs = inputs[self.input_key]
if 'choices' in llm_inputs:
if isinstance(llm_inputs['choices'], list):
llm_inputs['choices'] = ' '.join([f"({chr(97+i)}) {choice}" for i, choice in enumerate(llm_inputs['choices'])])
answer = llm_executor.predict(**llm_inputs)
return {self.output_key: answer}
def run(self, **kwargs: str) -> str:
"""Figure out next browser command to run."""
task = kwargs.pop("task", self.default_task)
_inputs = {
self.input_key: kwargs,
"task": task,
}
return self(_inputs)[self.output_key]

View File

@ -0,0 +1,74 @@
"""Prompts for adversarial NLI."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Premise:
"Conceptually cream skimming has two basic dimensions - product and geography."
Based on this premise, can we conclude the hypothesis "Product and geography are what make cream skimming
work." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: Based on "cream skimming has two basic dimensions" we cant infer that these two dimensions are what
make cream skimming work. The answer is it is not possible to tell.
Premise:
"One of our member will carry out your instructions minutely."
Based on this premise, can we conclude the hypothesis "A member of my team will execute your orders with
immense precision." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: "one of" means the same as "a member of", "carry out" means the same as "execute", and "minutely" means
the same as "immense precision". The answer is yes.
Premise:
"Fun for adults and children."
Based on this premise, can we conclude the hypothesis "Fun for only children." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: "adults and children" contradicts "only children". The answer is no.
Premise:
"He turned and smiled at Vrenna."
Based on this premise, can we conclude the hypothesis "He smiled at Vrenna who was walking slowly behind
him with her mother." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: the premise does not say anything about "Vrenna was walking". The answer is it is not possible to tell.
Premise:
"well you see that on television also"
Based on this premise, can we conclude the hypothesis "You can see that on television, as well." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: "also" and "as well" mean the same thing. The answer is yes.
Premise:
"Vrenna and I both fought him and he nearly took us."
Based on this premise, can we conclude the hypothesis "Neither Vrenna nor myself have ever fought him." is
true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: "Vrenna and I both" contradicts "neither Vrenna nor myself". The answer is no.
Premise:
{premise}
Based on this premise, can we conclude the hypothesis "{hypothesis}" is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A:"""
ANLI_PROMPT = Prompt(
input_variables=["premise", "hypothesis"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,29 @@
"""Prompts for the middle school Aqua RAT dataset."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Q: John found that the average of 15 numbers is 40. If 10 is added to each number then the mean of the
numbers is? Answer Choices: (a) 50 (b) 45 (c) 65 (d) 78 (e) 64
A: If 10 is added to each number, then the mean of the numbers also increases by 10. So the new mean
would be 50. The answer is (a).
Q: If a / b = 3/4 and 8a + 5b = 22,then find the value of a. Answer Choices: (a) 1/2 (b) 3/2 (c) 5/2 (d) 4/2 (e)
7/2
A: If a / b = 3/4, then b = 4a / 3. So 8a + 5(4a / 3) = 22. This simplifies to 8a + 20a / 3 = 22, which means
44a / 3 = 22. So a is equal to 3/2. The answer is (b).
Q: A person is traveling at 20 km/hr and reached his destiny in 2.5 hr then find the distance? Answer Choices:
(a) 53 km (b) 55 km (c) 52 km (d) 60 km (e) 50 km
A: The distance that the person traveled would have been 20 km/hr * 2.5 hrs = 50 km. The answer is (e).
Q: How many keystrokes are needed to type the numbers from 1 to 500? Answer Choices: (a) 1156 (b) 1392
(c) 1480 (d) 1562 (e) 1788
A: There are 9 one-digit numbers from 1 to 9. There are 90 two-digit numbers from 10 to 99. There are 401
three-digit numbers from 100 to 500. 9 + 90(2) + 401(3) = 1392. The answer is (b).
Q: {question} Answer Choices: {choices}
A:"""
AQUA_RAT_PROMPT = Prompt(
input_variables=["question", "choices"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,28 @@
"""Prompt for the ARC dataset."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Q: George wants to warm his hands quickly by rubbing them. Which skin surface will produce the most
heat? (a) dry palms. (b) wet palms. (c) palms covered with oil. (d) palms covered with lotion.
A: Dry surfaces will more likely cause more friction via rubbing than other smoother surfaces, hence dry
palms will produce the most heat. The answer is (a).
Q: Which factor will most likely cause a person to develop a fever? (a) a leg muscle relaxing after exercise.
(b) a bacterial population in the bloodstream. (c) several viral particles on the skin. (d) carbohydrates being
digested in the stomach.
A: Option (b), bacterial population is the most likely cause for a person developing fever. The answer is (b).
Q: Which change in the state of water particles causes the particles to become arranged in a fixed position?
(a) boiling. (b) melting. (c) freezing. (d) evaporating.
A: When water is freezed, the particles are arranged in a fixed position; the particles are still moving for all
other options. The answer is (c).
Q: When a switch is used in an electrical circuit, the switch can (a) cause the charge to build. (b) increase
and decrease the voltage. (c) cause the current to change direction. (d) stop and start the flow of current.
A: The function of a switch is to start and stop the flow of a current. The answer is (d).
Q: {question} {choices}
A:"""
ARC_PROMPT = Prompt(
input_variables=["question", "choices"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,42 @@
"""Prompt for arithmetic reasoning."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done,
there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted.
So, they must have planted 21 - 15 = 6 trees. The answer is 6.
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leahs sister had 42. That means there were originally 32 + 42 = 74
chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops
did Jason give to Denny?
A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of
lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does
he have now?
A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so
in total he has 7 + 2 = 9 toys. The answer is 9.
Q: There were nine computers in the server room. Five more computers were installed each day, from
monday to thursday. How many computers are now in the server room?
A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 =
20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers.
The answer is 29.
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many
golf balls did he have at the end of wednesday?
A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On
Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in
beginning, so now she has $23 - $15 = $8. The answer is 8.
Q: {question}
A:"""
ARITHMETIC_REASONING_PROMPT = Prompt(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,24 @@
"""Prompt for BoolQ."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Q: does system of a down have 2 singers?
A: System of a Down currently consists of Serj Tankian, Daron Malakian, Shavo Odadjian and John Dolmayan.
Serj and Daron do vocals, so the band does have two singers. The answer is yes.
Q: do iran and afghanistan speak the same language?
A: Iran and Afghanistan both speak the Indo-European language Persian. The answer is yes.
Q: is a cello and a bass the same thing?
A: The cello is played sitting down with the instrument between the knees, whereas the double bass is played
standing or sitting on a stool. The answer is no.
Q: can you use oyster card at epsom station?
A: Epsom railway station serves the town of Epsom in Surrey and is not in the London Oyster card zone. The
answer is no.
Q: {question}
A:"""
BOOLQ_PROMPT = Prompt(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,66 @@
"""Prompt for e-SNLI."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Premise:
"A person on a horse jumps over a broken down airplane."
Based on this premise, can we conclude the hypothesis "A person is training his horse for a competition." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: The person is not necessarily training his horse. The answer is it is not possible to tell.
Premise:
"A person on a horse jumps over a broken down airplane."
Based on this premise, can we conclude the hypothesis "A person is at a diner, ordering an omelette." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: One jumping horse cannot be in a diner ordering food. The answer is no.
Premise:
"A person on a horse jumps over a broken down airplane."
Based on this premise, can we conclude the hypothesis "A person is outdoors, on a horse." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: A broken down airplane is outdoors. The answer is yes.
Premise:
"Children smiling and waving at camera."
Based on this premise, can we conclude the hypothesis "They are smiling at their parents." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: Just because they are smiling and waving at a camera does not imply their parents or anyone is anyone behind
it. The answer is it is not possible to tell.
Premise:
"Children smiling and waving at camera."
Based on this premise, can we conclude the hypothesis "The kids are frowning." is true? OPTIONS:
- yes
- no
- it is not possible to tell
A: One cannot be smiling and frowning at the same time. The answer is no.
Premise:
"Children smiling and waving at camera."
Based on this premise, can we conclude the hypothesis "There are children present." is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A:The children must be present to see them smiling and waving. The answer is yes.
Premise:
\"{premise}\"
Based on this premise, can we conclude the hypothesis \"{hypothesis}\" is true?
OPTIONS:
- yes
- no
- it is not possible to tell
A: """
ESNLI_PROMPT = Prompt(
input_variables=["premise", "hypothesis"],
template=_PROMPT_TEMPLATE,
)

View File

@ -0,0 +1,23 @@
"""Prompt for HotPotQA."""
# From https://arxiv.org/pdf/2203.11171.pdf
from langchain.prompt import Prompt
_PROMPT_TEMPLATE = """Q: Which magazine was started first Arthurs Magazine or First for Women?
A: Arthurs Magazine started in 1844. First for Women started in 1989. So Arthurs Magazine was started first.
The answer is Arthurs Magazine.
Q: The Oberoi family is part of a hotel company that has a head office in what city?
A: The Oberoi family is part of the hotel company called The Oberoi Group. The Oberoi Group has its head
office in Delhi. The answer is Delhi.
Q: What nationality was James Henry Millers wife?
A: James Henry Millers wife is June Miller. June Miller is an American. The answer is American.
Q: The Dutch-Belgian television series that "House of Anubis" was based on first aired in what year?
A: "House of Anubis" is based on the DutchBelgian television series Het Huis Anubis. Het Huis Anubis is first
aired in September 2006. The answer is 2006.
Q: {question} Reason step-by-step.
A:"""
HOTPOTQA_PROMPT = Prompt(
input_variables=["question"],
template=_PROMPT_TEMPLATE,
)

View File

@ -1,11 +1,25 @@
"""Base interface for large language models to expose."""
from abc import ABC, abstractmethod
from typing import List, Optional
from dataclasses import dataclass
@dataclass
class CompletionOutput:
"""A completion output."""
text: str
"""The generated text."""
logprobs: Optional[List[float]] = None
"""The total log probability assigned to the generated text."""
class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string."""
"""LLM wrapper that should take in a prompt and return a string."""
@abstractmethod
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Generate strings for the given prompt and input."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input."""
return self.generate(prompt=prompt, stop=stop)[0].text

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
from langchain.llms.utils import enforce_stop_tokens
@ -18,11 +18,11 @@ class Cohere(BaseModel, LLM):
.. code-block:: python
from langchain import Cohere
cohere = Cohere(model="gptd-instruct-tft")
cohere = Cohere(model="small")
"""
client: Any #: :meta private:
model: str = "gptd-instruct-tft"
model: str = "small"
"""Model name to use."""
max_tokens: int = 256
@ -43,6 +43,12 @@ class Cohere(BaseModel, LLM):
presence_penalty: int = 0
"""Penalizes repeated tokens."""
num_generations: int = 1
"""Number of generations to return."""
return_likelihoods: bool = True
"""Whether to return the likelihoods of the generated tokens."""
class Config:
"""Configuration for this pydantic object."""
@ -67,7 +73,7 @@ class Cohere(BaseModel, LLM):
)
return values
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Call out to Cohere's generate endpoint.
Args:
@ -92,10 +98,22 @@ class Cohere(BaseModel, LLM):
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop_sequences=stop,
num_generations=self.num_generations,
return_likelihoods="GENERATION" if self.return_likelihoods else None,
)
text = response.generations[0].text
results = []
for generation in response.generations:
txt = generation.text
if stop is not None:
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
txt = enforce_stop_tokens(txt, stop)
N = len(generation.token_likelihoods)
logprobs = [token.likelihood / N for token in generation.token_likelihoods]
results.append(
CompletionOutput(
text=txt,
logprobs=logprobs,
)
)
return results

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
from langchain.llms.utils import enforce_stop_tokens
DEFAULT_REPO_ID = "gpt2"
@ -69,14 +69,22 @@ class HuggingFaceHub(BaseModel, LLM):
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling HuggingFace Hub API."""
# Convert temperature from [0, 1] to [1, 100] so 0 maps to 1 and 1 maps to 100.
temperature = self.temperature
if 0.0 <= temperature <= 1.0:
temperature = 1.0 + (temperature * 99.0)
# TODO: Add support for returning logprobs once added to the API.
return {
"temperature": self.temperature,
"temperature": temperature,
"max_new_tokens": self.max_new_tokens,
"top_p": self.top_p,
"num_return_sequences": self.num_return_sequences,
"return_full_text": False,
}
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(
self, prompt: str, stop: Optional[List[str]] = None
) -> List[CompletionOutput]:
"""Call out to HuggingFace Hub's inference endpoint.
Args:
@ -94,9 +102,14 @@ class HuggingFaceHub(BaseModel, LLM):
response = self.client(inputs=prompt, params=self._default_params)
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
text = response[0]["generated_text"][len(prompt) :]
if stop is not None:
return []
results = []
for result in response:
text = result["generated_text"]
if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce
# stop tokens when making calls to huggingface_hub.
text = enforce_stop_tokens(text, stop)
return text
results.append(CompletionOutput(text=text))
return results

View File

@ -4,9 +4,21 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
def _get_completion_logprobs(txt: str, tokens : List[str], token_logprobs: List[float]) -> List[float]:
"""Get the log probabilities corresponding to the tokens generated."""
N = len(txt)
_total = 0
results = []
for i in range(len(tokens)):
if _total >= N:
break
_total += len(tokens[i])
results.append(token_logprobs[i])
return results
class OpenAI(BaseModel, LLM):
"""Wrapper around OpenAI large language models.
@ -37,6 +49,9 @@ class OpenAI(BaseModel, LLM):
"""How many completions to generate for each prompt."""
best_of: int = 1
"""Generates best_of completions server-side and returns the "best"."""
logprobs: Optional[int] = 0
"""Returns the log probabilities of the generated tokens."""
class Config:
"""Configuration for this pydantic object."""
@ -73,9 +88,10 @@ class OpenAI(BaseModel, LLM):
"presence_penalty": self.presence_penalty,
"n": self.n,
"best_of": self.best_of,
"logprobs": self.logprobs,
}
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Call out to OpenAI's create endpoint.
Args:
@ -93,4 +109,14 @@ class OpenAI(BaseModel, LLM):
response = self.client.create(
model=self.model_name, prompt=prompt, stop=stop, **self._default_params
)
return response["choices"][0]["text"]
results = []
for choice in response["choices"]:
text = choice["text"]
truncated_logprobs = None
if choice["logprobs"] is not None:
tokens = choice["logprobs"]["tokens"]
token_logprobs = choice["logprobs"]["token_logprobs"]
truncated_logprobs = _get_completion_logprobs(choice["text"], tokens, token_logprobs)
results.append(CompletionOutput(text=text, logprobs=truncated_logprobs))
return results

View File

@ -3,18 +3,18 @@
from typing import List, Optional
from langchain.chains.natbot.base import NatBotChain
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
class FakeLLM(LLM):
"""Fake LLM wrapper for testing purposes."""
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:
return "foo"
return [CompletionOutput(text="foo")]
else:
return "bar"
return [CompletionOutput(text="bar")]
def test_proper_inputs() -> None:

View File

@ -8,7 +8,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.react.base import ReActChain, predict_until_observation
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
from langchain.prompt import Prompt
_PAGE_CONTENT = """This is a page about LangChain.
@ -30,10 +30,10 @@ class FakeListLLM(LLM):
self.responses = responses
self.i = -1
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""Increment counter, and then return response in that index."""
self.i += 1
return self.responses[self.i]
return [CompletionOutput(text=self.responses[self.i])]
class FakeDocstore(Docstore):

View File

@ -1,7 +1,7 @@
"""Fake LLM wrapper for testing purposes."""
from typing import List, Mapping, Optional
from langchain.llms.base import LLM
from langchain.llms.base import LLM, CompletionOutput
class FakeLLM(LLM):
@ -11,11 +11,11 @@ class FakeLLM(LLM):
"""Initialize with optional lookup of queries."""
self._queries = queries
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
def generate(self, prompt: str, stop: Optional[List[str]] = None) -> List[CompletionOutput]:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
if self._queries is not None:
return self._queries[prompt]
return [CompletionOutput(text=self._queries[prompt])]
if stop is None:
return "foo"
return [CompletionOutput(text="foo")]
else:
return "bar"
return [CompletionOutput(text="bar")]