From c844aaa7a6a0ab847f8077ac07db947a7f9aa9e8 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Tue, 29 Aug 2023 19:27:22 -0700 Subject: [PATCH] Weakref to tracer (#9954) Prevent memory/thread leakage --- .../langchain/callbacks/tracers/langchain.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 57b57ee270..0f57697721 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import os +import weakref from concurrent.futures import Future, ThreadPoolExecutor, wait from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -18,8 +19,10 @@ from langchain.schema.messages import BaseMessage logger = logging.getLogger(__name__) _LOGGED = set() -_TRACERS: List[LangChainTracer] = [] +_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet() _CLIENT: Optional[Client] = None +_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented +_EXECUTORS: List[ThreadPoolExecutor] = [] def log_error_once(method: str, exception: Exception) -> None: @@ -34,8 +37,9 @@ def log_error_once(method: str, exception: Exception) -> None: def wait_for_all_tracers() -> None: """Wait for all tracers to finish.""" global _TRACERS - for tracer in _TRACERS: - tracer.wait_for_futures() + for tracer in list(_TRACERS): + if tracer is not None: + tracer.wait_for_futures() def _get_client() -> Client: @@ -68,17 +72,22 @@ class LangChainTracer(BaseTracer): "LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default") ) if use_threading: - # set max_workers to 1 to process tasks in order - self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( - max_workers=1 - ) + global _MAX_EXECUTORS + if len(_EXECUTORS) < _MAX_EXECUTORS: + self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor( + max_workers=1 + ) + _EXECUTORS.append(self.executor) + else: + self.executor = _EXECUTORS.pop(0) + _EXECUTORS.append(self.executor) else: self.executor = None self.client = client or _get_client() self._futures: Set[Future] = set() self.tags = tags or [] global _TRACERS - _TRACERS.append(self) + _TRACERS.add(self) def on_chat_model_start( self,