forked from Archives/langchain
Add Invocation Params (#4509)
### Add Invocation Params to Logged Run Adds an llm type to each chat model as well as an override of the dict() method to log the invocation parameters for each call --------- Co-authored-by: Ankush Gola <ankush.gola@gmail.com>
This commit is contained in:
parent
59853fc876
commit
f4d3cf2dfb
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, Mapping
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
@ -110,3 +110,12 @@ class AzureChatOpenAI(ChatOpenAI):
|
|||||||
**super()._default_params,
|
**super()._default_params,
|
||||||
"engine": self.deployment_name,
|
"engine": self.deployment_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "azure-openai-chat"
|
||||||
|
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
@ -65,11 +65,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_chat_model_start(
|
run_manager = callback_manager.on_chat_model_start(
|
||||||
{"name": self.__class__.__name__}, messages
|
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
@ -98,12 +101,14 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
|
params = self.dict()
|
||||||
|
params["stop"] = stop
|
||||||
|
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
run_manager = await callback_manager.on_chat_model_start(
|
run_manager = await callback_manager.on_chat_model_start(
|
||||||
{"name": self.__class__.__name__}, messages
|
{"name": self.__class__.__name__}, messages, invocation_params=params
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
@ -181,6 +186,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
result = self([HumanMessage(content=message)], stop=stop)
|
result = self([HumanMessage(content=message)], stop=stop)
|
||||||
return result.content
|
return result.content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
|
||||||
|
def dict(self, **kwargs: Any) -> Dict:
|
||||||
|
"""Return a dictionary of the LLM."""
|
||||||
|
starter_dict = dict(self._identifying_params)
|
||||||
|
starter_dict["_type"] = self._llm_type
|
||||||
|
return starter_dict
|
||||||
|
|
||||||
|
|
||||||
class SimpleChatModel(BaseChatModel):
|
class SimpleChatModel(BaseChatModel):
|
||||||
def _generate(
|
def _generate(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Wrapper around Google's PaLM Chat API."""
|
"""Wrapper around Google's PaLM Chat API."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, root_validator
|
from pydantic import BaseModel, root_validator
|
||||||
|
|
||||||
@ -256,3 +256,18 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _response_to_result(response, stop)
|
return _response_to_result(response, stop)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"n": self.n,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "google-palm-chat"
|
||||||
|
@ -347,6 +347,11 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
return {**{"model_name": self.model_name}, **self._default_params}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "openai-chat"
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate num tokens with tiktoken package."""
|
"""Calculate num tokens with tiktoken package."""
|
||||||
# tiktoken NOT supported for Python 3.7 or below
|
# tiktoken NOT supported for Python 3.7 or below
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""PromptLayer wrapper."""
|
"""PromptLayer wrapper."""
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Optional
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -109,3 +109,15 @@ class PromptLayerChatOpenAI(ChatOpenAI):
|
|||||||
generation.generation_info = {}
|
generation.generation_info = {}
|
||||||
generation.generation_info["pl_request_id"] = pl_request_id
|
generation.generation_info["pl_request_id"] = pl_request_id
|
||||||
return generated_responses
|
return generated_responses
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "promptlayer-openai-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return {
|
||||||
|
**super()._identifying_params,
|
||||||
|
"pl_tags": self.pl_tags,
|
||||||
|
"return_pl_id": self.return_pl_id,
|
||||||
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Fake Chat Model wrapper for testing purposes."""
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
from typing import List, Optional
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
@ -30,3 +30,11 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
message = AIMessage(content=output_str)
|
message = AIMessage(content=output_str)
|
||||||
generation = ChatGeneration(message=message)
|
generation = ChatGeneration(message=message)
|
||||||
return ChatResult(generations=[generation])
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake-chat-model"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return {"key": "fake"}
|
||||||
|
Loading…
Reference in New Issue
Block a user