From 35e04f204ba3e69356a4f8f557ea88f46d2fa389 Mon Sep 17 00:00:00 2001 From: Hugues Chocart Date: Fri, 17 Nov 2023 08:39:36 +0100 Subject: [PATCH] [LLMonitorCallbackHandler] Various improvements (#13151) Small improvements for the llmonitor callback handler, like better support for non-openai models. --------- Co-authored-by: vincelwt --- .../langchain/callbacks/llmonitor_callback.py | 178 ++++++++++++------ 1 file changed, 122 insertions(+), 56 deletions(-) diff --git a/libs/langchain/langchain/callbacks/llmonitor_callback.py b/libs/langchain/langchain/callbacks/llmonitor_callback.py index ee7bc4dd38..886e583f77 100644 --- a/libs/langchain/langchain/callbacks/llmonitor_callback.py +++ b/libs/langchain/langchain/callbacks/llmonitor_callback.py @@ -4,7 +4,7 @@ import os import traceback import warnings from contextvars import ContextVar -from typing import Any, Dict, List, Literal, Union +from typing import Any, Dict, List, Union, cast from uuid import UUID import requests @@ -15,11 +15,30 @@ from langchain.schema.agent import AgentAction, AgentFinish from langchain.schema.messages import BaseMessage from langchain.schema.output import LLMResult +logger = logging.getLogger(__name__) + DEFAULT_API_URL = "https://app.llmonitor.com" user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None) user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None) +PARAMS_TO_CAPTURE = [ + "temperature", + "top_p", + "top_k", + "stop", + "presence_penalty", + "frequence_penalty", + "seed", + "function_call", + "functions", + "tools", + "tool_choice", + "response_format", + "max_tokens", + "logit_bias", +] + class UserContextManager: """Context manager for LLMonitor user context.""" @@ -66,6 +85,10 @@ def _parse_input(raw_input: Any) -> Any: if not raw_input: return None + # if it's an array of 1, just parse the first element + if isinstance(raw_input, list) and len(raw_input) == 1: + return _parse_input(raw_input[0]) + if not isinstance(raw_input, dict): return _serialize(raw_input) @@ -115,17 +138,11 @@ def _parse_output(raw_output: dict) -> Any: def _parse_lc_role( role: str, -) -> Union[Literal["user", "ai", "system", "function"], None]: +) -> str: if role == "human": return "user" - elif role == "ai": - return "ai" - elif role == "system": - return "system" - elif role == "function": - return "function" else: - return None + return role def _get_user_id(metadata: Any) -> Any: @@ -148,13 +165,15 @@ def _get_user_props(metadata: Any) -> Any: def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]: + keys = ["function_call", "tool_calls", "tool_call_id", "name"] parsed = {"text": message.content, "role": _parse_lc_role(message.type)} - - function_call = (message.additional_kwargs or {}).get("function_call") - - if function_call is not None: - parsed["functionCall"] = function_call - + parsed.update( + { + key: cast(Any, message.additional_kwargs.get(key)) + for key in keys + if message.additional_kwargs.get(key) is not None + } + ) return parsed @@ -213,19 +232,20 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): self.__track_event = llmonitor.track_event except ImportError: - warnings.warn( + logger.warning( """[LLMonitor] To use the LLMonitor callback handler you need to have the `llmonitor` Python package installed. Please install it with `pip install llmonitor`""" ) self.__has_valid_config = False + return - if parse(self.__llmonitor_version) < parse("0.0.20"): - warnings.warn( + if parse(self.__llmonitor_version) < parse("0.0.32"): + logger.warning( f"""[LLMonitor] The installed `llmonitor` version is - {self.__llmonitor_version} but `LLMonitorCallbackHandler` requires - at least version 0.0.20 upgrade `llmonitor` with `pip install - --upgrade llmonitor`""" + {self.__llmonitor_version} + but `LLMonitorCallbackHandler` requires at least version 0.0.32 + upgrade `llmonitor` with `pip install --upgrade llmonitor`""" ) self.__has_valid_config = False @@ -236,9 +256,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): _app_id = app_id or os.getenv("LLMONITOR_APP_ID") if _app_id is None: - warnings.warn( - """[LLMonitor] app_id must be provided either as an argument or as - an environment variable""" + logger.warning( + """[LLMonitor] app_id must be provided either as an argument or + as an environment variable""" ) self.__has_valid_config = False else: @@ -252,7 +272,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): if not res.ok: raise ConnectionError() except Exception: - warnings.warn( + logger.warning( f"""[LLMonitor] Could not connect to the LLMonitor API at {self.__api_url}""" ) @@ -273,7 +293,27 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): try: user_id = _get_user_id(metadata) user_props = _get_user_props(metadata) - name = kwargs.get("invocation_params", {}).get("model_name") + + params = kwargs.get("invocation_params", {}) + params.update( + serialized.get("kwargs", {}) + ) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty + + name = ( + params.get("model") + or params.get("model_name") + or params.get("model_id") + ) + + if not name and "anthropic" in params.get("_type"): + name = "claude-2" + + extra = { + param: params.get(param) + for param in PARAMS_TO_CAPTURE + if params.get(param) is not None + } + input = _parse_input(prompts) self.__track_event( @@ -285,8 +325,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): name=name, input=input, tags=tags, + extra=extra, metadata=metadata, user_props=user_props, + app_id=self.__app_id, ) except Exception as e: warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}") @@ -304,10 +346,31 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): ) -> Any: if self.__has_valid_config is False: return + try: user_id = _get_user_id(metadata) user_props = _get_user_props(metadata) - name = kwargs.get("invocation_params", {}).get("model_name") + + params = kwargs.get("invocation_params", {}) + params.update( + serialized.get("kwargs", {}) + ) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty + + name = ( + params.get("model") + or params.get("model_name") + or params.get("model_id") + ) + + if not name and "anthropic" in params.get("_type"): + name = "claude-2" + + extra = { + param: params.get(param) + for param in PARAMS_TO_CAPTURE + if params.get(param) is not None + } + input = _parse_lc_messages(messages[0]) self.__track_event( @@ -319,13 +382,13 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): name=name, input=input, tags=tags, + extra=extra, metadata=metadata, user_props=user_props, + app_id=self.__app_id, ) except Exception as e: - logging.warning( - f"[LLMonitor] An error occurred in on_chat_model_start: {e}" - ) + logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}") def on_llm_end( self, @@ -340,25 +403,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): try: token_usage = (response.llm_output or {}).get("token_usage", {}) - parsed_output = [ - { - "text": generation.text, - "role": "ai", - **( - { - "functionCall": generation.message.additional_kwargs[ - "function_call" - ] - } - if hasattr(generation, "message") - and hasattr(generation.message, "additional_kwargs") - and "function_call" in generation.message.additional_kwargs - else {} - ), - } + + parsed_output: Any = [ + _parse_lc_message(generation.message) + if hasattr(generation, "message") + else generation.text for generation in response.generations[0] ] + # if it's an array of 1, just parse the first element + if len(parsed_output) == 1: + parsed_output = parsed_output[0] + self.__track_event( "llm", "end", @@ -369,9 +425,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): "prompt": token_usage.get("prompt_tokens"), "completion": token_usage.get("completion_tokens"), }, + app_id=self.__app_id, ) except Exception as e: - warnings.warn(f"[LLMonitor] An error occurred in on_llm_end: {e}") + logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}") def on_tool_start( self, @@ -402,9 +459,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): tags=tags, metadata=metadata, user_props=user_props, + app_id=self.__app_id, ) except Exception as e: - warnings.warn(f"[LLMonitor] An error occurred in on_tool_start: {e}") + logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}") def on_tool_end( self, @@ -424,9 +482,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, output=output, + app_id=self.__app_id, ) except Exception as e: - warnings.warn(f"[LLMonitor] An error occurred in on_tool_end: {e}") + logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}") def on_chain_start( self, @@ -473,9 +532,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): tags=tags, metadata=metadata, user_props=user_props, + app_id=self.__app_id, ) except Exception as e: - warnings.warn(f"[LLMonitor] An error occurred in on_chain_start: {e}") + logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}") def on_chain_end( self, @@ -496,9 +556,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, output=output, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_chain_end: {e}") + logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}") def on_agent_action( self, @@ -521,9 +582,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): parent_run_id=str(parent_run_id) if parent_run_id else None, name=name, input=input, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_agent_action: {e}") + logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}") def on_agent_finish( self, @@ -544,9 +606,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, output=output, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_agent_finish: {e}") + logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}") def on_chain_error( self, @@ -565,9 +628,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, error={"message": str(error), "stack": traceback.format_exc()}, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_chain_error: {e}") + logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}") def on_tool_error( self, @@ -586,9 +650,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, error={"message": str(error), "stack": traceback.format_exc()}, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_tool_error: {e}") + logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}") def on_llm_error( self, @@ -607,9 +672,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): run_id=str(run_id), parent_run_id=str(parent_run_id) if parent_run_id else None, error={"message": str(error), "stack": traceback.format_exc()}, + app_id=self.__app_id, ) except Exception as e: - logging.warning(f"[LLMonitor] An error occurred in on_llm_error: {e}") + logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}") __all__ = ["LLMonitorCallbackHandler", "identify"]