mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
7306600e2f
**Description:** Databricks LLM does not support SerDe the transform_input_fn and transform_output_fn. After saving and loading, the LLM will be broken. This PR serialize these functions into a hex string using pickle, and saving the hex string in the yaml file. Using pickle to serialize a function can be flaky, but this is a simple workaround that unblocks many use cases. If more sophisticated SerDe is needed, we can improve it later. Test: Added a simple unit test. I did manual test on Databricks and it works well. The saved yaml looks like: ``` llm: _type: databricks cluster_driver_port: null cluster_id: null databricks_uri: databricks endpoint_name: databricks-mixtral-8x7b-instruct extra_params: {} host: e2-dogfood.staging.cloud.databricks.com max_tokens: null model_kwargs: null n: 1 stop: null task: null temperature: 0.0 transform_input_fn: 80049520000000000000008c085f5f6d61696e5f5f948c0f7472616e73666f726d5f696e7075749493942e transform_output_fn: null ``` @baskaryan ```python from langchain_community.embeddings import DatabricksEmbeddings from langchain_community.llms import Databricks from langchain.chains import RetrievalQA from langchain.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import FAISS import mlflow embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") def transform_input(**request): request["messages"] = [ { "role": "user", "content": request["prompt"] } ] del request["prompt"] return request llm = Databricks(endpoint_name="databricks-mixtral-8x7b-instruct", transform_input_fn=transform_input) persist_dir = "faiss_databricks_embedding" # Create the vector db, persist the db to a local fs folder loader = TextLoader("state_of_the_union.txt") documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(documents) db = FAISS.from_documents(docs, embeddings) db.save_local(persist_dir) def load_retriever(persist_directory): embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") vectorstore = FAISS.load_local(persist_directory, embeddings) return vectorstore.as_retriever() retriever = load_retriever(persist_dir) retrievalQA = RetrievalQA.from_llm(llm=llm, retriever=retriever) with mlflow.start_run() as run: logged_model = mlflow.langchain.log_model( retrievalQA, artifact_path="retrieval_qa", loader_fn=load_retriever, persist_dir=persist_dir, ) # Load the retrievalQA chain loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri) print(loaded_model.predict([{"query": "What did the president say about Ketanji Brown Jackson"}])) ``` |
||
---|---|---|
.. | ||
langchain_community | ||
scripts | ||
tests | ||
Makefile | ||
poetry.lock | ||
pyproject.toml | ||
README.md |
🦜️🧑🤝🧑 LangChain Community
Quick Install
pip install langchain-community
What is it?
LangChain Community contains third-party integrations that implement the base interfaces defined in LangChain Core, making them ready-to-use in any LangChain application.
For full documentation see the API reference.
📕 Releases & Versioning
langchain-community
is currently on version 0.0.x
All changes will be accompanied by a patch version increase.
💁 Contributing
As an open-source project in a rapidly developing field, we are extremely open to contributions, whether it be in the form of a new feature, improved infrastructure, or better documentation.
For detailed information on how to contribute, see the Contributing Guide.