Weakref to tracer (#9954)

Prevent memory/thread leakage
This commit is contained in:
William FH 2023-08-29 19:27:22 -07:00 committed by GitHub
parent a05fed9369
commit c844aaa7a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
@ -18,8 +19,10 @@ from langchain.schema.messages import BaseMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_LOGGED = set() _LOGGED = set()
_TRACERS: List[LangChainTracer] = [] _TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
_CLIENT: Optional[Client] = None _CLIENT: Optional[Client] = None
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
_EXECUTORS: List[ThreadPoolExecutor] = []
def log_error_once(method: str, exception: Exception) -> None: def log_error_once(method: str, exception: Exception) -> None:
@ -34,7 +37,8 @@ def log_error_once(method: str, exception: Exception) -> None:
def wait_for_all_tracers() -> None: def wait_for_all_tracers() -> None:
"""Wait for all tracers to finish.""" """Wait for all tracers to finish."""
global _TRACERS global _TRACERS
for tracer in _TRACERS: for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures() tracer.wait_for_futures()
@ -68,17 +72,22 @@ class LangChainTracer(BaseTracer):
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default") "LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
) )
if use_threading: if use_threading:
# set max_workers to 1 to process tasks in order global _MAX_EXECUTORS
if len(_EXECUTORS) < _MAX_EXECUTORS:
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
max_workers=1 max_workers=1
) )
_EXECUTORS.append(self.executor)
else:
self.executor = _EXECUTORS.pop(0)
_EXECUTORS.append(self.executor)
else: else:
self.executor = None self.executor = None
self.client = client or _get_client() self.client = client or _get_client()
self._futures: Set[Future] = set() self._futures: Set[Future] = set()
self.tags = tags or [] self.tags = tags or []
global _TRACERS global _TRACERS
_TRACERS.append(self) _TRACERS.add(self)
def on_chat_model_start( def on_chat_model_start(
self, self,