mirror of https://github.com/hwchase17/langchain
Synthetic Data generation (#9472)
--------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>pull/11006/head^2
parent
a4e0cf6300
commit
5d7c6d1bca
@ -0,0 +1,135 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel, root_validator
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
|
||||
class SyntheticDataGenerator(BaseModel):
|
||||
"""Generates synthetic data using the given LLM and few-shot template.
|
||||
|
||||
Utilizes the provided LLM to produce synthetic data based on the
|
||||
few-shot prompt template.
|
||||
|
||||
Attributes:
|
||||
template (FewShotPromptTemplate): Template for few-shot prompting.
|
||||
llm (Optional[BaseLanguageModel]): Large Language Model to use for generation.
|
||||
llm_chain (Optional[Chain]): LLM chain with the LLM and few-shot template.
|
||||
example_input_key (str): Key to use for storing example inputs.
|
||||
|
||||
Usage Example:
|
||||
>>> template = FewShotPromptTemplate(...)
|
||||
>>> llm = BaseLanguageModel(...)
|
||||
>>> generator = SyntheticDataGenerator(template=template, llm=llm)
|
||||
>>> results = generator.generate(subject="climate change", runs=5)
|
||||
"""
|
||||
|
||||
template: FewShotPromptTemplate
|
||||
llm: Optional[BaseLanguageModel] = None
|
||||
results: list = []
|
||||
llm_chain: Optional[Chain] = None
|
||||
example_input_key: str = "example"
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def set_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
llm_chain = values.get("llm_chain")
|
||||
llm = values.get("llm")
|
||||
few_shot_template = values.get("template")
|
||||
|
||||
if not llm_chain: # If llm_chain is None or not present
|
||||
if llm is None or few_shot_template is None:
|
||||
raise ValueError(
|
||||
"Both llm and few_shot_template must be provided if llm_chain is "
|
||||
"not given."
|
||||
)
|
||||
values["llm_chain"] = LLMChain(llm=llm, prompt=few_shot_template)
|
||||
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def _format_dict_to_string(input_dict: Dict) -> str:
|
||||
formatted_str = ", ".join(
|
||||
[f"{key}: {value}" for key, value in input_dict.items()]
|
||||
)
|
||||
return formatted_str
|
||||
|
||||
def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None:
|
||||
"""Prevents duplicates by adding previously generated examples to the few shot
|
||||
list."""
|
||||
if self.template and self.template.examples:
|
||||
if isinstance(example, BaseModel):
|
||||
formatted_example = self._format_dict_to_string(example.dict())
|
||||
elif isinstance(example, dict):
|
||||
formatted_example = self._format_dict_to_string(example)
|
||||
else:
|
||||
formatted_example = str(example)
|
||||
self.template.examples.pop(0)
|
||||
self.template.examples.append({self.example_input_key: formatted_example})
|
||||
|
||||
def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]:
|
||||
"""Generate synthetic data using the given subject string.
|
||||
|
||||
Args:
|
||||
subject (str): The subject the synthetic data will be about.
|
||||
runs (int): Number of times to generate the data.
|
||||
extra (str): Extra instructions for steerability in data generation.
|
||||
|
||||
Returns:
|
||||
List[str]: List of generated synthetic data.
|
||||
|
||||
Usage Example:
|
||||
>>> results = generator.generate(subject="climate change", runs=5,
|
||||
extra="Focus on environmental impacts.")
|
||||
"""
|
||||
if self.llm_chain is None:
|
||||
raise ValueError(
|
||||
"llm_chain is none, either set either llm_chain or llm at generator "
|
||||
"construction"
|
||||
)
|
||||
for _ in range(runs):
|
||||
result = self.llm_chain.run(subject=subject, *args, **kwargs)
|
||||
self.results.append(result)
|
||||
self._update_examples(result)
|
||||
return self.results
|
||||
|
||||
async def agenerate(
|
||||
self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any
|
||||
) -> List[str]:
|
||||
"""Generate synthetic data using the given subject asynchronously.
|
||||
|
||||
Note: Since the LLM calls run concurrently,
|
||||
you may have fewer duplicates by adding specific instructions to
|
||||
the "extra" keyword argument.
|
||||
|
||||
Args:
|
||||
subject (str): The subject the synthetic data will be about.
|
||||
runs (int): Number of times to generate the data asynchronously.
|
||||
extra (str): Extra instructions for steerability in data generation.
|
||||
|
||||
Returns:
|
||||
List[str]: List of generated synthetic data for the given subject.
|
||||
|
||||
Usage Example:
|
||||
>>> results = await generator.agenerate(subject="climate change", runs=5,
|
||||
extra="Focus on env impacts.")
|
||||
"""
|
||||
|
||||
async def run_chain(
|
||||
subject: str, extra: str = "", *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
if self.llm_chain is not None:
|
||||
result = await self.llm_chain.arun(
|
||||
subject=subject, extra=extra, *args, **kwargs
|
||||
)
|
||||
self.results.append(result)
|
||||
|
||||
await asyncio.gather(
|
||||
*(run_chain(subject=subject, extra=extra) for _ in range(runs))
|
||||
)
|
||||
return self.results
|
@ -0,0 +1,64 @@
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
from langchain.chains.openai_functions import create_structured_output_chain
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate
|
||||
|
||||
from langchain_experimental.tabular_synthetic_data.base import SyntheticDataGenerator
|
||||
|
||||
OPENAI_TEMPLATE = PromptTemplate(input_variables=["example"], template="{example}")
|
||||
|
||||
|
||||
def create_openai_data_generator(
|
||||
output_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: ChatOpenAI,
|
||||
prompt: BasePromptTemplate,
|
||||
output_parser: Optional[BaseLLMOutputParser] = None,
|
||||
**kwargs: Any
|
||||
) -> SyntheticDataGenerator:
|
||||
"""
|
||||
Create an instance of SyntheticDataGenerator tailored for OpenAI models.
|
||||
|
||||
This function creates an LLM chain designed for structured output based on the
|
||||
provided schema, language model, and prompt template. The resulting chain is then
|
||||
used to instantiate and return a SyntheticDataGenerator.
|
||||
|
||||
Args:
|
||||
output_schema (Union[Dict[str, Any], Type[BaseModel]]): Schema for expected
|
||||
output. This can be either a dictionary representing a valid JsonSchema or a
|
||||
Pydantic BaseModel class.
|
||||
|
||||
|
||||
llm (ChatOpenAI): OpenAI language model to use.
|
||||
|
||||
prompt (BasePromptTemplate): Template to be used for generating prompts.
|
||||
|
||||
|
||||
output_parser (Optional[BaseLLMOutputParser], optional): Parser for
|
||||
processing model outputs. If none is provided, a default will be inferred
|
||||
from the function types.
|
||||
|
||||
|
||||
**kwargs: Additional keyword arguments to be passed to
|
||||
`create_structured_output_chain`.
|
||||
|
||||
|
||||
Returns: SyntheticDataGenerator: An instance of the data generator set up with
|
||||
the constructed chain.
|
||||
|
||||
Usage:
|
||||
To generate synthetic data with a structured output, first define your desired
|
||||
output schema. Then, use this function to create a SyntheticDataGenerator
|
||||
instance. After obtaining the generator, you can utilize its methods to produce
|
||||
the desired synthetic data.
|
||||
"""
|
||||
# Create function calling chain to ensure structured output
|
||||
chain = create_structured_output_chain(
|
||||
output_schema, llm, prompt, output_parser=output_parser, **kwargs
|
||||
)
|
||||
|
||||
# Create the SyntheticDataGenerator instance with the created chain
|
||||
generator = SyntheticDataGenerator(template=prompt, llm_chain=chain)
|
||||
return generator
|
@ -0,0 +1,13 @@
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
DEFAULT_INPUT_KEY = "example"
|
||||
DEFAULT_PROMPT = PromptTemplate(
|
||||
input_variables=[DEFAULT_INPUT_KEY], template="{example}"
|
||||
)
|
||||
|
||||
SYNTHETIC_FEW_SHOT_PREFIX = (
|
||||
"This is a test about generating synthetic data about {subject}. Examples below:"
|
||||
)
|
||||
SYNTHETIC_FEW_SHOT_SUFFIX = (
|
||||
"""Now you generate synthetic data about {subject}. Make sure to {extra}:"""
|
||||
)
|
@ -0,0 +1,104 @@
|
||||
import pytest
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
|
||||
from langchain_experimental.tabular_synthetic_data.base import SyntheticDataGenerator
|
||||
from langchain_experimental.tabular_synthetic_data.openai import (
|
||||
OPENAI_TEMPLATE,
|
||||
create_openai_data_generator,
|
||||
)
|
||||
from langchain_experimental.tabular_synthetic_data.prompts import (
|
||||
SYNTHETIC_FEW_SHOT_PREFIX,
|
||||
SYNTHETIC_FEW_SHOT_SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
# Define the desired output schema for individual medical billing record
|
||||
class MedicalBilling(BaseModel):
|
||||
patient_id: int
|
||||
patient_name: str
|
||||
diagnosis_code: str
|
||||
procedure_code: str
|
||||
total_charge: float
|
||||
insurance_claim_amount: float
|
||||
|
||||
|
||||
examples = [
|
||||
{
|
||||
"example": """Patient ID: 123456, Patient Name: John Doe, Diagnosis Code:
|
||||
J20.9, Procedure Code: 99203, Total Charge: $500, Insurance Claim Amount:
|
||||
$350"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 789012, Patient Name: Johnson Smith, Diagnosis
|
||||
Code: M54.5, Procedure Code: 99213, Total Charge: $150, Insurance Claim
|
||||
Amount: $120"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 345678, Patient Name: Emily Stone, Diagnosis Code:
|
||||
E11.9, Procedure Code: 99214, Total Charge: $300, Insurance Claim Amount:
|
||||
$250"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 901234, Patient Name: Robert Miles, Diagnosis Code:
|
||||
B07.9, Procedure Code: 99204, Total Charge: $200, Insurance Claim Amount:
|
||||
$160"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 567890, Patient Name: Clara Jensen, Diagnosis Code:
|
||||
F41.9, Procedure Code: 99205, Total Charge: $450, Insurance Claim Amount:
|
||||
$310"""
|
||||
},
|
||||
{
|
||||
"example": """Patient ID: 234567, Patient Name: Alan Turing, Diagnosis Code:
|
||||
G40.909, Procedure Code: 99215, Total Charge: $220, Insurance Claim Amount:
|
||||
$180"""
|
||||
},
|
||||
]
|
||||
|
||||
prompt_template = FewShotPromptTemplate(
|
||||
prefix=SYNTHETIC_FEW_SHOT_PREFIX,
|
||||
examples=examples,
|
||||
suffix=SYNTHETIC_FEW_SHOT_SUFFIX,
|
||||
input_variables=["subject", "extra"],
|
||||
example_prompt=OPENAI_TEMPLATE,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def synthetic_data_generator() -> SyntheticDataGenerator:
|
||||
return create_openai_data_generator(
|
||||
output_schema=MedicalBilling,
|
||||
llm=ChatOpenAI(temperature=1), # replace with your LLM instance
|
||||
prompt=prompt_template,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator) -> None:
|
||||
synthetic_results = synthetic_data_generator.generate(
|
||||
subject="medical_billing",
|
||||
extra="""the name must be chosen at random. Make it something you wouldn't
|
||||
normally choose.""",
|
||||
runs=10,
|
||||
)
|
||||
assert len(synthetic_results) == 10
|
||||
for row in synthetic_results:
|
||||
assert isinstance(row, MedicalBilling)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_synthetic(
|
||||
synthetic_data_generator: SyntheticDataGenerator,
|
||||
) -> None:
|
||||
synthetic_results = await synthetic_data_generator.agenerate(
|
||||
subject="medical_billing",
|
||||
extra="""the name must be chosen at random. Make it something you wouldn't
|
||||
normally choose.""",
|
||||
runs=10,
|
||||
)
|
||||
assert len(synthetic_results) == 10
|
||||
for row in synthetic_results:
|
||||
assert isinstance(row, MedicalBilling)
|
Loading…
Reference in New Issue