mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Adds OpenAI functions powered document metadata tagger (#7521)
Adds a new document transformer that automatically extracts metadata for a document based on an input schema. I also moved `document_transformers.py` to `document_transformers/__init__.py` to group it with this new transformer - it didn't seem to cause issues in the notebook, but let me know if I've done something wrong there. Also had a linter issue I couldn't figure out: ``` MacBook-Pro:langchain jacoblee$ make lint poetry run mypy . docs/dist/conf.py: error: Duplicate module named "conf" (also at "./docs/api_reference/conf.py") docs/dist/conf.py: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules for more info docs/dist/conf.py: note: Common resolutions include: a) using `--exclude` to avoid checking one of them, b) adding `__init__.py` somewhere, c) using `--explicit-package-bases` or adjusting MYPYPATH Found 1 error in 1 file (errors prevented further checking) make: *** [lint] Error 2 ``` @rlancemartin @baskaryan --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
8effd90be0
commit
cdb93ab5ca
@ -131,7 +131,7 @@ chain.run(number=2, callbacks=[handler])
|
|||||||
The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:
|
The `callbacks` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) in two different places:
|
||||||
|
|
||||||
- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler], tags=['a-tag'])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.
|
- **Constructor callbacks**: defined in the constructor, eg. `LLMChain(callbacks=[handler], tags=['a-tag'])`, which will be used for all calls made on that object, and will be scoped to that object only, eg. if you pass a handler to the `LLMChain` constructor, it will not be used by the Model attached to that chain.
|
||||||
- **Request callbacks**: defined in the `call()`/`run()`/`apply()` methods used for issuing a request, eg. `chain.call(inputs, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
|
- **Request callbacks**: defined in the `run()`/`apply()` methods used for issuing a request, eg. `chain.run(input, callbacks=[handler])`, which will be used for that specific request only, and all sub-requests that it contains (eg. a call to an LLMChain triggers a call to a Model, which uses the same handler passed in the `call()` method).
|
||||||
|
|
||||||
The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.
|
The `verbose` argument is available on most objects throughout the API (Chains, Models, Tools, Agents, etc.) as a constructor argument, eg. `LLMChain(verbose=True)`, and it is equivalent to passing a `ConsoleCallbackHandler` to the `callbacks` argument of that object and all child objects. This is useful for debugging, as it will log all events to the console.
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
@ -26,7 +26,12 @@ Passage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
def create_tagging_chain(
|
||||||
|
schema: dict,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: Optional[ChatPromptTemplate] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -37,7 +42,7 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|||||||
Chain (LLMChain) that can be used to extract information from a passage.
|
Chain (LLMChain) that can be used to extract information from a passage.
|
||||||
"""
|
"""
|
||||||
function = _get_tagging_function(schema)
|
function = _get_tagging_function(schema)
|
||||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = JsonOutputFunctionsParser()
|
output_parser = JsonOutputFunctionsParser()
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
chain = LLMChain(
|
||||||
@ -45,12 +50,16 @@ def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
def create_tagging_chain_pydantic(
|
def create_tagging_chain_pydantic(
|
||||||
pydantic_schema: Any, llm: BaseLanguageModel
|
pydantic_schema: Any,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: Optional[ChatPromptTemplate] = None,
|
||||||
|
**kwargs: Any
|
||||||
) -> Chain:
|
) -> Chain:
|
||||||
"""Creates a chain that extracts information from a passage.
|
"""Creates a chain that extracts information from a passage.
|
||||||
|
|
||||||
@ -63,7 +72,7 @@ def create_tagging_chain_pydantic(
|
|||||||
"""
|
"""
|
||||||
openai_schema = pydantic_schema.schema()
|
openai_schema = pydantic_schema.schema()
|
||||||
function = _get_tagging_function(openai_schema)
|
function = _get_tagging_function(openai_schema)
|
||||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
prompt = prompt or ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||||
llm_kwargs = get_llm_kwargs(function)
|
llm_kwargs = get_llm_kwargs(function)
|
||||||
chain = LLMChain(
|
chain = LLMChain(
|
||||||
@ -71,5 +80,6 @@ def create_tagging_chain_pydantic(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
llm_kwargs=llm_kwargs,
|
llm_kwargs=llm_kwargs,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
return chain
|
return chain
|
||||||
|
@ -16,4 +16,7 @@ __all__ = [
|
|||||||
"EmbeddingsClusteringFilter",
|
"EmbeddingsClusteringFilter",
|
||||||
"EmbeddingsRedundantFilter",
|
"EmbeddingsRedundantFilter",
|
||||||
"get_stateful_documents",
|
"get_stateful_documents",
|
||||||
|
"OpenAIMetadataTagger",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
from langchain.document_transformers.openai_functions import OpenAIMetadataTagger
|
||||||
|
141
langchain/document_transformers/openai_functions.py
Normal file
141
langchain/document_transformers/openai_functions.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
"""Document transformers that use OpenAI Functions models"""
|
||||||
|
from typing import Any, Dict, Optional, Sequence, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.openai_functions import create_tagging_chain
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):
|
||||||
|
"""Extract metadata tags from document contents using OpenAI functions.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.document_transformers import OpenAIMetadataTagger
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"properties": {
|
||||||
|
"movie_title": { "type": "string" },
|
||||||
|
"critic": { "type": "string" },
|
||||||
|
"tone": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["positive", "negative"]
|
||||||
|
},
|
||||||
|
"rating": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of stars the critic rated the movie"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["movie_title", "critic", "tone"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Must be an OpenAI model that supports functions
|
||||||
|
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||||
|
tagging_chain = create_tagging_chain(schema, llm)
|
||||||
|
document_transformer = OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
||||||
|
original_documents = [
|
||||||
|
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
|
||||||
|
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||||
|
]
|
||||||
|
|
||||||
|
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
tagging_chain: LLMChain
|
||||||
|
"""The chain used to extract metadata from each document."""
|
||||||
|
|
||||||
|
def transform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
"""Automatically extract and populate metadata
|
||||||
|
for each document according to the provided schema."""
|
||||||
|
|
||||||
|
new_documents = []
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
|
extracted_metadata: Dict = self.tagging_chain.run(document.page_content) # type: ignore[assignment] # noqa: E501
|
||||||
|
new_document = Document(
|
||||||
|
page_content=document.page_content,
|
||||||
|
metadata={**extracted_metadata, **document.metadata},
|
||||||
|
)
|
||||||
|
new_documents.append(new_document)
|
||||||
|
return new_documents
|
||||||
|
|
||||||
|
async def atransform_documents(
|
||||||
|
self, documents: Sequence[Document], **kwargs: Any
|
||||||
|
) -> Sequence[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def create_metadata_tagger(
|
||||||
|
metadata_schema: Union[Dict[str, Any], Type[BaseModel]],
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
prompt: Optional[ChatPromptTemplate] = None,
|
||||||
|
*,
|
||||||
|
tagging_chain_kwargs: Optional[Dict] = None,
|
||||||
|
) -> OpenAIMetadataTagger:
|
||||||
|
"""Create a DocumentTransformer that uses an OpenAI function chain to automatically
|
||||||
|
tag documents with metadata based on their content and an input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata_schema: Either a dictionary or pydantic.BaseModel class. If a dictionary
|
||||||
|
is passed in, it's assumed to already be a valid JsonSchema.
|
||||||
|
For best results, pydantic.BaseModels should have docstrings describing what
|
||||||
|
the schema represents and descriptions for the parameters.
|
||||||
|
llm: Language model to use, assumed to support the OpenAI function-calling API.
|
||||||
|
Defaults to use "gpt-3.5-turbo-0613"
|
||||||
|
prompt: BasePromptTemplate to pass to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An LLMChain that will pass the given function to the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.document_transformers import create_metadata_tagger
|
||||||
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"properties": {
|
||||||
|
"movie_title": { "type": "string" },
|
||||||
|
"critic": { "type": "string" },
|
||||||
|
"tone": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["positive", "negative"]
|
||||||
|
},
|
||||||
|
"rating": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The number of stars the critic rated the movie"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["movie_title", "critic", "tone"]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Must be an OpenAI model that supports functions
|
||||||
|
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613")
|
||||||
|
|
||||||
|
document_transformer = create_metadata_tagger(schema, llm)
|
||||||
|
original_documents = [
|
||||||
|
Document(page_content="Review of The Bee Movie\nBy Roger Ebert\n\This is the greatest movie ever made. 4 out of 5 stars."),
|
||||||
|
Document(page_content="Review of The Godfather\nBy Anonymous\n\nThis movie was super boring. 1 out of 5 stars.", metadata={"reliable": False}),
|
||||||
|
]
|
||||||
|
|
||||||
|
enhanced_documents = document_transformer.transform_documents(original_documents)
|
||||||
|
""" # noqa: E501
|
||||||
|
metadata_schema = (
|
||||||
|
metadata_schema
|
||||||
|
if isinstance(metadata_schema, dict)
|
||||||
|
else metadata_schema.schema()
|
||||||
|
)
|
||||||
|
_tagging_chain_kwargs = tagging_chain_kwargs or {}
|
||||||
|
tagging_chain = create_tagging_chain(
|
||||||
|
metadata_schema, llm, prompt=prompt, **_tagging_chain_kwargs
|
||||||
|
)
|
||||||
|
return OpenAIMetadataTagger(tagging_chain=tagging_chain)
|
Loading…
Reference in New Issue
Block a user