Add Invocation Params (#4509)

### Add Invocation Params to Logged Run


Adds an llm type to each chat model as well as an override of the dict()
method to log the invocation parameters for each call

---------

Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
Zander Chase 2023-05-11 15:34:06 -07:00 committed by GitHub
parent 59853fc876
commit f4d3cf2dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 7 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import Any, Dict
from typing import Any, Dict, Mapping
from pydantic import root_validator
@ -110,3 +110,12 @@ class AzureChatOpenAI(ChatOpenAI):
**super()._default_params,
"engine": self.deployment_name,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**self._default_params}
@property
def _llm_type(self) -> str:
return "azure-openai-chat"

View File

@ -2,7 +2,7 @@ import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, Field, root_validator
@ -65,11 +65,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -98,12 +101,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
@ -181,6 +186,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
result = self([HumanMessage(content=message)], stop=stop)
return result.content
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel):
def _generate(

View File

@ -1,7 +1,7 @@
"""Wrapper around Google's PaLM Chat API."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, root_validator
@ -256,3 +256,18 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
)
return _response_to_result(response, stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model_name": self.model_name,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"n": self.n,
}
@property
def _llm_type(self) -> str:
return "google-palm-chat"

View File

@ -347,6 +347,11 @@ class ChatOpenAI(BaseChatModel):
"""Get the identifying parameters."""
return {**{"model_name": self.model_name}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "openai-chat"
def get_num_tokens(self, text: str) -> int:
"""Calculate num tokens with tiktoken package."""
# tiktoken NOT supported for Python 3.7 or below

View File

@ -1,6 +1,6 @@
"""PromptLayer wrapper."""
import datetime
from typing import List, Optional
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -109,3 +109,15 @@ class PromptLayerChatOpenAI(ChatOpenAI):
generation.generation_info = {}
generation.generation_info["pl_request_id"] = pl_request_id
return generated_responses
@property
def _llm_type(self) -> str:
return "promptlayer-openai-chat"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {
**super()._identifying_params,
"pl_tags": self.pl_tags,
"return_pl_id": self.return_pl_id,
}

View File

@ -1,5 +1,5 @@
"""Fake Chat Model wrapper for testing purposes."""
from typing import List, Optional
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -30,3 +30,11 @@ class FakeChatModel(SimpleChatModel):
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"key": "fake"}