diff --git a/docs/ecosystem/comet_tracking.ipynb b/docs/ecosystem/comet_tracking.ipynb new file mode 100644 index 00000000..b7009c72 --- /dev/null +++ b/docs/ecosystem/comet_tracking.ipynb @@ -0,0 +1,352 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Using Comet with Langchain" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![](https://user-images.githubusercontent.com/7529846/230328046-a8b18c51-12e3-4617-9b39-97614a571a2d.png)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this guide we will demonstrate how to track your Langchain Experiments, Evaluation Metrics, and LLM Sessions with [Comet](https://www.comet.com/site/?utm_source=langchain&utm_medium=referral&utm_campaign=comet_notebook). \n", + "\n", + "\n", + " \"Open\n", + "\n", + "\n", + "**Example Project:** [Comet with LangChain](https://www.comet.com/examples/comet-example-langchain/view/b5ZThK6OFdhKWVSP3fDfRtrNF/panels?utm_source=langchain&utm_medium=referral&utm_campaign=comet_notebook)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"comet-langchain\"\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Install Comet and Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install comet_ml\n", + "!pip install langchain\n", + "!pip install openai\n", + "!pip install google-search-results" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize Comet and Set your Credentials" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can grab your [Comet API Key here](https://www.comet.com/signup?utm_source=langchain&utm_medium=referral&utm_campaign=comet_notebook) or click the link after intializing Comet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import comet_ml\n", + "\n", + "comet_ml.init(project_name=\"comet-example-langchain\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set OpenAI and SerpAPI credentials" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You will need an [OpenAI API Key](https://platform.openai.com/account/api-keys) and a [SerpAPI API Key](https://serpapi.com/dashboard) to run the following examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "%env OPENAI_API_KEY=\"...\"\n", + "%env SERPAPI_API_KEY=\"...\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scenario 1: Using just an LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.llms import OpenAI\n", + "\n", + "comet_callback = CometCallbackHandler(\n", + " project_name=\"comet-example-langchain\",\n", + " complexity_metrics=True,\n", + " stream_logs=True,\n", + " tags=[\"llm\"],\n", + " visualizations=[\"dep\"],\n", + ")\n", + "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", + "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "\n", + "llm_result = llm.generate([\"Tell me a joke\", \"Tell me a poem\", \"Tell me a fact\"] * 3)\n", + "print(\"LLM result\", llm_result)\n", + "comet_callback.flush_tracker(llm, finish=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scenario 2: Using an LLM in a Chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.chains import LLMChain\n", + "from langchain.llms import OpenAI\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "comet_callback = CometCallbackHandler(\n", + " complexity_metrics=True,\n", + " project_name=\"comet-example-langchain\",\n", + " stream_logs=True,\n", + " tags=[\"synopsis-chain\"],\n", + ")\n", + "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", + "\n", + "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "\n", + "template = \"\"\"You are a playwright. Given the title of play, it is your job to write a synopsis for that title.\n", + "Title: {title}\n", + "Playwright: This is a synopsis for the above play:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "\n", + "test_prompts = [{\"title\": \"Documentary about Bigfoot in Paris\"}]\n", + "synopsis_chain.apply(test_prompts)\n", + "comet_callback.flush_tracker(synopsis_chain, finish=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scenario 3: Using An Agent with Tools " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.agents import initialize_agent, load_tools\n", + "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.llms import OpenAI\n", + "\n", + "comet_callback = CometCallbackHandler(\n", + " project_name=\"comet-example-langchain\",\n", + " complexity_metrics=True,\n", + " stream_logs=True,\n", + " tags=[\"agent\"],\n", + ")\n", + "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", + "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "\n", + "tools = load_tools([\"serpapi\", \"llm-math\"], llm=llm, callback_manager=manager)\n", + "agent = initialize_agent(\n", + " tools,\n", + " llm,\n", + " agent=\"zero-shot-react-description\",\n", + " callback_manager=manager,\n", + " verbose=True,\n", + ")\n", + "agent.run(\n", + " \"Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?\"\n", + ")\n", + "comet_callback.flush_tracker(agent, finish=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Scenario 4: Using Custom Evaluation Metrics" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `CometCallbackManager` also allows you to define and use Custom Evaluation Metrics to assess generated outputs from your model. Let's take a look at how this works. \n", + "\n", + "\n", + "In the snippet below, we will use the [ROUGE](https://huggingface.co/spaces/evaluate-metric/rouge) metric to evaluate the quality of a generated summary of an input prompt. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install rouge-score" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from rouge_score import rouge_scorer\n", + "\n", + "from langchain.callbacks import CometCallbackHandler, StdOutCallbackHandler\n", + "from langchain.callbacks.base import CallbackManager\n", + "from langchain.chains import LLMChain\n", + "from langchain.llms import OpenAI\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "\n", + "class Rouge:\n", + " def __init__(self, reference):\n", + " self.reference = reference\n", + " self.scorer = rouge_scorer.RougeScorer([\"rougeLsum\"], use_stemmer=True)\n", + "\n", + " def compute_metric(self, generation, prompt_idx, gen_idx):\n", + " prediction = generation.text\n", + " results = self.scorer.score(target=self.reference, prediction=prediction)\n", + "\n", + " return {\n", + " \"rougeLsum_score\": results[\"rougeLsum\"].fmeasure,\n", + " \"reference\": self.reference,\n", + " }\n", + "\n", + "\n", + "reference = \"\"\"\n", + "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building.\n", + "It was the first structure to reach a height of 300 metres.\n", + "\n", + "It is now taller than the Chrysler Building in New York City by 5.2 metres (17 ft)\n", + "Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France .\n", + "\"\"\"\n", + "rouge_score = Rouge(reference=reference)\n", + "\n", + "template = \"\"\"Given the following article, it is your job to write a summary.\n", + "Article:\n", + "{article}\n", + "Summary: This is the summary for the above article:\"\"\"\n", + "prompt_template = PromptTemplate(input_variables=[\"article\"], template=template)\n", + "\n", + "comet_callback = CometCallbackHandler(\n", + " project_name=\"comet-example-langchain\",\n", + " complexity_metrics=False,\n", + " stream_logs=True,\n", + " tags=[\"custom_metrics\"],\n", + " custom_metrics=rouge_score.compute_metric,\n", + ")\n", + "manager = CallbackManager([StdOutCallbackHandler(), comet_callback])\n", + "llm = OpenAI(temperature=0.9, callback_manager=manager, verbose=True)\n", + "\n", + "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, callback_manager=manager)\n", + "\n", + "test_prompts = [\n", + " {\n", + " \"article\": \"\"\"\n", + " The tower is 324 metres (1,063 ft) tall, about the same height as\n", + " an 81-storey building, and the tallest structure in Paris. Its base is square,\n", + " measuring 125 metres (410 ft) on each side.\n", + " During its construction, the Eiffel Tower surpassed the\n", + " Washington Monument to become the tallest man-made structure in the world,\n", + " a title it held for 41 years until the Chrysler Building\n", + " in New York City was finished in 1930.\n", + "\n", + " It was the first structure to reach a height of 300 metres.\n", + " Due to the addition of a broadcasting aerial at the top of the tower in 1957,\n", + " it is now taller than the Chrysler Building by 5.2 metres (17 ft).\n", + "\n", + " Excluding transmitters, the Eiffel Tower is the second tallest\n", + " free-standing structure in France after the Millau Viaduct.\n", + " \"\"\"\n", + " }\n", + "]\n", + "synopsis_chain.apply(test_prompts)\n", + "comet_callback.flush_tracker(synopsis_chain, finish=True)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/callbacks/__init__.py b/langchain/callbacks/__init__.py index a3764c92..c6137bf9 100644 --- a/langchain/callbacks/__init__.py +++ b/langchain/callbacks/__init__.py @@ -11,6 +11,7 @@ from langchain.callbacks.base import ( CallbackManager, ) from langchain.callbacks.clearml_callback import ClearMLCallbackHandler +from langchain.callbacks.comet_ml_callback import CometCallbackHandler from langchain.callbacks.openai_info import OpenAICallbackHandler from langchain.callbacks.shared import SharedCallbackManager from langchain.callbacks.stdout import StdOutCallbackHandler @@ -78,6 +79,7 @@ __all__ = [ "AimCallbackHandler", "WandbCallbackHandler", "ClearMLCallbackHandler", + "CometCallbackHandler", "AsyncIteratorCallbackHandler", "get_openai_callback", "set_tracing_callback_manager", diff --git a/langchain/callbacks/comet_ml_callback.py b/langchain/callbacks/comet_ml_callback.py new file mode 100644 index 00000000..c716d43d --- /dev/null +++ b/langchain/callbacks/comet_ml_callback.py @@ -0,0 +1,627 @@ +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import langchain +from langchain.callbacks.base import BaseCallbackHandler +from langchain.callbacks.utils import ( + BaseMetadataCallbackHandler, + flatten_dict, + import_pandas, + import_spacy, + import_textstat, +) +from langchain.schema import AgentAction, AgentFinish, Generation, LLMResult + +LANGCHAIN_MODEL_NAME = "langchain-model" + + +def import_comet_ml() -> Any: + try: + import comet_ml # noqa: F401 + except ImportError: + raise ImportError( + "To use the comet_ml callback manager you need to have the " + "`comet_ml` python package installed. Please install it with" + " `pip install comet_ml`" + ) + return comet_ml + + +def _get_experiment( + workspace: Optional[str] = None, project_name: Optional[str] = None +) -> Any: + comet_ml = import_comet_ml() + + experiment = comet_ml.config.get_global_experiment() + if experiment is None: + experiment = comet_ml.Experiment( # type: ignore + workspace=workspace, + project_name=project_name, + ) + + return experiment + + +def _fetch_text_complexity_metrics(text: str) -> dict: + textstat = import_textstat() + text_complexity_metrics = { + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "smog_index": textstat.smog_index(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + "automated_readability_index": textstat.automated_readability_index(text), + "dale_chall_readability_score": textstat.dale_chall_readability_score(text), + "difficult_words": textstat.difficult_words(text), + "linsear_write_formula": textstat.linsear_write_formula(text), + "gunning_fog": textstat.gunning_fog(text), + "text_standard": textstat.text_standard(text), + "fernandez_huerta": textstat.fernandez_huerta(text), + "szigriszt_pazos": textstat.szigriszt_pazos(text), + "gutierrez_polini": textstat.gutierrez_polini(text), + "crawford": textstat.crawford(text), + "gulpease_index": textstat.gulpease_index(text), + "osman": textstat.osman(text), + } + return text_complexity_metrics + + +def _summarize_metrics_for_generated_outputs(metrics: Sequence) -> dict: + pd = import_pandas() + metrics_df = pd.DataFrame(metrics) + metrics_summary = metrics_df.describe() + + return metrics_summary.to_dict() + + +class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): + """Callback Handler that logs to Comet. + + Parameters: + job_type (str): The type of comet_ml task such as "inference", + "testing" or "qc" + project_name (str): The comet_ml project name + tags (list): Tags to add to the task + task_name (str): Name of the comet_ml task + visualize (bool): Whether to visualize the run. + complexity_metrics (bool): Whether to log complexity metrics + stream_logs (bool): Whether to stream callback actions to Comet + + This handler will utilize the associated callback method and formats + the input of each callback function with metadata regarding the state of LLM run, + and adds the response to the list of records for both the {method}_records and + action. It then logs the response to Comet. + """ + + def __init__( + self, + task_type: Optional[str] = "inference", + workspace: Optional[str] = None, + project_name: Optional[str] = "comet-langchain-demo", + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + stream_logs: bool = True, + ) -> None: + """Initialize callback handler.""" + + comet_ml = import_comet_ml() + super().__init__() + + self.task_type = task_type + self.workspace = workspace + self.project_name = project_name + self.tags = tags + self.visualizations = visualizations + self.complexity_metrics = complexity_metrics + self.custom_metrics = custom_metrics + self.stream_logs = stream_logs + self.temp_dir = tempfile.TemporaryDirectory() + + self.experiment = _get_experiment(workspace, project_name) + self.experiment.log_other("Created from", "langchain") + if tags: + self.experiment.add_tags(tags) + self.name = name + if self.name: + self.experiment.set_name(self.name) + + warning = ( + "The comet_ml callback is currently in beta and is subject to change " + "based on updates to `langchain`. Please report any issues to " + "https://github.com/comet_ml/issue_tracking/issues with the tag " + "`langchain`." + ) + comet_ml.LOGGER.warning(warning) + + self.callback_columns: list = [] + self.action_records: list = [] + self.complexity_metrics = complexity_metrics + if self.visualizations: + spacy = import_spacy() + self.nlp = spacy.load("en_core_web_sm") + else: + self.nlp = None + + def _init_resp(self) -> Dict: + return {k: None for k in self.callback_columns} + + def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> None: + """Run when LLM starts.""" + self.step += 1 + self.llm_starts += 1 + self.starts += 1 + + metadata = self._init_resp() + metadata.update({"action": "on_llm_start"}) + metadata.update(flatten_dict(serialized)) + metadata.update(self.get_custom_callback_meta()) + + for prompt in prompts: + prompt_resp = deepcopy(metadata) + prompt_resp["prompts"] = prompt + self.on_llm_start_records.append(prompt_resp) + self.action_records.append(prompt_resp) + + if self.stream_logs: + self._log_stream(prompt, metadata, self.step) + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + """Run when LLM generates a new token.""" + self.step += 1 + self.llm_streams += 1 + + resp = self._init_resp() + resp.update({"action": "on_llm_new_token", "token": token}) + resp.update(self.get_custom_callback_meta()) + + self.action_records.append(resp) + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run when LLM ends running.""" + self.step += 1 + self.llm_ends += 1 + self.ends += 1 + + metadata = self._init_resp() + metadata.update({"action": "on_llm_end"}) + metadata.update(flatten_dict(response.llm_output or {})) + metadata.update(self.get_custom_callback_meta()) + + output_complexity_metrics = [] + output_custom_metrics = [] + + for prompt_idx, generations in enumerate(response.generations): + for gen_idx, generation in enumerate(generations): + text = generation.text + + generation_resp = deepcopy(metadata) + generation_resp.update(flatten_dict(generation.dict())) + + complexity_metrics = self._get_complexity_metrics(text) + if complexity_metrics: + output_complexity_metrics.append(complexity_metrics) + generation_resp.update(complexity_metrics) + + custom_metrics = self._get_custom_metrics( + generation, prompt_idx, gen_idx + ) + if custom_metrics: + output_custom_metrics.append(custom_metrics) + generation_resp.update(custom_metrics) + + if self.stream_logs: + self._log_stream(text, metadata, self.step) + + self.action_records.append(generation_resp) + self.on_llm_end_records.append(generation_resp) + + self._log_text_metrics(output_complexity_metrics, step=self.step) + self._log_text_metrics(output_custom_metrics, step=self.step) + + def on_llm_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when LLM errors.""" + self.step += 1 + self.errors += 1 + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Run when chain starts running.""" + self.step += 1 + self.chain_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_chain_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + + comet_ml = import_comet_ml() + + for chain_input_key, chain_input_val in inputs.items(): + if isinstance(chain_input_val, str): + input_resp = deepcopy(resp) + if self.stream_logs: + self._log_stream(chain_input_val, resp, self.step) + input_resp.update({chain_input_key: chain_input_val}) + self.action_records.append(input_resp) + + else: + comet_ml.LOGGER.warning( + f"Unexpected data format provided! " + f"Input Value for {chain_input_key} will not be logged" + ) + + def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + """Run when chain ends running.""" + self.step += 1 + self.chain_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_chain_end"}) + resp.update(self.get_custom_callback_meta()) + + comet_ml = import_comet_ml() + + for chain_output_key, chain_output_val in outputs.items(): + if isinstance(chain_output_val, str): + output_resp = deepcopy(resp) + if self.stream_logs: + self._log_stream(chain_output_val, resp, self.step) + output_resp.update({chain_output_key: chain_output_val}) + self.action_records.append(output_resp) + else: + comet_ml.LOGGER.warning( + f"Unexpected data format provided! " + f"Output Value for {chain_output_key} will not be logged" + ) + + def on_chain_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when chain errors.""" + self.step += 1 + self.errors += 1 + + def on_tool_start( + self, serialized: Dict[str, Any], input_str: str, **kwargs: Any + ) -> None: + """Run when tool starts running.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_start"}) + resp.update(flatten_dict(serialized)) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(input_str, resp, self.step) + + resp.update({"input_str": input_str}) + self.action_records.append(resp) + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """Run when tool ends running.""" + self.step += 1 + self.tool_ends += 1 + self.ends += 1 + + resp = self._init_resp() + resp.update({"action": "on_tool_end"}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(output, resp, self.step) + + resp.update({"output": output}) + self.action_records.append(resp) + + def on_tool_error( + self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + ) -> None: + """Run when tool errors.""" + self.step += 1 + self.errors += 1 + + def on_text(self, text: str, **kwargs: Any) -> None: + """ + Run when agent is ending. + """ + self.step += 1 + self.text_ctr += 1 + + resp = self._init_resp() + resp.update({"action": "on_text"}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(text, resp, self.step) + + resp.update({"text": text}) + self.action_records.append(resp) + + def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: + """Run when agent ends running.""" + self.step += 1 + self.agent_ends += 1 + self.ends += 1 + + resp = self._init_resp() + output = finish.return_values["output"] + log = finish.log + + resp.update({"action": "on_agent_finish", "log": log}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(output, resp, self.step) + + resp.update({"output": output}) + self.action_records.append(resp) + + def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: + """Run on agent action.""" + self.step += 1 + self.tool_starts += 1 + self.starts += 1 + + tool = action.tool + tool_input = action.tool_input + log = action.log + + resp = self._init_resp() + resp.update({"action": "on_agent_action", "log": log, "tool": tool}) + resp.update(self.get_custom_callback_meta()) + if self.stream_logs: + self._log_stream(tool_input, resp, self.step) + + resp.update({"tool_input": tool_input}) + self.action_records.append(resp) + + def _get_complexity_metrics(self, text: str) -> dict: + """Compute text complexity metrics using textstat. + + Parameters: + text (str): The text to analyze. + + Returns: + (dict): A dictionary containing the complexity metrics. + """ + resp = {} + if self.complexity_metrics: + text_complexity_metrics = _fetch_text_complexity_metrics(text) + resp.update(text_complexity_metrics) + + return resp + + def _get_custom_metrics( + self, generation: Generation, prompt_idx: int, gen_idx: int + ) -> dict: + """Compute Custom Metrics for an LLM Generated Output + + Args: + generation (LLMResult): Output generation from an LLM + prompt_idx (int): List index of the input prompt + gen_idx (int): List index of the generated output + + Returns: + dict: A dictionary containing the custom metrics. + """ + + resp = {} + if self.custom_metrics: + custom_metrics = self.custom_metrics(generation, prompt_idx, gen_idx) + resp.update(custom_metrics) + + return resp + + def flush_tracker( + self, + langchain_asset: Any = None, + task_type: Optional[str] = "inference", + workspace: Optional[str] = None, + project_name: Optional[str] = "comet-langchain-demo", + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + finish: bool = False, + reset: bool = False, + ) -> None: + """Flush the tracker and setup the session. + + Everything after this will be a new table. + + Args: + name: Name of the preformed session so far so it is identifyable + langchain_asset: The langchain asset to save. + finish: Whether to finish the run. + + Returns: + None + """ + self._log_session(langchain_asset) + + if langchain_asset: + self._log_model(langchain_asset) + + if finish: + self.experiment.end() + + if reset: + self._reset( + task_type, + workspace, + project_name, + tags, + name, + visualizations, + complexity_metrics, + custom_metrics, + ) + + def _log_stream(self, prompt: str, metadata: dict, step: int) -> None: + self.experiment.log_text(prompt, metadata=metadata, step=step) + + def _log_model(self, langchain_asset: Any) -> None: + comet_ml = import_comet_ml() + + model_parameters = self._get_llm_parameters(langchain_asset) + self.experiment.log_parameters(model_parameters, prefix="model") + + langchain_asset_path = Path(self.temp_dir.name, "model.json") + model_name = self.name if self.name else LANGCHAIN_MODEL_NAME + + try: + if hasattr(langchain_asset, "save"): + langchain_asset.save(langchain_asset_path) + self.experiment.log_model(model_name, str(langchain_asset_path)) + except (ValueError, AttributeError, NotImplementedError) as e: + if hasattr(langchain_asset, "save_agent"): + langchain_asset.save_agent(langchain_asset_path) + self.experiment.log_model(model_name, str(langchain_asset_path)) + else: + comet_ml.LOGGER.warning( + f"{e}" + " Could not save Langchain Asset " + f"for {langchain_asset.__class__.__name__}" + ) + + def _log_session(self, langchain_asset: Optional[Any] = None) -> None: + llm_session_df = self._create_session_analysis_dataframe(langchain_asset) + # Log the cleaned dataframe as a table + self.experiment.log_table("langchain-llm-session.csv", llm_session_df) + + metadata = {"langchain_version": str(langchain.__version__)} + # Log the langchain low-level records as a JSON file directly + self.experiment.log_asset_data( + self.action_records, "langchain-action_records.json", metadata=metadata + ) + + self._log_visualizations(llm_session_df) + + def _log_text_metrics(self, metrics: Sequence[dict], step: int) -> None: + if not metrics: + return + + metrics_summary = _summarize_metrics_for_generated_outputs(metrics) + for key, value in metrics_summary.items(): + self.experiment.log_metrics(value, prefix=key, step=step) + + def _log_visualizations(self, session_df: Any) -> None: + if not (self.visualizations and self.nlp): + return + + spacy = import_spacy() + comet_ml = import_comet_ml() + + prompts = session_df["prompts"].tolist() + outputs = session_df["text"].tolist() + + for idx, (prompt, output) in enumerate(zip(prompts, outputs)): + doc = self.nlp(output) + sentence_spans = list(doc.sents) + + for visualization in self.visualizations: + try: + html = spacy.displacy.render( + sentence_spans, + style=visualization, + options={"compact": True}, + jupyter=False, + page=True, + ) + self.experiment.log_asset_data( + html, + name=f"langchain-viz-{visualization}-{idx}.html", + metadata={"prompt": prompt}, + step=idx, + ) + except Exception as e: + comet_ml.LOGGER.warning(e) + + return + + def _reset( + self, + task_type: Optional[str] = None, + workspace: Optional[str] = None, + project_name: Optional[str] = None, + tags: Optional[Sequence] = None, + name: Optional[str] = None, + visualizations: Optional[List[str]] = None, + complexity_metrics: bool = False, + custom_metrics: Optional[Callable] = None, + ) -> None: + _task_type = task_type if task_type else self.task_type + _workspace = workspace if workspace else self.workspace + _project_name = project_name if project_name else self.project_name + _tags = tags if tags else self.tags + _name = name if name else self.name + _visualizations = visualizations if visualizations else self.visualizations + _complexity_metrics = ( + complexity_metrics if complexity_metrics else self.complexity_metrics + ) + _custom_metrics = custom_metrics if custom_metrics else self.custom_metrics + + self.__init__( # type: ignore + task_type=_task_type, + workspace=_workspace, + project_name=_project_name, + tags=_tags, + name=_name, + visualizations=_visualizations, + complexity_metrics=_complexity_metrics, + custom_metrics=_custom_metrics, + ) + + self.reset_callback_meta() + self.temp_dir = tempfile.TemporaryDirectory() + + def _create_session_analysis_dataframe(self, langchain_asset: Any = None) -> dict: + pd = import_pandas() + + llm_parameters = self._get_llm_parameters(langchain_asset) + num_generations_per_prompt = llm_parameters.get("n", 1) + + llm_start_records_df = pd.DataFrame(self.on_llm_start_records) + # Repeat each input row based on the number of outputs generated per prompt + llm_start_records_df = llm_start_records_df.loc[ + llm_start_records_df.index.repeat(num_generations_per_prompt) + ].reset_index(drop=True) + llm_end_records_df = pd.DataFrame(self.on_llm_end_records) + + llm_session_df = pd.merge( + llm_start_records_df, + llm_end_records_df, + left_index=True, + right_index=True, + suffixes=["_llm_start", "_llm_end"], + ) + + return llm_session_df + + def _get_llm_parameters(self, langchain_asset: Any = None) -> dict: + if not langchain_asset: + return {} + try: + if hasattr(langchain_asset, "agent"): + llm_parameters = langchain_asset.agent.llm_chain.llm.dict() + elif hasattr(langchain_asset, "llm_chain"): + llm_parameters = langchain_asset.llm_chain.llm.dict() + elif hasattr(langchain_asset, "llm"): + llm_parameters = langchain_asset.llm.dict() + else: + llm_parameters = langchain_asset.dict() + except Exception: + return {} + + return llm_parameters