langchain/templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py

168 lines
4.7 KiB
Python
Raw Normal View History

from typing import List, Optional, Union
2023-10-27 02:44:30 +00:00
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain_community.graphs import Neo4jGraph
from langchain_core.messages import (
AIMessage,
SystemMessage,
ToolMessage,
)
docs[patch], templates[patch]: Import from core (#14575) Update imports to use core for the low-hanging fruit changes. Ran following ```bash git grep -l 'langchain.schema.runnable' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.runnable/langchain_core.runnables/g' git grep -l 'langchain.schema.output_parser' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.output_parser/langchain_core.output_parsers/g' git grep -l 'langchain.schema.messages' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.messages/langchain_core.messages/g' git grep -l 'langchain.schema.chat_histry' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.chat_history/langchain_core.chat_history/g' git grep -l 'langchain.schema.prompt_template' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.prompt_template/langchain_core.prompts/g' git grep -l 'from langchain.pydantic_v1' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.pydantic_v1/from langchain_core.pydantic_v1/g' git grep -l 'from langchain.tools.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.tools\.base/from langchain_core.tools/g' git grep -l 'from langchain.chat_models.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.chat_models.base/from langchain_core.language_models.chat_models/g' git grep -l 'from langchain.llms.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.llms\.base\ /from langchain_core.language_models.llms\ /g' git grep -l 'from langchain.embeddings.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.embeddings\.base/from langchain_core.embeddings/g' git grep -l 'from langchain.vectorstores.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.vectorstores\.base/from langchain_core.vectorstores/g' git grep -l 'from langchain.agents.tools' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.agents\.tools/from langchain_core.tools/g' git grep -l 'from langchain.schema.output' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.output\ /from langchain_core.outputs\ /g' git grep -l 'from langchain.schema.embeddings' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.embeddings/from langchain_core.embeddings/g' git grep -l 'from langchain.schema.document' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.document/from langchain_core.documents/g' git grep -l 'from langchain.schema.agent' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.agent/from langchain_core.agents/g' git grep -l 'from langchain.schema.prompt ' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.prompt\ /from langchain_core.prompt_values /g' git grep -l 'from langchain.schema.language_model' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.language_model/from langchain_core.language_models/g' ```
2023-12-12 00:49:10 +00:00
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
docs[patch], templates[patch]: Import from core (#14575) Update imports to use core for the low-hanging fruit changes. Ran following ```bash git grep -l 'langchain.schema.runnable' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.runnable/langchain_core.runnables/g' git grep -l 'langchain.schema.output_parser' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.output_parser/langchain_core.output_parsers/g' git grep -l 'langchain.schema.messages' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.messages/langchain_core.messages/g' git grep -l 'langchain.schema.chat_histry' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.chat_history/langchain_core.chat_history/g' git grep -l 'langchain.schema.prompt_template' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.prompt_template/langchain_core.prompts/g' git grep -l 'from langchain.pydantic_v1' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.pydantic_v1/from langchain_core.pydantic_v1/g' git grep -l 'from langchain.tools.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.tools\.base/from langchain_core.tools/g' git grep -l 'from langchain.chat_models.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.chat_models.base/from langchain_core.language_models.chat_models/g' git grep -l 'from langchain.llms.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.llms\.base\ /from langchain_core.language_models.llms\ /g' git grep -l 'from langchain.embeddings.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.embeddings\.base/from langchain_core.embeddings/g' git grep -l 'from langchain.vectorstores.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.vectorstores\.base/from langchain_core.vectorstores/g' git grep -l 'from langchain.agents.tools' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.agents\.tools/from langchain_core.tools/g' git grep -l 'from langchain.schema.output' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.output\ /from langchain_core.outputs\ /g' git grep -l 'from langchain.schema.embeddings' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.embeddings/from langchain_core.embeddings/g' git grep -l 'from langchain.schema.document' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.document/from langchain_core.documents/g' git grep -l 'from langchain.schema.agent' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.agent/from langchain_core.agents/g' git grep -l 'from langchain.schema.prompt ' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.prompt\ /from langchain_core.prompt_values /g' git grep -l 'from langchain.schema.language_model' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.language_model/from langchain_core.language_models/g' ```
2023-12-12 00:49:10 +00:00
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
2023-10-27 02:44:30 +00:00
# Connection to Neo4j
graph = Neo4jGraph()
# Cypher validation tool for relationship directions
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in graph.structured_schema.get("relationships")
]
cypher_validation = CypherQueryCorrector(corrector_schema)
# LLMs
cypher_llm = ChatOpenAI(model="gpt-4", temperature=0.0)
qa_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.0)
2023-10-27 02:44:30 +00:00
# Extract entities from text
class Entities(BaseModel):
"""Identifying information about entities."""
names: List[str] = Field(
...,
2023-10-27 02:44:30 +00:00
description="All the person, organization, or business entities that "
"appear in the text",
)
2023-10-27 02:44:30 +00:00
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are extracting organization and person entities from the text.",
),
(
"human",
2023-10-27 02:44:30 +00:00
"Use the given format to extract information from the following "
"input: {question}",
),
]
)
2023-10-27 02:44:30 +00:00
# Fulltext index query
def map_to_database(entities: Entities) -> Optional[str]:
result = ""
for entity in entities.names:
response = graph.query(
"CALL db.index.fulltext.queryNodes('entity', $entity + '*', {limit:1})"
" YIELD node,score RETURN node.name AS result",
2023-10-27 02:44:30 +00:00
{"entity": entity},
)
try:
result += f"{entity} maps to {response[0]['result']} in database\n"
except IndexError:
pass
return result
2023-10-27 02:44:30 +00:00
entity_chain = prompt | qa_llm.with_structured_output(Entities)
# Generate Cypher statement based on natural language input
cypher_template = """Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question:
{schema}
Entities in the question map to the following database values:
{entities_list}
Question: {question}
2023-10-27 02:44:30 +00:00
Cypher query:""" # noqa: E501
cypher_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"Given an input question, convert it to a Cypher query. No pre-amble.",
),
("human", cypher_template),
]
)
cypher_response = (
RunnablePassthrough.assign(names=entity_chain)
| RunnablePassthrough.assign(
entities_list=lambda x: map_to_database(x["names"]),
schema=lambda _: graph.get_schema,
)
| cypher_prompt
| cypher_llm.bind(stop=["\nCypherResult:"])
| StrOutputParser()
)
# Generate natural language response based on database results
response_system = """You are an assistant that helps to form nice and human
understandable answers based on the provided information from tools.
Do not add any other information that wasn't present in the tools, and use
very concise style in interpreting results!
"""
response_prompt = ChatPromptTemplate.from_messages(
[
SystemMessage(content=response_system),
HumanMessagePromptTemplate.from_template("{question}"),
MessagesPlaceholder(variable_name="function_response"),
]
)
def get_function_response(
query: str, question: str
) -> List[Union[AIMessage, ToolMessage]]:
context = graph.query(cypher_validation(query))
TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D"
messages = [
AIMessage(
content="",
additional_kwargs={
"tool_calls": [
{
"id": TOOL_ID,
"function": {
"arguments": '{"question":"' + question + '"}',
"name": "GetInformation",
},
"type": "function",
}
]
},
),
ToolMessage(content=str(context), tool_call_id=TOOL_ID),
]
return messages
chain = (
RunnablePassthrough.assign(query=cypher_response)
| RunnablePassthrough.assign(
function_response=lambda x: get_function_response(x["query"], x["question"])
)
| response_prompt
| qa_llm
| StrOutputParser()
)
# Add typing for input
class Question(BaseModel):
question: str
chain = chain.with_types(input_type=Question)