mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Adds OpenAI functions powered document metadata tagger (#7521)
Adds a new document transformer that automatically extracts metadata for a document based on an input schema. I also moved `document_transformers.py` to `document_transformers/__init__.py` to group it with this new transformer - it didn't seem to cause issues in the notebook, but let me know if I've done something wrong there. Also had a linter issue I couldn't figure out: ``` MacBook-Pro:langchain jacoblee$ make lint poetry run mypy . docs/dist/conf.py: error: Duplicate module named "conf" (also at "./docs/api_reference/conf.py") docs/dist/conf.py: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules for more info docs/dist/conf.py: note: Common resolutions include: a) using `--exclude` to avoid checking one of them, b) adding `__init__.py` somewhere, c) using `--explicit-package-bases` or adjusting MYPYPATH Found 1 error in 1 file (errors prevented further checking) make: *** [lint] Error 2 ``` @rlancemartin @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
8effd90be0
commit
cdb93ab5ca
@ -131,7 +131,7 @@ chain.run(number=2, callbacks=[handler])
|
||||
The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:
|
||||
|
||||
- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler], tags=['a-tag'])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.
|
||||
- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
|
||||
- **Request callbacks**: defined in the `run()`/`apply()` methods used for issuing a request, eg. `chain.run(input, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
|
||||
|
||||
The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
@ -26,7 +26,12 @@ Passage:
|
||||
"""
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
def create_tagging_chain(
|
||||
schema: dict,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
|
||||
Args:
|
||||
@ -37,7 +42,7 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
Chain (LLMChain) that can be used to extract information from a passage.
|
||||
"""
|
||||
function = _get_tagging_function(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
@ -45,12 +50,16 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
pydantic_schema: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> Chain:
|
||||
"""Creates a chain that extracts information from a passage.
|
||||
|
||||
@ -63,7 +72,7 @@ def create_tagging_chain_pydantic(
|
||||
"""
|
||||
openai_schema = pydantic_schema.schema()
|
||||
function = _get_tagging_function(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
@ -71,5 +80,6 @@ def create_tagging_chain_pydantic(
|
||||
prompt=prompt,
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
return chain
|
||||
|
@ -16,4 +16,7 @@ __all__ = [
|
||||
"EmbeddingsClusteringFilter",
|
||||
"EmbeddingsRedundantFilter",
|
||||
"get_stateful_documents",
|
||||
"OpenAIMetadataTagger",
|
||||
]
|
||||
|
||||
from langchain.document_transformers.openai_functions import OpenAIMetadataTagger
|
||||
|
141
langchain/document_transformers/openai_functions.py
Normal file
141
langchain/document_transformers/openai_functions.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""Document transformers that use OpenAI Functions models"""
|
||||
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions import create_tagging_chain
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
|
||||
|
||||
|
||||
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
||||
"""Extract metadata tags from document contents using OpenAI functions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.document_transformers import OpenAIMetadataTagger
|
||||
from langchain.schema import Document
|
||||
|
||||
schema = {
|
||||
"properties": {
|
||||
"movie_title": { "type": "string" },
|
||||
"critic": { "type": "string" },
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["positive", "negative"]
|
||||
},
|
||||
"rating": {
|
||||
"type": "integer",
|
||||
"description": "The number of stars the critic rated the movie"
|
||||
}
|
||||
},
|
||||
"required": ["movie_title", "critic", "tone"]
|
||||
}
|
||||
|
||||
# Must be an OpenAI model that supports functions
|
||||
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||
tagging_chain = create_tagging_chain(schema, llm)
|
||||
document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
||||
original_documents = [
|
||||
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
|
||||
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||
]
|
||||
|
||||
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||
""" # noqa: E501
|
||||
|
||||
tagging_chain: LLMChain
|
||||
"""The chain used to extract metadata from each document."""
|
||||
|
||||
def transform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
"""Automatically extract and populate metadata
|
||||
for each document according to the provided schema."""
|
||||
|
||||
new_documents = []
|
||||
|
||||
for document in documents:
|
||||
extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501
|
||||
new_document = Document(
|
||||
page_content=document.page_content,
|
||||
metadata={**extracted_metadata, **document.metadata},
|
||||
)
|
||||
new_documents.append(new_document)
|
||||
return new_documents
|
||||
|
||||
async def atransform_documents(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> Sequence[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def create_metadata_tagger(
|
||||
metadata_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[ChatPromptTemplate] = None,
|
||||
*,
|
||||
tagging_chain_kwargs: Optional[Dict] = None,
|
||||
) -> OpenAIMetadataTagger:
|
||||
"""Create a DocumentTransformer that uses an OpenAI function chain to automatically
|
||||
tag documents with metadata based on their content and an input schema.
|
||||
|
||||
Args:
|
||||
metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||
is passed in, it's assumed to already be a valid JsonSchema.
|
||||
For best results, pydantic.BaseModels should have docstrings describing what
|
||||
the schema represents and descriptions for the parameters.
|
||||
llm: Language model to use, assumed to support the OpenAI function-calling API.
|
||||
Defaults to use "gpt-3.5-turbo-0613"
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
|
||||
Returns:
|
||||
An LLMChain that will pass the given function to the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.document_transformers import create_metadata_tagger
|
||||
from langchain.schema import Document
|
||||
|
||||
schema = {
|
||||
"properties": {
|
||||
"movie_title": { "type": "string" },
|
||||
"critic": { "type": "string" },
|
||||
"tone": {
|
||||
"type": "string",
|
||||
"enum": ["positive", "negative"]
|
||||
},
|
||||
"rating": {
|
||||
"type": "integer",
|
||||
"description": "The number of stars the critic rated the movie"
|
||||
}
|
||||
},
|
||||
"required": ["movie_title", "critic", "tone"]
|
||||
}
|
||||
|
||||
# Must be an OpenAI model that supports functions
|
||||
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||
|
||||
document_transformer = create_metadata_tagger(schema, llm)
|
||||
original_documents = [
|
||||
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
|
||||
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||
]
|
||||
|
||||
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||
""" # noqa: E501
|
||||
metadata_schema = (
|
||||
metadata_schema
|
||||
if isinstance(metadata_schema, dict)
|
||||
else metadata_schema.schema()
|
||||
)
|
||||
_tagging_chain_kwargs = tagging_chain_kwargs or {}
|
||||
tagging_chain = create_tagging_chain(
|
||||
metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs
|
||||
)
|
||||
return OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
Loading…
Reference in New Issue
Block a user