@ -39,7 +39,14 @@ from langchain.client.models import (
ListRunsQueryParams ,
ListRunsQueryParams ,
)
)
from langchain . llms . base import BaseLLM
from langchain . llms . base import BaseLLM
from langchain . schema import ChatResult , LLMResult , messages_from_dict
from langchain . schema import (
BaseMessage ,
ChatResult ,
HumanMessage ,
LLMResult ,
get_buffer_string ,
messages_from_dict ,
)
from langchain . utils import raise_for_status_with_text , xor_args
from langchain . utils import raise_for_status_with_text , xor_args
if TYPE_CHECKING :
if TYPE_CHECKING :
@ -50,6 +57,10 @@ logger = logging.getLogger(__name__)
MODEL_OR_CHAIN_FACTORY = Union [ Callable [ [ ] , Chain ] , BaseLanguageModel ]
MODEL_OR_CHAIN_FACTORY = Union [ Callable [ [ ] , Chain ] , BaseLanguageModel ]
class InputFormatError ( Exception ) :
""" Raised when input format is invalid. """
def _get_link_stem ( url : str ) - > str :
def _get_link_stem ( url : str ) - > str :
scheme = urlsplit ( url ) . scheme
scheme = urlsplit ( url ) . scheme
netloc_prefix = urlsplit ( url ) . netloc . split ( " : " ) [ 0 ]
netloc_prefix = urlsplit ( url ) . netloc . split ( " : " ) [ 0 ]
@ -389,6 +400,76 @@ class LangChainPlusClient(BaseSettings):
raise_for_status_with_text ( response )
raise_for_status_with_text ( response )
return [ Example ( * * dataset ) for dataset in response . json ( ) ]
return [ Example ( * * dataset ) for dataset in response . json ( ) ]
@staticmethod
def _get_prompts ( inputs : Dict [ str , Any ] ) - > List [ str ] :
""" Get prompts from inputs. """
if not inputs :
raise InputFormatError ( " Inputs should not be empty. " )
prompts = [ ]
if " prompt " in inputs :
if not isinstance ( inputs [ " prompt " ] , str ) :
raise InputFormatError (
" Expected string for ' prompt ' , got "
f " { type ( inputs [ ' prompt ' ] ) . __name__ } "
)
prompts = [ inputs [ " prompt " ] ]
elif " prompts " in inputs :
if not isinstance ( inputs [ " prompts " ] , list ) or not all (
isinstance ( i , str ) for i in inputs [ " prompts " ]
) :
raise InputFormatError (
" Expected list of strings for ' prompts ' , "
f " got { type ( inputs [ ' prompts ' ] ) . __name__ } "
)
prompts = inputs [ " prompts " ]
elif len ( inputs ) == 1 :
prompt_ = next ( iter ( inputs . values ( ) ) )
if isinstance ( prompt_ , str ) :
prompts = [ prompt_ ]
elif isinstance ( prompt_ , list ) and all ( isinstance ( i , str ) for i in prompt_ ) :
prompts = prompt_
else :
raise InputFormatError (
f " LLM Run expects string prompt input. Got { inputs } "
)
else :
raise InputFormatError (
f " LLM Run expects ' prompt ' or ' prompts ' in inputs. Got { inputs } "
)
return prompts
@staticmethod
def _get_messages ( inputs : Dict [ str , Any ] ) - > List [ List [ BaseMessage ] ] :
""" Get Chat Messages from inputs. """
if not inputs :
raise InputFormatError ( " Inputs should not be empty. " )
if " messages " in inputs :
single_input = inputs [ " messages " ]
elif len ( inputs ) == 1 :
single_input = next ( iter ( inputs . values ( ) ) )
else :
raise InputFormatError (
f " Chat Run expects ' messages ' in inputs. Got { inputs } "
)
if isinstance ( single_input , list ) and all (
isinstance ( i , dict ) for i in single_input
) :
raw_messages = [ single_input ]
elif isinstance ( single_input , list ) and all (
isinstance ( i , list ) for i in single_input
) :
raw_messages = single_input
else :
raise InputFormatError (
f " Chat Run expects List[dict] or List[List[dict]] ' messages ' "
f " input. Got { inputs } "
)
return [ messages_from_dict ( batch ) for batch in raw_messages ]
@staticmethod
@staticmethod
async def _arun_llm (
async def _arun_llm (
llm : BaseLanguageModel ,
llm : BaseLanguageModel ,
@ -396,16 +477,31 @@ class LangChainPlusClient(BaseSettings):
langchain_tracer : LangChainTracer ,
langchain_tracer : LangChainTracer ,
) - > Union [ LLMResult , ChatResult ] :
) - > Union [ LLMResult , ChatResult ] :
if isinstance ( llm , BaseLLM ) :
if isinstance ( llm , BaseLLM ) :
if " prompt " not in inputs :
try :
raise ValueError ( f " LLM Run requires ' prompt ' input. Got { inputs } " )
llm_prompts = LangChainPlusClient . _get_prompts ( inputs )
llm_prompt : str = inputs [ " prompt " ]
llm_output = await llm . agenerate (
llm_output = await llm . agenerate ( [ llm_prompt ] , callbacks = [ langchain_tracer ] )
llm_prompts , callbacks = [ langchain_tracer ]
)
except InputFormatError :
llm_messages = LangChainPlusClient . _get_messages ( inputs )
buffer_strings = [
get_buffer_string ( messages ) for messages in llm_messages
]
llm_output = await llm . agenerate (
buffer_strings , callbacks = [ langchain_tracer ]
)
elif isinstance ( llm , BaseChatModel ) :
elif isinstance ( llm , BaseChatModel ) :
if " messages " not in inputs :
try :
raise ValueError ( f " Chat Run requires ' messages ' input. Got { inputs } " )
messages = LangChainPlusClient . _get_messages ( inputs )
raw_messages : List [ dict ] = inputs [ " messages " ]
llm_output = await llm . agenerate ( messages , callbacks = [ langchain_tracer ] )
messages = messages_from_dict ( raw_messages )
except InputFormatError :
llm_output = await llm . agenerate ( [ messages ] , callbacks = [ langchain_tracer ] )
prompts = LangChainPlusClient . _get_prompts ( inputs )
converted_messages : List [ List [ BaseMessage ] ] = [
[ HumanMessage ( content = prompt ) ] for prompt in prompts
]
llm_output = await llm . agenerate (
converted_messages , callbacks = [ langchain_tracer ]
)
else :
else :
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
return llm_output
return llm_output
@ -562,18 +658,27 @@ class LangChainPlusClient(BaseSettings):
) - > Union [ LLMResult , ChatResult ] :
) - > Union [ LLMResult , ChatResult ] :
""" Run the language model on the example. """
""" Run the language model on the example. """
if isinstance ( llm , BaseLLM ) :
if isinstance ( llm , BaseLLM ) :
if " prompt " not in inputs :
try :
raise ValueError ( f " LLM Run must contain ' prompt ' key. Got { inputs } " )
llm_prompts = LangChainPlusClient . _get_prompts ( inputs )
llm_prompt : str = inputs [ " prompt " ]
llm_output = llm . generate ( llm_prompts , callbacks = [ langchain_tracer ] )
llm_output = llm . generate ( [ llm_prompt ] , callbacks = [ langchain_tracer ] )
except InputFormatError :
llm_messages = LangChainPlusClient . _get_messages ( inputs )
buffer_strings = [
get_buffer_string ( messages ) for messages in llm_messages
]
llm_output = llm . generate ( buffer_strings , callbacks = [ langchain_tracer ] )
elif isinstance ( llm , BaseChatModel ) :
elif isinstance ( llm , BaseChatModel ) :
if " messages " not in inputs :
try :
raise ValueError (
messages = LangChainPlusClient . _get_messages ( inputs )
f " Chat Model Run must contain ' messages ' key. Got { inputs } "
llm_output = llm . generate ( messages , callbacks = [ langchain_tracer ] )
except InputFormatError :
prompts = LangChainPlusClient . _get_prompts ( inputs )
converted_messages : List [ List [ BaseMessage ] ] = [
[ HumanMessage ( content = prompt ) ] for prompt in prompts
]
llm_output = llm . generate (
converted_messages , callbacks = [ langchain_tracer ]
)
)
raw_messages : List [ dict ] = inputs [ " messages " ]
messages = messages_from_dict ( raw_messages )
llm_output = llm . generate ( [ messages ] , callbacks = [ langchain_tracer ] )
else :
else :
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
raise ValueError ( f " Unsupported LLM type { type ( llm ) } " )
return llm_output
return llm_output