Weakref to tracer (#9954)

Prevent memory/thread leakage
pull/9974/head
William FH 1 year ago committed by GitHub
parent a05fed9369
commit c844aaa7a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
import weakref
from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Union
@ -18,8 +19,10 @@ from langchain.schema.messages import BaseMessage
logger = logging.getLogger(__name__)
_LOGGED = set()
_TRACERS: List[LangChainTracer] = []
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
_CLIENT: Optional[Client] = None
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
_EXECUTORS: List[ThreadPoolExecutor] = []
def log_error_once(method: str, exception: Exception) -> None:
@ -34,8 +37,9 @@ def log_error_once(method: str, exception: Exception) -> None:
def wait_for_all_tracers() -> None:
"""Wait for all tracers to finish."""
global _TRACERS
for tracer in _TRACERS:
tracer.wait_for_futures()
for tracer in list(_TRACERS):
if tracer is not None:
tracer.wait_for_futures()
def _get_client() -> Client:
@ -68,17 +72,22 @@ class LangChainTracer(BaseTracer):
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
)
if use_threading:
# set max_workers to 1 to process tasks in order
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
max_workers=1
)
global _MAX_EXECUTORS
if len(_EXECUTORS) < _MAX_EXECUTORS:
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
max_workers=1
)
_EXECUTORS.append(self.executor)
else:
self.executor = _EXECUTORS.pop(0)
_EXECUTORS.append(self.executor)
else:
self.executor = None
self.client = client or _get_client()
self._futures: Set[Future] = set()
self.tags = tags or []
global _TRACERS
_TRACERS.append(self)
_TRACERS.add(self)
def on_chat_model_start(
self,

Loading…
Cancel
Save