Fix Async Shared Resource Bug (#4751)

Use an async queue to distribute tracers rather than inappropriately
sharing a single one
dynamic_agent_tools
Zander Chase 1 year ago committed by GitHub
parent 3f0357f94a
commit a128d95aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,6 @@ from pydantic import BaseSettings, Field, root_validator
from requests import Response from requests import Response
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
@ -351,14 +350,13 @@ class LangChainPlusClient(BaseSettings):
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)}) outputs.append({"Error": str(e)})
finally: langchain_tracer.example_id = previous_example_id
langchain_tracer.example_id = previous_example_id
return outputs return outputs
@staticmethod @staticmethod
async def _gather_with_concurrency( async def _gather_with_concurrency(
n: int, n: int,
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]], initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]],
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]], *async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
) -> List[Any]: ) -> List[Any]:
""" """
@ -373,21 +371,28 @@ class LangChainPlusClient(BaseSettings):
A list of results from the coroutines. A list of results from the coroutines.
""" """
semaphore = asyncio.Semaphore(n) semaphore = asyncio.Semaphore(n)
tracer, job_state = await initializer() job_state = {"num_processed": 0}
tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue()
for _ in range(n):
tracer_queue.put_nowait(await initializer())
async def run_coroutine_with_semaphore( async def run_coroutine_with_semaphore(
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]] async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
) -> Any: ) -> Any:
async with semaphore: async with semaphore:
return await async_func(tracer, job_state) tracer = await tracer_queue.get()
try:
result = await async_func(tracer, job_state)
finally:
tracer_queue.put_nowait(tracer)
return result
return await asyncio.gather( return await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs) *(run_coroutine_with_semaphore(function) for function in async_funcs)
) )
async def _tracer_initializer( async def _tracer_initializer(self, session_name: str) -> LangChainTracer:
self, session_name: str
) -> Tuple[LangChainTracer, dict]:
""" """
Initialize a tracer to share across tasks. Initialize a tracer to share across tasks.
@ -397,11 +402,9 @@ class LangChainPlusClient(BaseSettings):
Returns: Returns:
A LangChainTracer instance with an active session. A LangChainTracer instance with an active session.
""" """
job_state = {"num_processed": 0} tracer = LangChainTracer(session_name=session_name)
with tracing_v2_enabled(session_name=session_name) as session: tracer.ensure_session()
tracer = LangChainTracer() return tracer
tracer.session = session
return tracer, job_state
async def arun_on_dataset( async def arun_on_dataset(
self, self,
@ -513,8 +516,7 @@ class LangChainPlusClient(BaseSettings):
except Exception as e: except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}") logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)}) outputs.append({"Error": str(e)})
finally: langchain_tracer.example_id = previous_example_id
langchain_tracer.example_id = previous_example_id
return outputs return outputs
def run_on_dataset( def run_on_dataset(
@ -550,18 +552,16 @@ class LangChainPlusClient(BaseSettings):
dataset = self.read_dataset(dataset_name=dataset_name) dataset = self.read_dataset(dataset_name=dataset_name)
examples = list(self.list_examples(dataset_id=str(dataset.id))) examples = list(self.list_examples(dataset_id=str(dataset.id)))
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
with tracing_v2_enabled(session_name=session_name) as session: tracer = LangChainTracer(session_name=session_name)
tracer = LangChainTracer() tracer.ensure_session()
tracer.session = session for i, example in enumerate(examples):
result = self.run_llm_or_chain(
for i, example in enumerate(examples): example,
result = self.run_llm_or_chain( tracer,
example, llm_or_chain_factory,
tracer, num_repetitions,
llm_or_chain_factory, )
num_repetitions, if verbose:
) print(f"{i+1} processed", flush=True, end="\r")
if verbose: results[str(example.id)] = result
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
return results return results

Loading…
Cancel
Save