split up batch llm calls into separate runs (#5804)

This commit is contained in:
Ankush Gola 2023-06-24 21:03:31 -07:00 committed by GitHub
parent 1da99ce013
commit e1b801be36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 401 additions and 293 deletions

View File

@ -672,66 +672,72 @@ class CallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
managers = []
for prompt in prompts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
return managers
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
managers = []
for message_list in messages:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
return managers
def on_chain_start(
self,
@ -830,64 +836,84 @@ class AsyncCallbackManager(BaseCallbackManager):
self,
serialized: Dict[str, Any],
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
tasks = []
managers = []
return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
for prompt in prompts:
run_id_ = uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
tasks = []
managers = []
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
for message_list in messages:
run_id_ = uuid4()
return AsyncCallbackManagerForLLMRun(
run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
tasks.append(
_ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
[message_list],
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id=run_id_,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chain_start(
self,

View File

@ -1,8 +1,8 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import LLMResult
MODEL_COST_PER_1K_TOKENS = {
# GPT-4 input
@ -152,64 +152,6 @@ class OpenAICallbackHandler(BaseCallbackHandler):
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Print out the log in specified color."""
pass
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
pass
def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run on agent action."""
pass
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
pass
def __copy__(self) -> "OpenAICallbackHandler":
"""Return a copy of the callback handler."""
return self

View File

@ -101,26 +101,37 @@ class BaseChatModel(BaseLanguageModel, ABC):
tags,
self.tags,
)
run_manager = callback_manager.on_chat_model_start(
run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
)
try:
results = [
self._generate_with_cache(
m, stop=stop, run_manager=run_manager, **kwargs
results = []
for i, m in enumerate(messages):
try:
results.append(
self._generate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
except (KeyboardInterrupt, Exception) as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
run_infos.append(RunInfo(run_id=manager.run_id))
output.run = run_infos
return output
async def agenerate(
@ -143,28 +154,62 @@ class BaseChatModel(BaseLanguageModel, ABC):
tags,
self.tags,
)
run_manager = await callback_manager.on_chat_model_start(
run_managers = await callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
)
try:
results = await asyncio.gather(
*[
self._agenerate_with_cache(
m, stop=stop, run_manager=run_manager, **kwargs
)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
results = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
stop=stop,
run_manager=run_managers[i] if run_managers else None,
**kwargs,
)
for i, m in enumerate(messages)
],
return_exceptions=True,
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, Exception):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], llm_output=res.llm_output
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate_prompt(

View File

@ -1,4 +1,5 @@
"""Base interface for large language models to expose."""
import asyncio
import inspect
import json
import warnings
@ -151,6 +152,39 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def _generate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[CallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
self._generate(
prompts,
stop=stop,
# TODO: support multiple run managers
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
def generate(
self,
prompts: List[str],
@ -161,8 +195,6 @@ class BaseLLM(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
@ -185,60 +217,77 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
self._generate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
run_managers = callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
)
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
try:
new_results = (
self._generate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def _agenerate_helper(
self,
prompts: List[str],
stop: Optional[List[str]],
run_managers: List[AsyncCallbackManagerForLLMRun],
new_arg_supported: bool,
**kwargs: Any,
) -> LLMResult:
try:
output = (
await self._agenerate(
prompts,
stop=stop,
run_manager=run_managers[0] if run_managers else None,
**kwargs,
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
async def agenerate(
self,
prompts: List[str],
@ -266,54 +315,32 @@ class BaseLLM(BaseLanguageModel, ABC):
"run_manager"
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
)
try:
output = (
await self._agenerate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
dumpd(self),
missing_prompts,
invocation_params=params,
options=options,
run_managers = await callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
)
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
await run_manager.on_llm_end(new_results)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
run_info = (
[RunInfo(run_id=run_manager.run_id) for run_manager in run_managers]
if run_managers
else None
)
else:
llm_output = {}
run_info = None

View File

@ -227,9 +227,35 @@ class LLMResult(BaseModel):
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
run: Optional[List[RunInfo]] = None
"""Run metadata."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
llm_results = []
for i, gen_list in enumerate(self.generations):
# Avoid double counting tokens in OpenAICallback
if i == 0:
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=self.llm_output,
)
)
else:
if self.llm_output is not None:
llm_output = self.llm_output.copy()
llm_output["token_usage"] = dict()
else:
llm_output = None
llm_results.append(
LLMResult(
generations=[gen_list],
llm_output=llm_output,
)
)
return llm_results
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented

View File

@ -38,6 +38,21 @@ async def test_openai_callback() -> None:
assert cb.total_tokens == total_tokens
def test_openai_callback_batch_llm() -> None:
llm = OpenAI(temperature=0)
with get_openai_callback() as cb:
llm.generate(["What is the square root of 4?", "What is the square root of 4?"])
assert cb.total_tokens > 0
total_tokens = cb.total_tokens
with get_openai_callback() as cb:
llm("What is the square root of 4?")
llm("What is the square root of 4?")
assert cb.total_tokens == total_tokens
def test_openai_callback_agent() -> None:
llm = OpenAI(temperature=0)
tools = load_tools(["serpapi", "llm-math"], llm=llm)

View File

@ -96,6 +96,15 @@ def test_openai_streaming() -> None:
assert isinstance(token["choices"][0]["text"], str)
def test_openai_multiple_prompts() -> None:
"""Test completion with multiple prompts."""
llm = OpenAI(max_tokens=10)
output = llm.generate(["I'm Pickle Rick", "I'm Pickle Rick"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)
assert len(output.generations) == 2
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)

View File

@ -28,6 +28,10 @@ class FakeListLLM(LLM):
print(self.responses[self.i])
return self.responses[self.i]
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens in text."""
return len(text.split())
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@ -18,11 +18,12 @@ def _test_callback_manager(
manager: CallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_manager.on_llm_new_token("foo")
run_manager.on_text("foo")
run_manager_chain = manager.on_chain_start({"name": "foo"}, {})
run_manager_chain.on_chain_end({})
@ -42,11 +43,12 @@ async def _test_callback_manager_async(
manager: AsyncCallbackManager, *handlers: BaseFakeCallbackHandler
) -> None:
"""Test the CallbackManager."""
run_manager = await manager.on_llm_start({}, [])
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_managers = await manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
await run_manager.on_llm_end(LLMResult(generations=[]))
await run_manager.on_llm_error(Exception())
await run_manager.on_llm_new_token("foo")
await run_manager.on_text("foo")
run_manager_chain = await manager.on_chain_start({"name": "foo"}, {})
await run_manager_chain.on_chain_end({})
@ -95,9 +97,10 @@ def test_ignore_llm() -> None:
handler1 = FakeCallbackHandler(ignore_llm_=True)
handler2 = FakeCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2])
run_manager = manager.on_llm_start({}, [])
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
run_managers = manager.on_llm_start({}, ["prompt"])
for run_manager in run_managers:
run_manager.on_llm_end(LLMResult(generations=[]))
run_manager.on_llm_error(Exception())
assert handler1.starts == 0
assert handler1.ends == 0
assert handler1.errors == 0

View File

@ -11,7 +11,7 @@ from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
@ -58,9 +58,13 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = Run(
id=str(uuid),
id=str(run_managers[0].run_id),
name="chat_model",
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@ -68,17 +72,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=[""]),
inputs=dict(prompts=["Human: "]),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]

View File

@ -18,7 +18,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema import HumanMessage, LLMResult
TEST_SESSION_ID = 2023
@ -127,9 +127,15 @@ def test_tracer_llm_run() -> None:
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = LLMRun(
uuid=str(uuid),
uuid=str(run_managers[0].run_id),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
@ -137,19 +143,13 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
prompts=[""],
prompts=["Human: "],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]

View File

@ -49,6 +49,10 @@ class FakeLLM(BaseLLM):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""

View File

@ -28,6 +28,9 @@ class FakeLLM(LLM):
"""Return type of llm."""
return "fake"
def get_num_tokens(self, text: str) -> int:
return len(text.split())
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {}

View File

@ -24,6 +24,10 @@ class FakeLLM(LLM):
)
return queries
def get_num_tokens(self, text: str) -> int:
"""Return number of tokens."""
return len(text.split())
@property
def _llm_type(self) -> str:
"""Return type of llm."""