forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
"""Select examples based on length."""
|
|
import re
|
|
from typing import Callable, Dict, List
|
|
|
|
from pydantic import BaseModel, validator
|
|
|
|
from langchain.prompts.example_selector.base import BaseExampleSelector
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
class LengthBasedExampleSelector(BaseExampleSelector, BaseModel):
|
|
"""Select examples based on length."""
|
|
|
|
examples: List[dict]
|
|
"""A list of the examples that the prompt template expects."""
|
|
|
|
example_prompt: PromptTemplate
|
|
"""Prompt template used to format the examples."""
|
|
|
|
get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x))
|
|
"""Function to measure prompt length. Defaults to word count."""
|
|
|
|
max_length: int = 2048
|
|
"""Max length for the prompt, beyond which examples are cut."""
|
|
|
|
example_text_lengths: List[int] = [] #: :meta private:
|
|
|
|
@validator("example_text_lengths", always=True)
|
|
def calculate_example_text_lengths(cls, v: List[int], values: Dict) -> List[int]:
|
|
"""Calculate text lengths if they don't exist."""
|
|
# Check if text lengths were passed in
|
|
if v:
|
|
return v
|
|
# If they were not, calculate them
|
|
example_prompt = values["example_prompt"]
|
|
get_text_length = values["get_text_length"]
|
|
string_examples = [example_prompt.format(**eg) for eg in values["examples"]]
|
|
return [get_text_length(eg) for eg in string_examples]
|
|
|
|
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
|
"""Select which examples to use based on the input lengths."""
|
|
inputs = " ".join(input_variables.values())
|
|
remaining_length = self.max_length - self.get_text_length(inputs)
|
|
i = 0
|
|
examples = []
|
|
while remaining_length > 0 and i < len(self.examples):
|
|
new_length = remaining_length - self.example_text_lengths[i]
|
|
if i < 0:
|
|
break
|
|
else:
|
|
examples.append(self.examples[i])
|
|
remaining_length = new_length
|
|
i += 1
|
|
return examples
|