[LLMonitorCallbackHandler] Various improvements (#13151)

Small improvements for the llmonitor callback handler, like better
support for non-openai models.


---------

Co-authored-by: vincelwt <vince@lyser.io>
This commit is contained in:
Hugues Chocart 2023-11-17 08:39:36 +01:00 committed by GitHub
parent c1b041c188
commit 35e04f204b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ import os
import traceback
import warnings
from contextvars import ContextVar
from typing import Any, Dict, List, Literal, Union
from typing import Any, Dict, List, Union, cast
from uuid import UUID
import requests
@ -15,11 +15,30 @@ from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
logger = logging.getLogger(__name__)
DEFAULT_API_URL = "https://app.llmonitor.com"
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
PARAMS_TO_CAPTURE = [
"temperature",
"top_p",
"top_k",
"stop",
"presence_penalty",
"frequence_penalty",
"seed",
"function_call",
"functions",
"tools",
"tool_choice",
"response_format",
"max_tokens",
"logit_bias",
]
class UserContextManager:
"""Context manager for LLMonitor user context."""
@ -66,6 +85,10 @@ def _parse_input(raw_input: Any) -> Any:
if not raw_input:
return None
# if it's an array of 1, just parse the first element
if isinstance(raw_input, list) and len(raw_input) == 1:
return _parse_input(raw_input[0])
if not isinstance(raw_input, dict):
return _serialize(raw_input)
@ -115,17 +138,11 @@ def _parse_output(raw_output: dict) -> Any:
def _parse_lc_role(
role: str,
) -> Union[Literal["user", "ai", "system", "function"], None]:
) -> str:
if role == "human":
return "user"
elif role == "ai":
return "ai"
elif role == "system":
return "system"
elif role == "function":
return "function"
else:
return None
return role
def _get_user_id(metadata: Any) -> Any:
@ -148,13 +165,15 @@ def _get_user_props(metadata: Any) -> Any:
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
function_call = (message.additional_kwargs or {}).get("function_call")
if function_call is not None:
parsed["functionCall"] = function_call
parsed.update(
{
key: cast(Any, message.additional_kwargs.get(key))
for key in keys
if message.additional_kwargs.get(key) is not None
}
)
return parsed
@ -213,19 +232,20 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
self.__track_event = llmonitor.track_event
except ImportError:
warnings.warn(
logger.warning(
"""[LLMonitor] To use the LLMonitor callback handler you need to
have the `llmonitor` Python package installed. Please install it
with `pip install llmonitor`"""
)
self.__has_valid_config = False
return
if parse(self.__llmonitor_version) < parse("0.0.20"):
warnings.warn(
if parse(self.__llmonitor_version) < parse("0.0.32"):
logger.warning(
f"""[LLMonitor] The installed `llmonitor` version is
{self.__llmonitor_version} but `LLMonitorCallbackHandler` requires
at least version 0.0.20 upgrade `llmonitor` with `pip install
--upgrade llmonitor`"""
{self.__llmonitor_version}
but `LLMonitorCallbackHandler` requires at least version 0.0.32
upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
)
self.__has_valid_config = False
@ -236,9 +256,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
_app_id = app_id or os.getenv("LLMONITOR_APP_ID")
if _app_id is None:
warnings.warn(
"""[LLMonitor] app_id must be provided either as an argument or as
an environment variable"""
logger.warning(
"""[LLMonitor] app_id must be provided either as an argument or
as an environment variable"""
)
self.__has_valid_config = False
else:
@ -252,7 +272,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
if not res.ok:
raise ConnectionError()
except Exception:
warnings.warn(
logger.warning(
f"""[LLMonitor] Could not connect to the LLMonitor API at
{self.__api_url}"""
)
@ -273,7 +293,27 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_input(prompts)
self.__track_event(
@ -285,8 +325,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
@ -304,10 +346,31 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
) -> Any:
if self.__has_valid_config is False:
return
try:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_lc_messages(messages[0])
self.__track_event(
@ -319,13 +382,13 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
name=name,
input=input,
tags=tags,
extra=extra,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
logging.warning(
f"[LLMonitor] An error occurred in on_chat_model_start: {e}"
)
logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
def on_llm_end(
self,
@ -340,25 +403,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
try:
token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output = [
{
"text": generation.text,
"role": "ai",
**(
{
"functionCall": generation.message.additional_kwargs[
"function_call"
]
}
parsed_output: Any = [
_parse_lc_message(generation.message)
if hasattr(generation, "message")
and hasattr(generation.message, "additional_kwargs")
and "function_call" in generation.message.additional_kwargs
else {}
),
}
else generation.text
for generation in response.generations[0]
]
# if it's an array of 1, just parse the first element
if len(parsed_output) == 1:
parsed_output = parsed_output[0]
self.__track_event(
"llm",
"end",
@ -369,9 +425,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"prompt": token_usage.get("prompt_tokens"),
"completion": token_usage.get("completion_tokens"),
},
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_end: {e}")
logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
def on_tool_start(
self,
@ -402,9 +459,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_tool_start: {e}")
logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
def on_tool_end(
self,
@ -424,9 +482,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_tool_end: {e}")
logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
def on_chain_start(
self,
@ -473,9 +532,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
tags=tags,
metadata=metadata,
user_props=user_props,
app_id=self.__app_id,
)
except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_chain_start: {e}")
logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
def on_chain_end(
self,
@ -496,9 +556,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_end: {e}")
logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
def on_agent_action(
self,
@ -521,9 +582,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name,
input=input,
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_action: {e}")
logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
def on_agent_finish(
self,
@ -544,9 +606,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output,
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
def on_chain_error(
self,
@ -565,9 +628,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_error: {e}")
logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
def on_tool_error(
self,
@ -586,9 +650,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_tool_error: {e}")
logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
def on_llm_error(
self,
@ -607,9 +672,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
)
except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_llm_error: {e}")
logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
__all__ = ["LLMonitorCallbackHandler", "identify"]