forked from Archives/langchain
Add Invocation Params (#4509)
### Add Invocation Params to Logged Run Adds an llm type to each chat model as well as an override of the dict() method to log the invocation parameters for each call --------- Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
parent
59853fc876
commit
f4d3cf2dfb
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
@ -110,3 +110,12 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
**super()._default_params,
|
||||
"engine": self.deployment_name,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "azure-openai-chat"
|
||||
|
@ -2,7 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
@ -65,11 +65,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
@ -98,12 +101,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
params = self.dict()
|
||||
params["stop"] = stop
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = await callback_manager.on_chat_model_start(
|
||||
{"name": self.__class__.__name__}, messages
|
||||
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||
)
|
||||
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
@ -181,6 +186,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
result = self([HumanMessage(content=message)], stop=stop)
|
||||
return result.content
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return a dictionary of the LLM."""
|
||||
starter_dict = dict(self._identifying_params)
|
||||
starter_dict["_type"] = self._llm_type
|
||||
return starter_dict
|
||||
|
||||
|
||||
class SimpleChatModel(BaseChatModel):
|
||||
def _generate(
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Wrapper around Google's PaLM Chat API."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
@ -256,3 +256,18 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
||||
)
|
||||
|
||||
return _response_to_result(response, stop)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"n": self.n,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "google-palm-chat"
|
||||
|
@ -347,6 +347,11 @@ class ChatOpenAI(BaseChatModel):
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "openai-chat"
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Calculate num tokens with tiktoken package."""
|
||||
# tiktoken NOT supported for Python 3.7 or below
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""PromptLayer wrapper."""
|
||||
import datetime
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -109,3 +109,15 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
||||
generation.generation_info = {}
|
||||
generation.generation_info["pl_request_id"] = pl_request_id
|
||||
return generated_responses
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "promptlayer-openai-chat"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {
|
||||
**super()._identifying_params,
|
||||
"pl_tags": self.pl_tags,
|
||||
"return_pl_id": self.return_pl_id,
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@ -30,3 +30,11 @@ class FakeChatModel(SimpleChatModel):
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-chat-model"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return {"key": "fake"}
|
||||
|
Loading…
Reference in New Issue
Block a user