@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import functools
import inspect
import itertools
import logging
import uuid
@ -19,6 +20,7 @@ from typing import (
Sequence ,
Tuple ,
Union ,
cast ,
)
from urllib . parse import urlparse , urlunparse
@ -37,12 +39,20 @@ from langchain.evaluation.schema import EvaluatorType, StringEvaluator
from langchain . schema import ChatResult , LLMResult
from langchain . schema . language_model import BaseLanguageModel
from langchain . schema . messages import BaseMessage , messages_from_dict
from langchain . schema . runnable import Runnable , RunnableConfig , RunnableLambda
from langchain . smith . evaluation . config import EvalConfig , RunEvalConfig
from langchain . smith . evaluation . string_run_evaluator import StringRunEvaluatorChain
logger = logging . getLogger ( __name__ )
MODEL_OR_CHAIN_FACTORY = Union [ Callable [ [ ] , Chain ] , BaseLanguageModel ]
MODEL_OR_CHAIN_FACTORY = Union [
Callable [ [ ] , Union [ Chain , Runnable ] ] ,
BaseLanguageModel ,
Callable [ [ dict ] , Any ] ,
Runnable ,
Chain ,
]
MCF = Union [ Callable [ [ ] , Union [ Chain , Runnable ] ] , BaseLanguageModel ]
class InputFormatError ( Exception ) :
@ -66,9 +76,9 @@ def _get_eval_project_url(api_url: str, project_id: str) -> str:
def _wrap_in_chain_factory (
llm_or_chain_factory : Union[ Chain , MODEL_OR_CHAIN_FACTORY] ,
llm_or_chain_factory : MODEL_OR_CHAIN_FACTORY,
dataset_name : str = " <my_dataset> " ,
) - > M ODEL_OR_ CHAIN_ FACTORY :
) - > M CF:
""" Forgive the user if they pass in a chain without memory instead of a chain
factory . It ' s a common mistake. Raise a more helpful error message as well. " " "
if isinstance ( llm_or_chain_factory , Chain ) :
@ -105,11 +115,31 @@ def _wrap_in_chain_factory(
return lambda : chain
elif isinstance ( llm_or_chain_factory , BaseLanguageModel ) :
return llm_or_chain_factory
elif isinstance ( llm_or_chain_factory , Runnable ) :
# Memory may exist here, but it's not elegant to check all those cases.
lcf = llm_or_chain_factory
return lambda : lcf
elif callable ( llm_or_chain_factory ) :
_model = llm_or_chain_factory ( )
try :
_model = llm_or_chain_factory ( ) # type: ignore[call-arg]
except TypeError :
# It's an arbitrary function, wrap it in a RunnableLambda
user_func = cast ( Callable , llm_or_chain_factory )
sig = inspect . signature ( user_func )
logger . info ( f " Wrapping function { sig } as RunnableLambda. " )
wrapped = RunnableLambda ( user_func )
return lambda : wrapped
constructor = cast ( Callable , llm_or_chain_factory )
if isinstance ( _model , BaseLanguageModel ) :
# It's not uncommon to do an LLM constructor instead of raw LLM,
# so we'll unpack it for the user.
return _model
return llm_or_chain_factory
elif not isinstance ( _model , Runnable ) :
# This is unlikely to happen - a constructor for a model function
return lambda : RunnableLambda ( constructor )
else :
# Typical correct case
return constructor # noqa
return llm_or_chain_factory
@ -220,7 +250,7 @@ def _get_messages(inputs: Dict[str, Any]) -> List[BaseMessage]:
def _get_project_name (
project_name : Optional [ str ] ,
llm_or_chain_factory : M ODEL_OR_ CHAIN_ FACTORY ,
llm_or_chain_factory : M CF,
) - > str :
"""
Get the project name .
@ -315,7 +345,7 @@ def _validate_example_inputs_for_chain(
def _validate_example_inputs (
examples : Iterator [ Example ] ,
llm_or_chain_factory : M ODEL_OR_ CHAIN_ FACTORY ,
llm_or_chain_factory : M CF,
input_mapper : Optional [ Callable [ [ Dict ] , Any ] ] ,
) - > Iterator [ Example ] :
""" Validate that the example inputs are valid for the model. """
@ -324,7 +354,11 @@ def _validate_example_inputs(
_validate_example_inputs_for_language_model ( first_example , input_mapper )
else :
chain = llm_or_chain_factory ( )
_validate_example_inputs_for_chain ( first_example , chain , input_mapper )
if isinstance ( chain , Chain ) :
# Otherwise it's a runnable
_validate_example_inputs_for_chain ( first_example , chain , input_mapper )
elif isinstance ( chain , Runnable ) :
logger . debug ( f " Skipping input validation for { chain } " )
return examples
@ -332,7 +366,7 @@ def _validate_example_inputs(
def _setup_evaluation (
llm_or_chain_factory : M ODEL_OR_ CHAIN_ FACTORY ,
llm_or_chain_factory : M CF,
examples : Iterator [ Example ] ,
evaluation : Optional [ RunEvalConfig ] ,
data_type : DataType ,
@ -353,8 +387,8 @@ def _setup_evaluation(
" Please specify a dataset with the default ' kv ' data type. "
)
chain = llm_or_chain_factory ( )
run_inputs = chain . input_keys
run_outputs = chain . output_keys
run_inputs = chain . input_keys if isinstance ( chain , Chain ) else None
run_outputs = chain . output_keys if isinstance ( chain , Chain ) else None
run_evaluators = _load_run_evaluators (
evaluation ,
run_type ,
@ -372,17 +406,15 @@ def _setup_evaluation(
def _determine_input_key (
config : RunEvalConfig ,
run_inputs : Optional [ List [ str ] ] ,
run_type : str ,
) - > Optional [ str ] :
input_key = None
if config . input_key :
input_key = config . input_key
if run_inputs and input_key not in run_inputs :
raise ValueError ( f " Input key { input_key } not in run inputs { run_inputs } " )
elif run_type == " llm " :
input_key = None
elif run_inputs and len ( run_inputs ) == 1 :
input_key = run_inputs [ 0 ]
el se:
el if run_input s is not Non e and len ( run_inputs ) > 1 :
raise ValueError (
f " Must specify input key for model with multiple inputs: { run_inputs } "
)
@ -393,19 +425,17 @@ def _determine_input_key(
def _determine_prediction_key (
config : RunEvalConfig ,
run_outputs : Optional [ List [ str ] ] ,
run_type : str ,
) - > Optional [ str ] :
prediction_key = None
if config . prediction_key :
prediction_key = config . prediction_key
if run_outputs and prediction_key not in run_outputs :
raise ValueError (
f " Prediction key { prediction_key } not in run outputs { run_outputs } "
)
elif run_type == " llm " :
prediction_key = None
elif run_outputs and len ( run_outputs ) == 1 :
prediction_key = run_outputs [ 0 ]
el se:
el if run_output s is not Non e and len ( run_outputs ) > 1 :
raise ValueError (
f " Must specify prediction key for model "
f " with multiple outputs: { run_outputs } "
@ -491,8 +521,8 @@ def _load_run_evaluators(
"""
eval_llm = config . eval_llm or ChatOpenAI ( model = " gpt-4 " , temperature = 0.0 )
run_evaluators = [ ]
input_key = _determine_input_key ( config , run_inputs , run_type )
prediction_key = _determine_prediction_key ( config , run_outputs , run_type )
input_key = _determine_input_key ( config , run_inputs )
prediction_key = _determine_prediction_key ( config , run_outputs )
reference_key = _determine_reference_key ( config , example_outputs )
for eval_config in config . evaluators :
run_evaluator = _construct_run_evaluator (
@ -590,7 +620,7 @@ async def _arun_llm(
async def _arun_chain (
chain : Chain,
chain : Union[ Chain, Runnable ] ,
inputs : Dict [ str , Any ] ,
callbacks : Callbacks ,
* ,
@ -598,20 +628,22 @@ async def _arun_chain(
input_mapper : Optional [ Callable [ [ Dict ] , Any ] ] = None ,
) - > Union [ dict , str ] :
""" Run a chain asynchronously on inputs. """
if input_mapper is not None :
inputs_ = input_mapper ( inputs )
output : Union [ dict , str ] = await chain . acall (
inputs_ , callbacks = callbacks , tags = tags
)
inputs_ = inputs if input_mapper is None else input_mapper ( inputs )
if isinstance ( chain , Chain ) :
if isinstance ( inputs_ , dict ) and len ( inputs_ ) == 1 :
val = next ( iter ( inputs_ . values ( ) ) )
output = await chain . acall ( val , callbacks = callbacks , tags = tags )
else :
output = await chain . acall ( inputs_ , callbacks = callbacks , tags = tags )
else :
inputs_ = next ( iter ( inputs . values ( ) ) ) if len ( inputs ) == 1 else inputs
output = await chain . a call( inputs_ , callbacks = callbacks , tags = tags )
runnable_config = RunnableConfig ( tags = tags or [ ] , callbacks = callbacks )
output = await chain . a invoke( inputs_ , config = runnable_config )
return output
async def _arun_llm_or_chain (
example : Example ,
llm_or_chain_factory : M ODEL_OR_ CHAIN_ FACTORY ,
llm_or_chain_factory : M CF,
n_repetitions : int ,
* ,
tags : Optional [ List [ str ] ] = None ,
@ -810,12 +842,12 @@ async def _arun_on_examples(
Returns :
A dictionary mapping example ids to the model outputs .
"""
llm_or_chain_factory = _wrap_in_chain_factory ( llm_or_chain_factory )
project_name = _get_project_name ( project_name , llm_or_chain_factory )
wrapped_model = _wrap_in_chain_factory ( llm_or_chain_factory )
project_name = _get_project_name ( project_name , wrapped_model )
run_evaluators , examples = _setup_evaluation (
llm_or_chain_factory , examples , evaluation , data_type
wrapped_model , examples , evaluation , data_type
)
examples = _validate_example_inputs ( examples , llm_or_chain_factory , input_mapper )
examples = _validate_example_inputs ( examples , wrapped_model , input_mapper )
results : Dict [ str , List [ Any ] ] = { }
async def process_example (
@ -824,7 +856,7 @@ async def _arun_on_examples(
""" Process a single example. """
result = await _arun_llm_or_chain (
example ,
llm_or_chain_factory ,
wrapped_model ,
num_repetitions ,
tags = tags ,
callbacks = callbacks ,
@ -911,7 +943,7 @@ def _run_llm(
def _run_chain (
chain : Chain,
chain : Union[ Chain, Runnable ] ,
inputs : Dict [ str , Any ] ,
callbacks : Callbacks ,
* ,
@ -919,18 +951,22 @@ def _run_chain(
input_mapper : Optional [ Callable [ [ Dict ] , Any ] ] = None ,
) - > Union [ Dict , str ] :
""" Run a chain on inputs. """
if input_mapper is not None :
inputs_ = input_mapper ( inputs )
output : Union [ dict , str ] = chain ( inputs_ , callbacks = callbacks , tags = tags )
inputs_ = inputs if input_mapper is None else input_mapper ( inputs )
if isinstance ( chain , Chain ) :
if isinstance ( inputs_ , dict ) and len ( inputs_ ) == 1 :
val = next ( iter ( inputs_ . values ( ) ) )
output = chain ( val , callbacks = callbacks , tags = tags )
else :
output = chain ( inputs_ , callbacks = callbacks , tags = tags )
else :
inputs_ = next ( iter ( inputs . values ( ) ) ) if len ( inputs ) == 1 else inputs
output = chain ( inputs_ , callbacks = callbacks , tags = tags )
runnable_config = RunnableConfig ( tags = tags or [ ] , callbacks = callbacks )
output = chain . invoke ( inputs_ , config = runnable_config )
return output
def _run_llm_or_chain (
example : Example ,
llm_or_chain_factory : M ODEL_OR_ CHAIN_ FACTORY ,
llm_or_chain_factory : M CF,
n_repetitions : int ,
* ,
tags : Optional [ List [ str ] ] = None ,
@ -986,7 +1022,8 @@ def _run_llm_or_chain(
outputs . append ( output )
except Exception as e :
logger . warning (
f " { chain_or_llm } failed for example { example . id } . Error: { e } "
f " { chain_or_llm } failed for example { example . id } with inputs: "
f " { example . inputs } . \n Error: { e } " ,
)
outputs . append ( { " Error " : str ( e ) } )
if callbacks and previous_example_ids :
@ -1080,7 +1117,7 @@ def _prepare_eval_run(
dataset_name : str ,
llm_or_chain_factory : MODEL_OR_CHAIN_FACTORY ,
project_name : Optional [ str ] ,
) - > Tuple [ M ODEL_OR_ CHAIN_ FACTORY , str , Dataset , Iterator [ Example ] ] :
) - > Tuple [ M CF, str , Dataset , Iterator [ Example ] ] :
llm_or_chain_factory = _wrap_in_chain_factory ( llm_or_chain_factory , dataset_name )
project_name = _get_project_name ( project_name , llm_or_chain_factory )
try :