|
|
|
@ -26,7 +26,6 @@ from pydantic import BaseSettings, Field, root_validator
|
|
|
|
|
from requests import Response
|
|
|
|
|
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
|
|
|
from langchain.callbacks.manager import tracing_v2_enabled
|
|
|
|
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
|
|
|
|
from langchain.chains.base import Chain
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
@ -351,14 +350,13 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
|
|
|
|
outputs.append({"Error": str(e)})
|
|
|
|
|
finally:
|
|
|
|
|
langchain_tracer.example_id = previous_example_id
|
|
|
|
|
langchain_tracer.example_id = previous_example_id
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
async def _gather_with_concurrency(
|
|
|
|
|
n: int,
|
|
|
|
|
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]],
|
|
|
|
|
initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]],
|
|
|
|
|
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
|
|
|
|
|
) -> List[Any]:
|
|
|
|
|
"""
|
|
|
|
@ -373,21 +371,28 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
A list of results from the coroutines.
|
|
|
|
|
"""
|
|
|
|
|
semaphore = asyncio.Semaphore(n)
|
|
|
|
|
tracer, job_state = await initializer()
|
|
|
|
|
job_state = {"num_processed": 0}
|
|
|
|
|
|
|
|
|
|
tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue()
|
|
|
|
|
for _ in range(n):
|
|
|
|
|
tracer_queue.put_nowait(await initializer())
|
|
|
|
|
|
|
|
|
|
async def run_coroutine_with_semaphore(
|
|
|
|
|
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
|
|
|
|
|
) -> Any:
|
|
|
|
|
async with semaphore:
|
|
|
|
|
return await async_func(tracer, job_state)
|
|
|
|
|
tracer = await tracer_queue.get()
|
|
|
|
|
try:
|
|
|
|
|
result = await async_func(tracer, job_state)
|
|
|
|
|
finally:
|
|
|
|
|
tracer_queue.put_nowait(tracer)
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
return await asyncio.gather(
|
|
|
|
|
*(run_coroutine_with_semaphore(function) for function in async_funcs)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def _tracer_initializer(
|
|
|
|
|
self, session_name: str
|
|
|
|
|
) -> Tuple[LangChainTracer, dict]:
|
|
|
|
|
async def _tracer_initializer(self, session_name: str) -> LangChainTracer:
|
|
|
|
|
"""
|
|
|
|
|
Initialize a tracer to share across tasks.
|
|
|
|
|
|
|
|
|
@ -397,11 +402,9 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
Returns:
|
|
|
|
|
A LangChainTracer instance with an active session.
|
|
|
|
|
"""
|
|
|
|
|
job_state = {"num_processed": 0}
|
|
|
|
|
with tracing_v2_enabled(session_name=session_name) as session:
|
|
|
|
|
tracer = LangChainTracer()
|
|
|
|
|
tracer.session = session
|
|
|
|
|
return tracer, job_state
|
|
|
|
|
tracer = LangChainTracer(session_name=session_name)
|
|
|
|
|
tracer.ensure_session()
|
|
|
|
|
return tracer
|
|
|
|
|
|
|
|
|
|
async def arun_on_dataset(
|
|
|
|
|
self,
|
|
|
|
@ -513,8 +516,7 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
|
|
|
|
|
outputs.append({"Error": str(e)})
|
|
|
|
|
finally:
|
|
|
|
|
langchain_tracer.example_id = previous_example_id
|
|
|
|
|
langchain_tracer.example_id = previous_example_id
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def run_on_dataset(
|
|
|
|
@ -550,18 +552,16 @@ class LangChainPlusClient(BaseSettings):
|
|
|
|
|
dataset = self.read_dataset(dataset_name=dataset_name)
|
|
|
|
|
examples = list(self.list_examples(dataset_id=str(dataset.id)))
|
|
|
|
|
results: Dict[str, Any] = {}
|
|
|
|
|
with tracing_v2_enabled(session_name=session_name) as session:
|
|
|
|
|
tracer = LangChainTracer()
|
|
|
|
|
tracer.session = session
|
|
|
|
|
|
|
|
|
|
for i, example in enumerate(examples):
|
|
|
|
|
result = self.run_llm_or_chain(
|
|
|
|
|
example,
|
|
|
|
|
tracer,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
|
num_repetitions,
|
|
|
|
|
)
|
|
|
|
|
if verbose:
|
|
|
|
|
print(f"{i+1} processed", flush=True, end="\r")
|
|
|
|
|
results[str(example.id)] = result
|
|
|
|
|
tracer = LangChainTracer(session_name=session_name)
|
|
|
|
|
tracer.ensure_session()
|
|
|
|
|
for i, example in enumerate(examples):
|
|
|
|
|
result = self.run_llm_or_chain(
|
|
|
|
|
example,
|
|
|
|
|
tracer,
|
|
|
|
|
llm_or_chain_factory,
|
|
|
|
|
num_repetitions,
|
|
|
|
|
)
|
|
|
|
|
if verbose:
|
|
|
|
|
print(f"{i+1} processed", flush=True, end="\r")
|
|
|
|
|
results[str(example.id)] = result
|
|
|
|
|
return results
|
|
|
|
|