From 42b892c21be7278689cabdb83101631f286ffc34 Mon Sep 17 00:00:00 2001 From: Alexander Hoyle Date: Sun, 26 Feb 2023 20:54:43 -0500 Subject: [PATCH] Avoid IntegrityError for SQLiteCache updates (#1286) While using a `SQLiteCache`, if there are duplicate `(prompt, llm, idx)` tuples passed to [`update_cache()`](https://github.com/hwchase17/langchain/blob/c5dd491a21bde7a65c66c761aa0aad3734978008/langchain/llms/base.py#L39), then an `IntegrityError` is thrown. This can happen when there are duplicated prompts within the same batch. This PR changes the SQLAlchemy `session.add()` to a `session.merge()` in `cache.py`, [following the solution from this SO thread](https://stackoverflow.com/questions/10322514/dealing-with-duplicate-primary-keys-on-insert-in-sqlalchemy-declarative-style). I believe this fixes #983, but not entirely sure since that also involves async Here's a minimal example of the error: ```python from pathlib import Path import langchain from langchain.cache import SQLiteCache llm = langchain.OpenAI(model_name="text-ada-001", openai_api_key=Path("/.openai_api_key").read_text().strip()) langchain.llm_cache = SQLiteCache("test_cache.db") llm.generate(['a'] * 5) ``` ``` > IntegrityError: (sqlite3.IntegrityError) UNIQUE constraint failed: full_llm_cache.prompt, full_llm_cache.llm, full_llm_cache.idx [SQL: INSERT INTO full_llm_cache (prompt, llm, idx, response) VALUES (?, ?, ?, ?)] [parameters: ('a', "[('_type', 'openai'), ('best_of', 1), ('frequency_penalty', 0), ('logit_bias', {}), ('max_tokens', 256), ('model_name', 'text-ada-001'), ('n', 1), ('presence_penalty', 0), ('request_timeout', None), ('stop', None), ('temperature', 0.7), ('top_p', 1)]", 0, '\n\nA is for air.\n\nA is for atmosphere.')] (Background on this error at: https://sqlalche.me/e/14/gkpj) ``` After the change, we now have the following ```python class Output: def __init__(self, text): self.text = text # make dummy data cache = SQLiteCache("test_cache_2.db") cache.update(prompt="prompt_0", llm_string="llm_0", return_val=[Output("text_0")]) cache.engine.execute("SELECT * FROM full_llm_cache").fetchall() # output > [('prompt_0', 'llm_0', 0, 'text_0')] ``` ```python # update data, before change this would have thrown an `IntegrityError` cache.update(prompt="prompt_0", llm_string="llm_0", return_val=[Output("text_0_new")]) cache.engine.execute("SELECT * FROM full_llm_cache").fetchall() # output > [('prompt_0', 'llm_0', 0, 'text_0_new')] ``` --- langchain/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langchain/cache.py b/langchain/cache.py index bdf444fe..3d7149d9 100644 --- a/langchain/cache.py +++ b/langchain/cache.py @@ -87,7 +87,7 @@ class SQLAlchemyCache(BaseCache): prompt=prompt, llm=llm_string, response=generation.text, idx=i ) with Session(self.engine) as session, session.begin(): - session.add(item) + session.merge(item) class SQLiteCache(SQLAlchemyCache):