make textstat optional in the flyte callback handler (#7186)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->

This PR makes the `textstat` library optional in the Flyte callback
handler.

@hinthornw, would you mind reviewing this PR since you merged the flyte
callback handler code previously?

---------

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
pull/7120/head^2
Samhita Alla 1 year ago committed by GitHub
parent 6eff0fa2ca
commit 6f358bb04a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,6 +39,7 @@ def import_flytekit() -> Tuple[flytekit, renderer]:
def analyze_text(
text: str,
nlp: Any = None,
textstat: Any = None,
) -> dict:
"""Analyze text using textstat and spacy.
@ -51,26 +52,26 @@ def analyze_text(
files serialized to HTML string.
"""
resp: Dict[str, Any] = {}
textstat = import_textstat()
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if textstat is not None:
text_complexity_metrics = {
"flesch_reading_ease": textstat.flesch_reading_ease(text),
"flesch_kincaid_grade": textstat.flesch_kincaid_grade(text),
"smog_index": textstat.smog_index(text),
"coleman_liau_index": textstat.coleman_liau_index(text),
"automated_readability_index": textstat.automated_readability_index(text),
"dale_chall_readability_score": textstat.dale_chall_readability_score(text),
"difficult_words": textstat.difficult_words(text),
"linsear_write_formula": textstat.linsear_write_formula(text),
"gunning_fog": textstat.gunning_fog(text),
"fernandez_huerta": textstat.fernandez_huerta(text),
"szigriszt_pazos": textstat.szigriszt_pazos(text),
"gutierrez_polini": textstat.gutierrez_polini(text),
"crawford": textstat.crawford(text),
"gulpease_index": textstat.gulpease_index(text),
"osman": textstat.osman(text),
}
resp.update({"text_complexity_metrics": text_complexity_metrics})
resp.update(text_complexity_metrics)
if nlp is not None:
spacy = import_spacy()
@ -78,16 +79,13 @@ def analyze_text(
dep_out = spacy.displacy.render( # type: ignore
doc, style="dep", jupyter=False, page=True
)
ent_out = spacy.displacy.render( # type: ignore
doc, style="ent", jupyter=False, page=True
)
text_visualizations = {
"dependency_tree": dep_out,
"entities": ent_out,
}
resp.update(text_visualizations)
return resp
@ -98,10 +96,19 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
def __init__(self) -> None:
"""Initialize callback handler."""
import_textstat() # Raise error since it is required
flytekit, renderer = import_flytekit()
self.pandas = import_pandas()
self.textstat = None
try:
self.textstat = import_textstat()
except ImportError:
logger.warning(
"Textstat library is not installed. \
It may result in the inability to log \
certain metrics that can be captured with Textstat."
)
spacy = None
try:
spacy = import_spacy()
@ -123,7 +130,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
"FlyteCallbackHandler uses spacy's en_core_web_sm model"
" for certain metrics. To download,"
" run the following command in your terminal:"
" `python -m spacy download en_core_web_sm` command."
" `python -m spacy download en_core_web_sm`"
)
self.table_renderer = renderer.TableRenderer
@ -180,11 +187,10 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
for generation in generations:
generation_resp = deepcopy(resp)
generation_resp.update(flatten_dict(generation.dict()))
if self.nlp:
if self.nlp or self.textstat:
generation_resp.update(
analyze_text(
generation.text,
nlp=self.nlp,
generation.text, nlp=self.nlp, textstat=self.textstat
)
)

Loading…
Cancel
Save