mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
split up batch llm calls into separate runs (#5804)
This commit is contained in:
parent
1da99ce013
commit
e1b801be36
@ -672,66 +672,72 @@ class CallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
managers = []
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
return managers
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Re-use the LLM Run Manager since the outputs are treated
|
||||
# the same for now
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
managers = []
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message_list],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return managers
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
@ -830,64 +836,84 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return managers
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
messages,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
for message_list in messages:
|
||||
run_id_ = uuid4()
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_chat_model_start",
|
||||
"ignore_chat_model",
|
||||
serialized,
|
||||
[message_list],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id=run_id_,
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
return managers
|
||||
|
||||
async def on_chain_start(
|
||||
self,
|
||||
|
@ -1,8 +1,8 @@
|
||||
"""Callback Handler that prints to std out."""
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
MODEL_COST_PER_1K_TOKENS = {
|
||||
# GPT-4 input
|
||||
@ -152,64 +152,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
|
||||
def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""Print out that we are entering a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Print out that we finished a chain."""
|
||||
pass
|
||||
|
||||
def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
input_str: str,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Print out the log in specified color."""
|
||||
pass
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
color: Optional[str] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""If not the final action, print out observation."""
|
||||
pass
|
||||
|
||||
def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> None:
|
||||
"""Do nothing."""
|
||||
pass
|
||||
|
||||
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run on agent action."""
|
||||
pass
|
||||
|
||||
def on_agent_finish(
|
||||
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Run on agent end."""
|
||||
pass
|
||||
|
||||
def __copy__(self) -> "OpenAICallbackHandler":
|
||||
"""Return a copy of the callback handler."""
|
||||
return self
|
||||
|
@ -101,26 +101,37 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
try:
|
||||
results = [
|
||||
self._generate_with_cache(
|
||||
m, stop=stop, run_manager=run_manager, **kwargs
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
try:
|
||||
results.append(
|
||||
self._generate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
if run_managers:
|
||||
run_managers[i].on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
if run_managers:
|
||||
run_infos = []
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
run_infos.append(RunInfo(run_id=manager.run_id))
|
||||
output.run = run_infos
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
@ -143,28 +154,62 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
tags,
|
||||
self.tags,
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
)
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
m, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
self._agenerate_with_cache(
|
||||
m,
|
||||
stop=stop,
|
||||
run_manager=run_managers[i] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
for i, m in enumerate(messages)
|
||||
],
|
||||
return_exceptions=True,
|
||||
)
|
||||
exceptions = []
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
if run_managers:
|
||||
await run_managers[i].on_llm_error(res)
|
||||
exceptions.append(res)
|
||||
if exceptions:
|
||||
if run_managers:
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[res.generations], llm_output=res.llm_output
|
||||
)
|
||||
)
|
||||
for run_manager, res in zip(run_managers, results)
|
||||
if not isinstance(res, Exception)
|
||||
]
|
||||
)
|
||||
raise exceptions[0]
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
@ -151,6 +152,39 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
def _generate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[CallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
# TODO: support multiple run managers
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -161,8 +195,6 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
@ -185,60 +217,77 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
self._generate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
output = self._generate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
)
|
||||
new_results = self._generate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
self._generate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
async def _agenerate_helper(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]],
|
||||
run_managers: List[AsyncCallbackManagerForLLMRun],
|
||||
new_arg_supported: bool,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts,
|
||||
stop=stop,
|
||||
run_manager=run_managers[0] if run_managers else None,
|
||||
**kwargs,
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
flattened_outputs = output.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -266,54 +315,32 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"run_manager"
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), prompts, invocation_params=params, options=options
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(
|
||||
prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
output = await self._agenerate_helper(
|
||||
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
dumpd(self), missing_prompts, invocation_params=params, options=options
|
||||
)
|
||||
new_results = await self._agenerate_helper(
|
||||
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
raise e
|
||||
await run_manager.on_llm_end(new_results)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
run_info = (
|
||||
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
|
||||
if run_managers
|
||||
else None
|
||||
)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
|
@ -227,9 +227,35 @@ class LLMResult(BaseModel):
|
||||
each input could have multiple generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
run: Optional[RunInfo] = None
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""Run metadata."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list."""
|
||||
llm_results = []
|
||||
for i, gen_list in enumerate(self.generations):
|
||||
# Avoid double counting tokens in OpenAICallback
|
||||
if i == 0:
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.llm_output is not None:
|
||||
llm_output = self.llm_output.copy()
|
||||
llm_output["token_usage"] = dict()
|
||||
else:
|
||||
llm_output = None
|
||||
llm_results.append(
|
||||
LLMResult(
|
||||
generations=[gen_list],
|
||||
llm_output=llm_output,
|
||||
)
|
||||
)
|
||||
return llm_results
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
|
@ -38,6 +38,21 @@ async def test_openai_callback() -> None:
|
||||
assert cb.total_tokens == total_tokens
|
||||
|
||||
|
||||
def test_openai_callback_batch_llm() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
with get_openai_callback() as cb:
|
||||
llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
|
||||
|
||||
assert cb.total_tokens > 0
|
||||
total_tokens = cb.total_tokens
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
llm("What is the square root of 4?")
|
||||
llm("What is the square root of 4?")
|
||||
|
||||
assert cb.total_tokens == total_tokens
|
||||
|
||||
|
||||
def test_openai_callback_agent() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
|
@ -96,6 +96,15 @@ def test_openai_streaming() -> None:
|
||||
assert isinstance(token["choices"][0]["text"], str)
|
||||
|
||||
|
||||
def test_openai_multiple_prompts() -> None:
|
||||
"""Test completion with multiple prompts."""
|
||||
llm = OpenAI(max_tokens=10)
|
||||
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
assert len(output.generations) == 2
|
||||
|
||||
|
||||
def test_openai_streaming_error() -> None:
|
||||
"""Test error handling in stream."""
|
||||
llm = OpenAI(best_of=2)
|
||||
|
@ -28,6 +28,10 @@ class FakeListLLM(LLM):
|
||||
print(self.responses[self.i])
|
||||
return self.responses[self.i]
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Return number of tokens in text."""
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
@ -18,11 +18,12 @@ def _test_callback_manager(
|
||||
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
run_managers = manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_manager.on_llm_new_token("foo")
|
||||
run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
|
||||
run_manager_chain.on_chain_end({})
|
||||
@ -42,11 +43,12 @@ async def _test_callback_manager_async(
|
||||
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
|
||||
) -> None:
|
||||
"""Test the CallbackManager."""
|
||||
run_manager = await manager.on_llm_start({}, [])
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
run_managers = await manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
await run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
await run_manager.on_llm_error(Exception())
|
||||
await run_manager.on_llm_new_token("foo")
|
||||
await run_manager.on_text("foo")
|
||||
|
||||
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
|
||||
await run_manager_chain.on_chain_end({})
|
||||
@ -95,9 +97,10 @@ def test_ignore_llm() -> None:
|
||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||
handler2 = FakeCallbackHandler()
|
||||
manager = CallbackManager(handlers=[handler1, handler2])
|
||||
run_manager = manager.on_llm_start({}, [])
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
run_managers = manager.on_llm_start({}, ["prompt"])
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(LLMResult(generations=[]))
|
||||
run_manager.on_llm_error(Exception())
|
||||
assert handler1.starts == 0
|
||||
assert handler1.ends == 0
|
||||
assert handler1.errors == 0
|
||||
|
@ -11,7 +11,7 @@ from freezegun import freeze_time
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import HumanMessage, LLMResult
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
@ -58,9 +58,13 @@ def test_tracer_llm_run() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = Run(
|
||||
id=str(uuid),
|
||||
id=str(run_managers[0].run_id),
|
||||
name="chat_model",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
@ -68,17 +72,13 @@ def test_tracer_chat_model_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs=dict(prompts=[""]),
|
||||
inputs=dict(prompts=["Human: "]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
|
||||
TracerSessionV1,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import HumanMessage, LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
@ -127,9 +127,15 @@ def test_tracer_llm_run() -> None:
|
||||
@freeze_time("2023-01-01")
|
||||
def test_tracer_chat_model_run() -> None:
|
||||
"""Test tracer on a Chat Model run."""
|
||||
uuid = uuid4()
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_managers = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
|
||||
)
|
||||
compare_run = LLMRun(
|
||||
uuid=str(uuid),
|
||||
uuid=str(run_managers[0].run_id),
|
||||
parent_uuid=None,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
@ -137,19 +143,13 @@ def test_tracer_chat_model_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED_CHAT,
|
||||
prompts=[""],
|
||||
prompts=["Human: "],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
|
||||
|
@ -49,6 +49,10 @@ class FakeLLM(BaseLLM):
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Return number of tokens."""
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
@ -28,6 +28,9 @@ class FakeLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {}
|
||||
|
@ -24,6 +24,10 @@ class FakeLLM(LLM):
|
||||
)
|
||||
return queries
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Return number of tokens."""
|
||||
return len(text.split())
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
|
Loading…
Reference in New Issue
Block a user