@ -5,7 +5,16 @@ import asyncio
import functools
import logging
from datetime import datetime
from typing import Any , Callable , Coroutine , Dict , Iterator , List , Optional , Union
from typing import (
Any ,
Callable ,
Coroutine ,
Dict ,
Iterator ,
List ,
Optional ,
Union ,
)
from langchainplus_sdk import LangChainPlusClient
from langchainplus_sdk . schemas import Example
@ -104,6 +113,8 @@ async def _arun_llm(
llm : BaseLanguageModel ,
inputs : Dict [ str , Any ] ,
langchain_tracer : Optional [ LangChainTracer ] ,
* ,
tags : Optional [ List [ str ] ] = None ,
) - > Union [ LLMResult , ChatResult ] :
callbacks : Optional [ List [ BaseCallbackHandler ] ] = (
[ langchain_tracer ] if langchain_tracer else None
@ -111,21 +122,27 @@ async def _arun_llm(
if isinstance ( llm , BaseLLM ) :
try :
llm_prompts = _get_prompts ( inputs )
llm_output = await llm . agenerate ( llm_prompts , callbacks = callbacks )
llm_output = await llm . agenerate (
llm_prompts , callbacks = callbacks , tags = tags
)
except InputFormatError :
llm_messages = _get_messages ( inputs )
buffer_strings = [ get_buffer_string ( messages ) for messages in llm_messages ]
llm_output = await llm . agenerate ( buffer_strings , callbacks = callbacks )
llm_output = await llm . agenerate (
buffer_strings , callbacks = callbacks , tags = tags
)
elif isinstance ( llm , BaseChatModel ) :
try :
messages = _get_messages ( inputs )
llm_output = await llm . agenerate ( messages , callbacks = callbacks )
llm_output = await llm . agenerate ( messages , callbacks = callbacks , tags = tags )
except InputFormatError :
prompts = _get_prompts ( inputs )
converted_messages : List [ List [ BaseMessage ] ] = [
[ HumanMessage ( content = prompt ) ] for prompt in prompts
]
llm_output = await llm . agenerate ( converted_messages , callbacks = callbacks )
llm_output = await llm . agenerate (
converted_messages , callbacks = callbacks , tags = tags
)
else :
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
return llm_output
@ -136,6 +153,8 @@ async def _arun_llm_or_chain(
llm_or_chain_factory : MODEL_OR_CHAIN_FACTORY ,
n_repetitions : int ,
langchain_tracer : Optional [ LangChainTracer ] ,
* ,
tags : Optional [ List [ str ] ] = None ,
) - > Union [ List [ dict ] , List [ str ] , List [ LLMResult ] , List [ ChatResult ] ] :
""" Run the chain asynchronously. """
if langchain_tracer is not None :
@ -150,11 +169,16 @@ async def _arun_llm_or_chain(
try :
if isinstance ( llm_or_chain_factory , BaseLanguageModel ) :
output : Any = await _arun_llm (
llm_or_chain_factory , example . inputs , langchain_tracer
llm_or_chain_factory ,
example . inputs ,
langchain_tracer ,
tags = tags ,
)
else :
chain = llm_or_chain_factory ( )
output = await chain . acall ( example . inputs , callbacks = callbacks )
output = await chain . acall (
example . inputs , callbacks = callbacks , tags = tags
)
outputs . append ( output )
except Exception as e :
logger . warning ( f " Chain failed for example { example . id } . Error: { e } " )
@ -230,6 +254,7 @@ async def arun_on_examples(
num_repetitions : int = 1 ,
session_name : Optional [ str ] = None ,
verbose : bool = False ,
tags : Optional [ List [ str ] ] = None ,
) - > Dict [ str , Any ] :
"""
Run the chain on examples and store traces to the specified session name .
@ -245,6 +270,7 @@ async def arun_on_examples(
intervals .
session_name : Session name to use when tracing runs .
verbose : Whether to print progress .
tags : Tags to add to the traces .
Returns :
A dictionary mapping example ids to the model outputs .
@ -260,6 +286,7 @@ async def arun_on_examples(
llm_or_chain_factory ,
num_repetitions ,
tracer ,
tags = tags ,
)
results [ str ( example . id ) ] = result
job_state [ " num_processed " ] + = 1
@ -282,12 +309,14 @@ def run_llm(
llm : BaseLanguageModel ,
inputs : Dict [ str , Any ] ,
callbacks : Callbacks ,
* ,
tags : Optional [ List [ str ] ] = None ,
) - > Union [ LLMResult , ChatResult ] :
""" Run the language model on the example. """
if isinstance ( llm , BaseLLM ) :
try :
llm_prompts = _get_prompts ( inputs )
llm_output = llm . generate ( llm_prompts , callbacks = callbacks )
llm_output = llm . generate ( llm_prompts , callbacks = callbacks , tags = tags )
except InputFormatError :
llm_messages = _get_messages ( inputs )
buffer_strings = [ get_buffer_string ( messages ) for messages in llm_messages ]
@ -295,13 +324,15 @@ def run_llm(
elif isinstance ( llm , BaseChatModel ) :
try :
messages = _get_messages ( inputs )
llm_output = llm . generate ( messages , callbacks = callbacks )
llm_output = llm . generate ( messages , callbacks = callbacks , tags = tags )
except InputFormatError :
prompts = _get_prompts ( inputs )
converted_messages : List [ List [ BaseMessage ] ] = [
[ HumanMessage ( content = prompt ) ] for prompt in prompts
]
llm_output = llm . generate ( converted_messages , callbacks = callbacks )
llm_output = llm . generate (
converted_messages , callbacks = callbacks , tags = tags
)
else :
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
return llm_output
@ -312,6 +343,8 @@ def run_llm_or_chain(
llm_or_chain_factory : MODEL_OR_CHAIN_FACTORY ,
n_repetitions : int ,
langchain_tracer : Optional [ LangChainTracer ] = None ,
* ,
tags : Optional [ List [ str ] ] = None ,
) - > Union [ List [ dict ] , List [ str ] , List [ LLMResult ] , List [ ChatResult ] ] :
""" Run the chain synchronously. """
if langchain_tracer is not None :
@ -325,10 +358,12 @@ def run_llm_or_chain(
for _ in range ( n_repetitions ) :
try :
if isinstance ( llm_or_chain_factory , BaseLanguageModel ) :
output : Any = run_llm ( llm_or_chain_factory , example . inputs , callbacks )
output : Any = run_llm (
llm_or_chain_factory , example . inputs , callbacks , tags = tags
)
else :
chain = llm_or_chain_factory ( )
output = chain ( example . inputs , callbacks = callbacks )
output = chain ( example . inputs , callbacks = callbacks , tags = tags )
outputs . append ( output )
except Exception as e :
logger . warning ( f " Chain failed for example { example . id } . Error: { e } " )
@ -345,6 +380,7 @@ def run_on_examples(
num_repetitions : int = 1 ,
session_name : Optional [ str ] = None ,
verbose : bool = False ,
tags : Optional [ List [ str ] ] = None ,
) - > Dict [ str , Any ] :
""" Run the chain on examples and store traces to the specified session name.
@ -359,6 +395,7 @@ def run_on_examples(
intervals .
session_name : Session name to use when tracing runs .
verbose : Whether to print progress .
tags : Tags to add to the run traces .
Returns :
A dictionary mapping example ids to the model outputs .
"""
@ -370,6 +407,7 @@ def run_on_examples(
llm_or_chain_factory ,
num_repetitions ,
langchain_tracer = tracer ,
tags = tags ,
)
if verbose :
print ( f " { i + 1 } processed " , flush = True , end = " \r " )
@ -401,6 +439,7 @@ async def arun_on_dataset(
session_name : Optional [ str ] = None ,
verbose : bool = False ,
client : Optional [ LangChainPlusClient ] = None ,
tags : Optional [ List [ str ] ] = None ,
) - > Dict [ str , Any ] :
"""
Run the chain on a dataset and store traces to the specified session name .
@ -420,6 +459,7 @@ async def arun_on_dataset(
verbose : Whether to print progress .
client : Client to use to read the dataset . If not provided , a new
client will be created using the credentials in the environment .
tags : Tags to add to each run in the sesssion .
Returns :
A dictionary containing the run ' s session name and the resulting model outputs.
@ -436,6 +476,7 @@ async def arun_on_dataset(
num_repetitions = num_repetitions ,
session_name = session_name ,
verbose = verbose ,
tags = tags ,
)
return {
" session_name " : session_name ,
@ -451,6 +492,7 @@ def run_on_dataset(
session_name : Optional [ str ] = None ,
verbose : bool = False ,
client : Optional [ LangChainPlusClient ] = None ,
tags : Optional [ List [ str ] ] = None ,
) - > Dict [ str , Any ] :
""" Run the chain on a dataset and store traces to the specified session name.
@ -468,6 +510,7 @@ def run_on_dataset(
verbose : Whether to print progress .
client : Client to use to access the dataset . If None , a new client
will be created using the credentials in the environment .
tags : Tags to add to each run in the sesssion .
Returns :
A dictionary containing the run ' s session name and the resulting model outputs.
@ -482,6 +525,7 @@ def run_on_dataset(
num_repetitions = num_repetitions ,
session_name = session_name ,
verbose = verbose ,
tags = tags ,
)
return {
" session_name " : session_name ,