You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.2 KiB

import asyncio
from typing import Any, Dict, List, Optional, Union
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema.language_model import BaseLanguageModel
class SyntheticDataGenerator(BaseModel):
"""Generates synthetic data using the given LLM and few-shot template.
Utilizes the provided LLM to produce synthetic data based on the
few-shot prompt template.
template (FewShotPromptTemplate): Template for few-shot prompting.
llm (Optional[BaseLanguageModel]): Large Language Model to use for generation.
llm_chain (Optional[Chain]): LLM chain with the LLM and few-shot template.
example_input_key (str): Key to use for storing example inputs.
Usage Example:
>>> template = FewShotPromptTemplate(...)
>>> llm = BaseLanguageModel(...)
>>> generator = SyntheticDataGenerator(template=template, llm=llm)
>>> results = generator.generate(subject="climate change", runs=5)
template: FewShotPromptTemplate
llm: Optional[BaseLanguageModel] = None
results: list = []
llm_chain: Optional[Chain] = None
example_input_key: str = "example"
class Config:
validate_assignment = True
@root_validator(pre=False, skip_on_failure=True)
def set_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_chain = values.get("llm_chain")
llm = values.get("llm")
few_shot_template = values.get("template")
if not llm_chain: # If llm_chain is None or not present
if llm is None or few_shot_template is None:
raise ValueError(
"Both llm and few_shot_template must be provided if llm_chain is "
"not given."
values["llm_chain"] = LLMChain(llm=llm, prompt=few_shot_template)
return values
def _format_dict_to_string(input_dict: Dict) -> str:
formatted_str = ", ".join(
[f"{key}: {value}" for key, value in input_dict.items()]
return formatted_str
def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None:
"""Prevents duplicates by adding previously generated examples to the few shot
if self.template and self.template.examples:
if isinstance(example, BaseModel):
formatted_example = self._format_dict_to_string(example.dict())
elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example)
formatted_example = str(example)
self.template.examples.append({self.example_input_key: formatted_example})
def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]:
"""Generate synthetic data using the given subject string.
subject (str): The subject the synthetic data will be about.
runs (int): Number of times to generate the data.
extra (str): Extra instructions for steerability in data generation.
List[str]: List of generated synthetic data.
Usage Example:
>>> results = generator.generate(subject="climate change", runs=5,
extra="Focus on environmental impacts.")
if self.llm_chain is None:
raise ValueError(
"llm_chain is none, either set either llm_chain or llm at generator "
for _ in range(runs):
result =, *args, **kwargs)
return self.results
async def agenerate(
self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any
) -> List[str]:
"""Generate synthetic data using the given subject asynchronously.
Note: Since the LLM calls run concurrently,
you may have fewer duplicates by adding specific instructions to
the "extra" keyword argument.
subject (str): The subject the synthetic data will be about.
runs (int): Number of times to generate the data asynchronously.
extra (str): Extra instructions for steerability in data generation.
List[str]: List of generated synthetic data for the given subject.
Usage Example:
>>> results = await generator.agenerate(subject="climate change", runs=5,
extra="Focus on env impacts.")
async def run_chain(
subject: str, extra: str = "", *args: Any, **kwargs: Any
) -> None:
if self.llm_chain is not None:
result = await self.llm_chain.arun(
subject=subject, extra=extra, *args, **kwargs
await asyncio.gather(
*(run_chain(subject=subject, extra=extra) for _ in range(runs))
return self.results