Add Invocation Params (#4509)

### Add Invocation Params to Logged Run


Adds an llm type to each chat model as well as an override of the dict()
method to log the invocation parameters for each call

---------

Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
Zander Chase 2023-05-11 15:34:06 -07:00 committed by GitHub
parent 59853fc876
commit f4d3cf2dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 7 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any, Dict from typing import Any, Dict, Mapping
from pydantic import root_validator from pydantic import root_validator
@ -110,3 +110,12 @@ class AzureChatOpenAI(ChatOpenAI):
**super()._default_params, **super()._default_params,
"engine": self.deployment_name, "engine": self.deployment_name,
} }
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**self._default_params}
@property
def _llm_type(self) -> str:
return "azure-openai-chat"

View File

@ -2,7 +2,7 @@ import asyncio
import inspect import inspect
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator from pydantic import Extra, Field, root_validator
@ -65,11 +65,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = callback_manager.on_chat_model_start( run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages {"name": self.__class__.__name__}, messages, invocation_params=params
) )
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -98,12 +101,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
callbacks: Callbacks = None, callbacks: Callbacks = None,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
run_manager = await callback_manager.on_chat_model_start( run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages {"name": self.__class__.__name__}, messages, invocation_params=params
) )
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(
@ -181,6 +186,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
result = self([HumanMessage(content=message)], stop=stop) result = self([HumanMessage(content=message)], stop=stop)
return result.content return result.content
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel): class SimpleChatModel(BaseChatModel):
def _generate( def _generate(

View File

@ -1,7 +1,7 @@
"""Wrapper around Google's PaLM Chat API.""" """Wrapper around Google's PaLM Chat API."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, root_validator from pydantic import BaseModel, root_validator
@ -256,3 +256,18 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
) )
return _response_to_result(response, stop) return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"

View File

@ -347,6 +347,11 @@ class ChatOpenAI(BaseChatModel):
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params} return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "openai-chat"
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Calculate num tokens with tiktoken package.""" """Calculate num tokens with tiktoken package."""
# tiktoken NOT supported for Python 3.7 or below # tiktoken NOT supported for Python 3.7 or below

View File

@ -1,6 +1,6 @@
"""PromptLayer wrapper.""" """PromptLayer wrapper."""
import datetime import datetime
from typing import List, Optional from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -109,3 +109,15 @@ class PromptLayerChatOpenAI(ChatOpenAI):
generation.generation_info = {} generation.generation_info = {}
generation.generation_info["pl_request_id"] = pl_request_id generation.generation_info["pl_request_id"] = pl_request_id
return generated_responses return generated_responses
@property
def _llm_type(self) -> str:
return "promptlayer-openai-chat"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {
**super()._identifying_params,
"pl_tags": self.pl_tags,
"return_pl_id": self.return_pl_id,
}

View File

@ -1,5 +1,5 @@
"""Fake Chat Model wrapper for testing purposes.""" """Fake Chat Model wrapper for testing purposes."""
from typing import List, Optional from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -30,3 +30,11 @@ class FakeChatModel(SimpleChatModel):
message = AIMessage(content=output_str) message = AIMessage(content=output_str)
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"key": "fake"}