From 6f358bb04a1b74df090992dd7ee871b7763273e4 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 5 Jul 2023 22:45:56 +0530 Subject: [PATCH] make textstat optional in the flyte callback handler (#7186) This PR makes the `textstat` library optional in the Flyte callback handler. @hinthornw, would you mind reviewing this PR since you merged the flyte callback handler code previously? --------- Signed-off-by: Samhita Alla --- langchain/callbacks/flyte_callback.py | 62 +++++++++++++++------------ 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/langchain/callbacks/flyte_callback.py b/langchain/callbacks/flyte_callback.py index 89123db668..d806f0a3d9 100644 --- a/langchain/callbacks/flyte_callback.py +++ b/langchain/callbacks/flyte_callback.py @@ -39,6 +39,7 @@ def import_flytekit() -> Tuple[flytekit, renderer]: def analyze_text( text: str, nlp: Any = None, + textstat: Any = None, ) -> dict: """Analyze text using textstat and spacy. @@ -51,26 +52,26 @@ def analyze_text( files serialized to HTML string. """ resp: Dict[str, Any] = {} - textstat = import_textstat() - text_complexity_metrics = { - "flesch_reading_ease": textstat.flesch_reading_ease(text), - "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), - "smog_index": textstat.smog_index(text), - "coleman_liau_index": textstat.coleman_liau_index(text), - "automated_readability_index": textstat.automated_readability_index(text), - "dale_chall_readability_score": textstat.dale_chall_readability_score(text), - "difficult_words": textstat.difficult_words(text), - "linsear_write_formula": textstat.linsear_write_formula(text), - "gunning_fog": textstat.gunning_fog(text), - "fernandez_huerta": textstat.fernandez_huerta(text), - "szigriszt_pazos": textstat.szigriszt_pazos(text), - "gutierrez_polini": textstat.gutierrez_polini(text), - "crawford": textstat.crawford(text), - "gulpease_index": textstat.gulpease_index(text), - "osman": textstat.osman(text), - } - resp.update({"text_complexity_metrics": text_complexity_metrics}) - resp.update(text_complexity_metrics) + if textstat is not None: + text_complexity_metrics = { + "flesch_reading_ease": textstat.flesch_reading_ease(text), + "flesch_kincaid_grade": textstat.flesch_kincaid_grade(text), + "smog_index": textstat.smog_index(text), + "coleman_liau_index": textstat.coleman_liau_index(text), + "automated_readability_index": textstat.automated_readability_index(text), + "dale_chall_readability_score": textstat.dale_chall_readability_score(text), + "difficult_words": textstat.difficult_words(text), + "linsear_write_formula": textstat.linsear_write_formula(text), + "gunning_fog": textstat.gunning_fog(text), + "fernandez_huerta": textstat.fernandez_huerta(text), + "szigriszt_pazos": textstat.szigriszt_pazos(text), + "gutierrez_polini": textstat.gutierrez_polini(text), + "crawford": textstat.crawford(text), + "gulpease_index": textstat.gulpease_index(text), + "osman": textstat.osman(text), + } + resp.update({"text_complexity_metrics": text_complexity_metrics}) + resp.update(text_complexity_metrics) if nlp is not None: spacy = import_spacy() @@ -78,16 +79,13 @@ def analyze_text( dep_out = spacy.displacy.render( # type: ignore doc, style="dep", jupyter=False, page=True ) - ent_out = spacy.displacy.render( # type: ignore doc, style="ent", jupyter=False, page=True ) - text_visualizations = { "dependency_tree": dep_out, "entities": ent_out, } - resp.update(text_visualizations) return resp @@ -98,10 +96,19 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): def __init__(self) -> None: """Initialize callback handler.""" - import_textstat() # Raise error since it is required flytekit, renderer = import_flytekit() self.pandas = import_pandas() + self.textstat = None + try: + self.textstat = import_textstat() + except ImportError: + logger.warning( + "Textstat library is not installed. \ + It may result in the inability to log \ + certain metrics that can be captured with Textstat." + ) + spacy = None try: spacy = import_spacy() @@ -123,7 +130,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): "FlyteCallbackHandler uses spacy's en_core_web_sm model" " for certain metrics. To download," " run the following command in your terminal:" - " `python -m spacy download en_core_web_sm` command." + " `python -m spacy download en_core_web_sm`" ) self.table_renderer = renderer.TableRenderer @@ -180,11 +187,10 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): for generation in generations: generation_resp = deepcopy(resp) generation_resp.update(flatten_dict(generation.dict())) - if self.nlp: + if self.nlp or self.textstat: generation_resp.update( analyze_text( - generation.text, - nlp=self.nlp, + generation.text, nlp=self.nlp, textstat=self.textstat ) )