Adds OpenAI functions powered document metadata tagger (#7521)

Adds a new document transformer that automatically extracts metadata for
a document based on an input schema. I also moved
`document_transformers.py` to `document_transformers/__init__.py` to
group it with this new transformer - it didn't seem to cause issues in
the notebook, but let me know if I've done something wrong there.

Also had a linter issue I couldn't figure out:

```
MacBook-Pro:langchain jacoblee$ make lint
poetry run mypy .
docs/dist/conf.py: error: Duplicate module named "conf" (also at "./docs/api_reference/conf.py")
docs/dist/conf.py: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules for more info
docs/dist/conf.py: note: Common resolutions include: a) using `--exclude` to avoid checking one of them, b) adding `__init__.py` somewhere, c) using `--explicit-package-bases` or adjusting MYPYPATH
Found 1 error in 1 file (errors prevented further checking)
make: *** [lint] Error 2
```

@rlancemartin @baskaryan

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Jacob Lee 2023-07-12 22:12:41 -07:00 committed by GitHub
parent 8effd90be0
commit cdb93ab5ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 6 deletions

View File

@ -131,7 +131,7 @@ chain.run(number=2, callbacks=[handler])
The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:
- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler], tags=['a-tag'])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.
- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
- **Request callbacks**: defined in the `run()`/`apply()` methods used for issuing a request, eg. `chain.run(input, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.

View File

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
@ -26,7 +26,12 @@ Passage:
"""
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
def create_tagging_chain(
schema: dict,
llm: BaseLanguageModel,
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any
) -> Chain:
"""Creates a chain that extracts information from a passage.
Args:
@ -37,7 +42,7 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
Chain (LLMChain) that can be used to extract information from a passage.
"""
function = _get_tagging_function(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
@ -45,12 +50,16 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
**kwargs,
)
return chain
def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
pydantic_schema: Any,
llm: BaseLanguageModel,
prompt: Optional[ChatPromptTemplate] = None,
**kwargs: Any
) -> Chain:
"""Creates a chain that extracts information from a passage.
@ -63,7 +72,7 @@ def create_tagging_chain_pydantic(
"""
openai_schema = pydantic_schema.schema()
function = _get_tagging_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
@ -71,5 +80,6 @@ def create_tagging_chain_pydantic(
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
**kwargs,
)
return chain

View File

@ -16,4 +16,7 @@ __all__ = [
"EmbeddingsClusteringFilter",
"EmbeddingsRedundantFilter",
"get_stateful_documents",
"OpenAIMetadataTagger",
]
from langchain.document_transformers.openai_functions import OpenAIMetadataTagger

View File

@ -0,0 +1,141 @@
"""Document transformers that use OpenAI Functions models"""
from typing import Any, Dict, Optional, Sequence, Type, Union
from pydantic import BaseModel
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions import create_tagging_chain
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
"""Extract metadata tags from document contents using OpenAI functions.
Example:
.. code-block:: python
from langchain.chat_models import ChatOpenAI
from langchain.document_transformers import OpenAIMetadataTagger
from langchain.schema import Document
schema = {
"properties": {
"movie_title": { "type": "string" },
"critic": { "type": "string" },
"tone": {
"type": "string",
"enum": ["positive", "negative"]
},
"rating": {
"type": "integer",
"description": "The number of stars the critic rated the movie"
}
},
"required": ["movie_title", "critic", "tone"]
}
# Must be an OpenAI model that supports functions
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
tagging_chain = create_tagging_chain(schema, llm)
document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain)
original_documents = [
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
]
enhanced_documents = document_transformer.transform_documents(original_documents)
""" # noqa: E501
tagging_chain: LLMChain
"""The chain used to extract metadata from each document."""
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Automatically extract and populate metadata
for each document according to the provided schema."""
new_documents = []
for document in documents:
extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501
new_document = Document(
page_content=document.page_content,
metadata={**extracted_metadata, **document.metadata},
)
new_documents.append(new_document)
return new_documents
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
raise NotImplementedError
def create_metadata_tagger(
metadata_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: BaseLanguageModel,
prompt: Optional[ChatPromptTemplate] = None,
*,
tagging_chain_kwargs: Optional[Dict] = None,
) -> OpenAIMetadataTagger:
"""Create a DocumentTransformer that uses an OpenAI function chain to automatically
tag documents with metadata based on their content and an input schema.
Args:
metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
is passed in, it's assumed to already be a valid JsonSchema.
For best results, pydantic.BaseModels should have docstrings describing what
the schema represents and descriptions for the parameters.
llm: Language model to use, assumed to support the OpenAI function-calling API.
Defaults to use "gpt-3.5-turbo-0613"
prompt: BasePromptTemplate to pass to the model.
Returns:
An LLMChain that will pass the given function to the model.
Example:
.. code-block:: python
from langchain.chat_models import ChatOpenAI
from langchain.document_transformers import create_metadata_tagger
from langchain.schema import Document
schema = {
"properties": {
"movie_title": { "type": "string" },
"critic": { "type": "string" },
"tone": {
"type": "string",
"enum": ["positive", "negative"]
},
"rating": {
"type": "integer",
"description": "The number of stars the critic rated the movie"
}
},
"required": ["movie_title", "critic", "tone"]
}
# Must be an OpenAI model that supports functions
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
document_transformer = create_metadata_tagger(schema, llm)
original_documents = [
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
]
enhanced_documents = document_transformer.transform_documents(original_documents)
""" # noqa: E501
metadata_schema = (
metadata_schema
if isinstance(metadata_schema, dict)
else metadata_schema.schema()
)
_tagging_chain_kwargs = tagging_chain_kwargs or {}
tagging_chain = create_tagging_chain(
metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs
)
return OpenAIMetadataTagger(tagging_chain=tagging_chain)