community[patch]: Improve Kuzu Cypher generation prompt (#20481)

- [x] **PR title**: "community: improve kuzu cypher generation prompt"

- [x] **PR message**: ***Delete this entire checklist*** and replace
with
- **Description:** Improves the Kùzu Cypher generation prompt to be more
robust to open source LLM outputs
    - **Issue:** N/A
    - **Dependencies:** N/A
    - **Twitter handle:** @kuzudb

- [x] **Add tests and docs**: If you're adding a new integration, please
include
No new tests (non-breaking. change)

- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/
This commit is contained in:
Prashanth Rao 2024-04-16 21:01:36 -04:00 committed by GitHub
parent bce69ae43d
commit 295b9b704b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 4 deletions

View File

@ -1,6 +1,7 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs.kuzu_graph import KuzuGraph
@ -14,6 +15,30 @@ from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_
from langchain.chains.llm import LLMChain
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :]
return text
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class KuzuQAChain(Chain):
"""Question-answering against a graph by generating Cypher statements for Kùzu.
@ -84,6 +109,9 @@ class KuzuQAChain(Chain):
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in triple backticks
# with the language marker "cypher"
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(

View File

@ -76,10 +76,11 @@ NGQL_GENERATION_PROMPT = PromptTemplate(
KUZU_EXTRA_INSTRUCTIONS = """
Instructions:
Generate statement with Kùzu Cypher dialect (rather than standard):
1. do not use `WHERE EXISTS` clause to check the existence of a property because Kùzu database has a fixed schema.
2. do not omit relationship pattern. Always use `()-[]->()` instead of `()->()`.
3. do not include any notes or comments even if the statement does not produce the expected result.
Generate the Kùzu dialect of Cypher with the following rules in mind:
1. Do not use a `WHERE EXISTS` clause to check the existence of a property.
2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
3. Do not include any notes or comments even if the statement does not produce the expected result.
```\n"""
KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(