From 3438d2cbcc0c07d9f7b0b536423f2bac3585ec5d Mon Sep 17 00:00:00 2001 From: Rodrigo Nogueira <121117945+rodrigo-f-nogueira@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:18:23 -0300 Subject: [PATCH] community[minor]: add maritalk chat (#17675) **Description:** Adds the MariTalk chat that is based on a LLM specially trained for Portuguese. **Twitter handle:** @MaritacaAI --- docs/docs/integrations/chat/maritalk.ipynb | 201 ++++++++++++++++++ .../chat_models/__init__.py | 2 + .../chat_models/maritalk.py | 151 +++++++++++++ .../unit_tests/chat_models/test_imports.py | 1 + 4 files changed, 355 insertions(+) create mode 100644 docs/docs/integrations/chat/maritalk.ipynb create mode 100644 libs/community/langchain_community/chat_models/maritalk.py diff --git a/docs/docs/integrations/chat/maritalk.ipynb b/docs/docs/integrations/chat/maritalk.ipynb new file mode 100644 index 0000000000..82e9b75dbc --- /dev/null +++ b/docs/docs/integrations/chat/maritalk.ipynb @@ -0,0 +1,201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open\n", + "\n", + "# Maritalk\n", + "\n", + "## Introduction\n", + "\n", + "MariTalk is an assistant developed by the Brazilian company [Maritaca AI](www.maritaca.ai).\n", + "MariTalk is based on language models that have been specially trained to understand Portuguese well.\n", + "\n", + "This notebook demonstrates how to use MariTalk with LangChain through two examples:\n", + "\n", + "1. A simple example of how to use MariTalk to perform a task.\n", + "2. LLM + RAG: The second example shows how to answer a question whose answer is found in a long document that does not fit within the token limit of MariTalk. For this, we will use a simple searcher (BM25) to first search the document for the most relevant sections and then feed them to MariTalk for answering." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Installation\n", + "First, install the LangChain library (and all its dependencies) using the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install langchain-core langchain-community" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API Key\n", + "You will need an API key that can be obtained from chat.maritaca.ai (\"Chaves da API\" section)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Example 1 - Pet Name Suggestions\n", + "\n", + "Let's define our language model, ChatMaritalk, and configure it with your API key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts.chat import ChatPromptTemplate\n", + "from langchain_community.chat_models import ChatMaritalk\n", + "from langchain_core.output_parsers import StrOutputParser\n", + "\n", + "llm = ChatMaritalk(\n", + " api_key=\"\", # Insert your API key here\n", + " temperature=0.7,\n", + " max_tokens=100,\n", + ")\n", + "\n", + "output_parser = StrOutputParser()\n", + "\n", + "chat_prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are an assistant specialized in suggesting pet names. Given the animal, you must suggest 4 names.\",\n", + " ),\n", + " (\"human\", \"I have a {animal}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = chat_prompt | llm | output_parser\n", + "\n", + "response = chain.invoke({\"animal\": \"dog\"})\n", + "print(response) # should answer something like \"1. Max\\n2. Bella\\n3. Charlie\\n4. Rocky\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 2 - RAG + LLM: UNICAMP 2024 Entrance Exam Question Answering System\n", + "For this example, we need to install some extra libraries:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install unstructured rank_bm25 pdf2image pdfminer-six pikepdf pypdf unstructured_inference fastapi kaleido uvicorn \"pillow<10.1.0\" pillow_heif -q" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Loading the database\n", + "\n", + "The first step is to create a database with the information from the notice. For this, we will download the notice from the COMVEST website and segment the extracted text into 500-character windows." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import OnlinePDFLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "\n", + "# Loading the COMVEST 2024 notice\n", + "loader = OnlinePDFLoader(\n", + " \"https://www.comvest.unicamp.br/wp-content/uploads/2023/10/31-2023-Dispoe-sobre-o-Vestibular-Unicamp-2024_com-retificacao.pdf\"\n", + ")\n", + "data = loader.load()\n", + "\n", + "text_splitter = RecursiveCharacterTextSplitter(\n", + " chunk_size=500, chunk_overlap=100, separators=[\"\\n\", \" \", \"\"]\n", + ")\n", + "texts = text_splitter.split_documents(data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Creating a Searcher\n", + "Now that we have our database, we need a searcher. For this example, we will use a simple BM25 as a search system, but this could be replaced by any other searcher (such as search via embeddings)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import BM25Retriever\n", + "\n", + "retriever = BM25Retriever.from_documents(texts)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Combining Search System + LLM\n", + "Now that we have our searcher, we just need to implement a prompt specifying the task and invoke the chain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains.question_answering import load_qa_chain\n", + "\n", + "prompt = \"\"\"Baseado nos seguintes documentos, responda a pergunta abaixo.\n", + "\n", + "{context}\n", + "\n", + "Pergunta: {query}\n", + "\"\"\"\n", + "\n", + "qa_prompt = ChatPromptTemplate.from_messages([(\"human\", prompt)])\n", + "\n", + "chain = load_qa_chain(llm, chain_type=\"stuff\", verbose=True, prompt=qa_prompt)\n", + "\n", + "query = \"Qual o tempo máximo para realização da prova?\"\n", + "\n", + "docs = retriever.get_relevant_documents(query)\n", + "\n", + "chain.invoke(\n", + " {\"input_documents\": docs, \"query\": query}\n", + ") # Should output something like: \"O tempo máximo para realização da prova é de 5 horas.\"" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 00f7ffeb40..1d128f82ac 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -43,6 +43,7 @@ from langchain_community.chat_models.konko import ChatKonko from langchain_community.chat_models.litellm import ChatLiteLLM from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter from langchain_community.chat_models.llama_edge import LlamaEdgeChatService +from langchain_community.chat_models.maritalk import ChatMaritalk from langchain_community.chat_models.minimax import MiniMaxChat from langchain_community.chat_models.mlflow import ChatMlflow from langchain_community.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway @@ -99,4 +100,5 @@ __all__ = [ "ChatYuan2", "ChatZhipuAI", "ChatKinetica", + "ChatMaritalk", ] diff --git a/libs/community/langchain_community/chat_models/maritalk.py b/libs/community/langchain_community/chat_models/maritalk.py new file mode 100644 index 0000000000..ab90de19c8 --- /dev/null +++ b/libs/community/langchain_community/chat_models/maritalk.py @@ -0,0 +1,151 @@ +from typing import Any, Dict, List, Optional, Union + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import SimpleChatModel +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage +from langchain_core.pydantic_v1 import Field + + +class ChatMaritalk(SimpleChatModel): + """`MariTalk` Chat models API. + + This class allows interacting with the MariTalk chatbot API. + To use it, you must provide an API key either through the constructor. + + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatMaritalk + chat = ChatMaritalk(api_key="your_api_key_here") + """ + + api_key: str + """Your MariTalk API key.""" + + temperature: float = Field(default=0.7, gt=0.0, lt=1.0) + """Run inference with this temperature. + Must be in the closed interval [0.0, 1.0].""" + + max_tokens: int = Field(default=512, gt=0) + """The maximum number of tokens to generate in the reply.""" + + do_sample: bool = Field(default=True) + """Whether or not to use sampling; use `True` to enable.""" + + top_p: float = Field(default=0.95, gt=0.0, lt=1.0) + """Nucleus sampling parameter controlling the size of + the probability mass considered for sampling.""" + + system_message_workaround: bool = Field(default=True) + """Whether to include a workaround for system messages + by adding them as a user message.""" + + @property + def _llm_type(self) -> str: + """Identifies the LLM type as 'maritalk'.""" + return "maritalk" + + def parse_messages_for_model( + self, messages: List[BaseMessage] + ) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]: + """ + Parses messages from LangChain's format to the format expected by + the MariTalk API. + + Parameters: + messages (List[BaseMessage]): A list of messages in LangChain + format to be parsed. + + Returns: + A list of messages formatted for the MariTalk API. + """ + parsed_messages = [] + + for message in messages: + if isinstance(message, HumanMessage): + parsed_messages.append({"role": "user", "content": message.content}) + elif isinstance(message, AIMessage): + parsed_messages.append( + {"role": "assistant", "content": message.content} + ) + elif isinstance(message, SystemMessage) and self.system_message_workaround: + # Maritalk models do not understand system message. + # #Instead we add these messages as user messages. + parsed_messages.append({"role": "user", "content": message.content}) + parsed_messages.append({"role": "assistant", "content": "ok"}) + + return parsed_messages + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Sends the parsed messages to the MariTalk API and returns the generated + response or an error message. + + This method makes an HTTP POST request to the MariTalk API with the + provided messages and other parameters. + If the request is successful and the API returns a response, + this method returns a string containing the answer. + If the request is rate-limited or encounters another error, + it returns a string with the error message. + + Parameters: + messages (List[BaseMessage]): Messages to send to the model. + stop (Optional[List[str]]): Tokens that will signal the model + to stop generating further tokens. + + Returns: + str: If the API call is successful, returns the answer. + If an error occurs (e.g., rate limiting), returns a string + describing the error. + """ + try: + url = "https://chat.maritaca.ai/api/chat/inference" + headers = {"authorization": f"Key {self.api_key}"} + stopping_tokens = stop if stop is not None else [] + + parsed_messages = self.parse_messages_for_model(messages) + + data = { + "messages": parsed_messages, + "do_sample": self.do_sample, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "stopping_tokens": stopping_tokens, + **kwargs, + } + + response = requests.post(url, json=data, headers=headers) + if response.status_code == 429: + return "Rate limited, please try again soon" + elif response.ok: + return response.json().get("answer", "No answer found") + + except requests.exceptions.RequestException as e: + return f"An error occurred: {str(e)}" + + # Fallback return statement, in case of unexpected code paths + return "An unexpected error occurred" + + @property + def _identifying_params(self) -> Dict[str, Any]: + """ + Identifies the key parameters of the chat model for logging + or tracking purposes. + + Returns: + A dictionary of the key configuration parameters. + """ + return { + "system_message_workaround": self.system_message_workaround, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, + } diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index f851a6348d..61ef5f831c 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -13,6 +13,7 @@ EXPECTED_ALL = [ "ChatDeepInfra", "ChatGooglePalm", "ChatHuggingFace", + "ChatMaritalk", "ChatMlflow", "ChatMLflowAIGateway", "ChatOllama",