mirror of https://github.com/hwchase17/langchain
community[minor]: add maritalk chat (#17675)
**Description:** Adds the MariTalk chat that is based on a LLM specially trained for Portuguese. **Twitter handle:** @MaritacaAIpull/18386/head
parent
08fa38d56d
commit
3438d2cbcc
@ -0,0 +1,201 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"<a href=\"https://colab.research.google.com/github/langchain-ai/langchain/docs/docs/integrations/chat/maritalk.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
||||||
|
"\n",
|
||||||
|
"# Maritalk\n",
|
||||||
|
"\n",
|
||||||
|
"## Introduction\n",
|
||||||
|
"\n",
|
||||||
|
"MariTalk is an assistant developed by the Brazilian company [Maritaca AI](www.maritaca.ai).\n",
|
||||||
|
"MariTalk is based on language models that have been specially trained to understand Portuguese well.\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook demonstrates how to use MariTalk with LangChain through two examples:\n",
|
||||||
|
"\n",
|
||||||
|
"1. A simple example of how to use MariTalk to perform a task.\n",
|
||||||
|
"2. LLM + RAG: The second example shows how to answer a question whose answer is found in a long document that does not fit within the token limit of MariTalk. For this, we will use a simple searcher (BM25) to first search the document for the most relevant sections and then feed them to MariTalk for answering."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Installation\n",
|
||||||
|
"First, install the LangChain library (and all its dependencies) using the following command:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install langchain-core langchain-community"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## API Key\n",
|
||||||
|
"You will need an API key that can be obtained from chat.maritaca.ai (\"Chaves da API\" section)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"### Example 1 - Pet Name Suggestions\n",
|
||||||
|
"\n",
|
||||||
|
"Let's define our language model, ChatMaritalk, and configure it with your API key."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.prompts.chat import ChatPromptTemplate\n",
|
||||||
|
"from langchain_community.chat_models import ChatMaritalk\n",
|
||||||
|
"from langchain_core.output_parsers import StrOutputParser\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatMaritalk(\n",
|
||||||
|
" api_key=\"\", # Insert your API key here\n",
|
||||||
|
" temperature=0.7,\n",
|
||||||
|
" max_tokens=100,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"output_parser = StrOutputParser()\n",
|
||||||
|
"\n",
|
||||||
|
"chat_prompt = ChatPromptTemplate.from_messages(\n",
|
||||||
|
" [\n",
|
||||||
|
" (\n",
|
||||||
|
" \"system\",\n",
|
||||||
|
" \"You are an assistant specialized in suggesting pet names. Given the animal, you must suggest 4 names.\",\n",
|
||||||
|
" ),\n",
|
||||||
|
" (\"human\", \"I have a {animal}\"),\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"chain = chat_prompt | llm | output_parser\n",
|
||||||
|
"\n",
|
||||||
|
"response = chain.invoke({\"animal\": \"dog\"})\n",
|
||||||
|
"print(response) # should answer something like \"1. Max\\n2. Bella\\n3. Charlie\\n4. Rocky\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Example 2 - RAG + LLM: UNICAMP 2024 Entrance Exam Question Answering System\n",
|
||||||
|
"For this example, we need to install some extra libraries:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install unstructured rank_bm25 pdf2image pdfminer-six pikepdf pypdf unstructured_inference fastapi kaleido uvicorn \"pillow<10.1.0\" pillow_heif -q"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Loading the database\n",
|
||||||
|
"\n",
|
||||||
|
"The first step is to create a database with the information from the notice. For this, we will download the notice from the COMVEST website and segment the extracted text into 500-character windows."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import OnlinePDFLoader\n",
|
||||||
|
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
|
||||||
|
"\n",
|
||||||
|
"# Loading the COMVEST 2024 notice\n",
|
||||||
|
"loader = OnlinePDFLoader(\n",
|
||||||
|
" \"https://www.comvest.unicamp.br/wp-content/uploads/2023/10/31-2023-Dispoe-sobre-o-Vestibular-Unicamp-2024_com-retificacao.pdf\"\n",
|
||||||
|
")\n",
|
||||||
|
"data = loader.load()\n",
|
||||||
|
"\n",
|
||||||
|
"text_splitter = RecursiveCharacterTextSplitter(\n",
|
||||||
|
" chunk_size=500, chunk_overlap=100, separators=[\"\\n\", \" \", \"\"]\n",
|
||||||
|
")\n",
|
||||||
|
"texts = text_splitter.split_documents(data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Creating a Searcher\n",
|
||||||
|
"Now that we have our database, we need a searcher. For this example, we will use a simple BM25 as a search system, but this could be replaced by any other searcher (such as search via embeddings)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import BM25Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"retriever = BM25Retriever.from_documents(texts)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Combining Search System + LLM\n",
|
||||||
|
"Now that we have our searcher, we just need to implement a prompt specifying the task and invoke the chain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains.question_answering import load_qa_chain\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = \"\"\"Baseado nos seguintes documentos, responda a pergunta abaixo.\n",
|
||||||
|
"\n",
|
||||||
|
"{context}\n",
|
||||||
|
"\n",
|
||||||
|
"Pergunta: {query}\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"qa_prompt = ChatPromptTemplate.from_messages([(\"human\", prompt)])\n",
|
||||||
|
"\n",
|
||||||
|
"chain = load_qa_chain(llm, chain_type=\"stuff\", verbose=True, prompt=qa_prompt)\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"Qual o tempo máximo para realização da prova?\"\n",
|
||||||
|
"\n",
|
||||||
|
"docs = retriever.get_relevant_documents(query)\n",
|
||||||
|
"\n",
|
||||||
|
"chain.invoke(\n",
|
||||||
|
" {\"input_documents\": docs, \"query\": query}\n",
|
||||||
|
") # Should output something like: \"O tempo máximo para realização da prova é de 5 horas.\""
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
@ -0,0 +1,151 @@
|
|||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models.chat_models import SimpleChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
from langchain_core.pydantic_v1 import Field
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMaritalk(SimpleChatModel):
|
||||||
|
"""`MariTalk` Chat models API.
|
||||||
|
|
||||||
|
This class allows interacting with the MariTalk chatbot API.
|
||||||
|
To use it, you must provide an API key either through the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatMaritalk
|
||||||
|
chat = ChatMaritalk(api_key="your_api_key_here")
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: str
|
||||||
|
"""Your MariTalk API key."""
|
||||||
|
|
||||||
|
temperature: float = Field(default=0.7, gt=0.0, lt=1.0)
|
||||||
|
"""Run inference with this temperature.
|
||||||
|
Must be in the closed interval [0.0, 1.0]."""
|
||||||
|
|
||||||
|
max_tokens: int = Field(default=512, gt=0)
|
||||||
|
"""The maximum number of tokens to generate in the reply."""
|
||||||
|
|
||||||
|
do_sample: bool = Field(default=True)
|
||||||
|
"""Whether or not to use sampling; use `True` to enable."""
|
||||||
|
|
||||||
|
top_p: float = Field(default=0.95, gt=0.0, lt=1.0)
|
||||||
|
"""Nucleus sampling parameter controlling the size of
|
||||||
|
the probability mass considered for sampling."""
|
||||||
|
|
||||||
|
system_message_workaround: bool = Field(default=True)
|
||||||
|
"""Whether to include a workaround for system messages
|
||||||
|
by adding them as a user message."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Identifies the LLM type as 'maritalk'."""
|
||||||
|
return "maritalk"
|
||||||
|
|
||||||
|
def parse_messages_for_model(
|
||||||
|
self, messages: List[BaseMessage]
|
||||||
|
) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]:
|
||||||
|
"""
|
||||||
|
Parses messages from LangChain's format to the format expected by
|
||||||
|
the MariTalk API.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
messages (List[BaseMessage]): A list of messages in LangChain
|
||||||
|
format to be parsed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of messages formatted for the MariTalk API.
|
||||||
|
"""
|
||||||
|
parsed_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message, HumanMessage):
|
||||||
|
parsed_messages.append({"role": "user", "content": message.content})
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
parsed_messages.append(
|
||||||
|
{"role": "assistant", "content": message.content}
|
||||||
|
)
|
||||||
|
elif isinstance(message, SystemMessage) and self.system_message_workaround:
|
||||||
|
# Maritalk models do not understand system message.
|
||||||
|
# #Instead we add these messages as user messages.
|
||||||
|
parsed_messages.append({"role": "user", "content": message.content})
|
||||||
|
parsed_messages.append({"role": "assistant", "content": "ok"})
|
||||||
|
|
||||||
|
return parsed_messages
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Sends the parsed messages to the MariTalk API and returns the generated
|
||||||
|
response or an error message.
|
||||||
|
|
||||||
|
This method makes an HTTP POST request to the MariTalk API with the
|
||||||
|
provided messages and other parameters.
|
||||||
|
If the request is successful and the API returns a response,
|
||||||
|
this method returns a string containing the answer.
|
||||||
|
If the request is rate-limited or encounters another error,
|
||||||
|
it returns a string with the error message.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
messages (List[BaseMessage]): Messages to send to the model.
|
||||||
|
stop (Optional[List[str]]): Tokens that will signal the model
|
||||||
|
to stop generating further tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: If the API call is successful, returns the answer.
|
||||||
|
If an error occurs (e.g., rate limiting), returns a string
|
||||||
|
describing the error.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = "https://chat.maritaca.ai/api/chat/inference"
|
||||||
|
headers = {"authorization": f"Key {self.api_key}"}
|
||||||
|
stopping_tokens = stop if stop is not None else []
|
||||||
|
|
||||||
|
parsed_messages = self.parse_messages_for_model(messages)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"messages": parsed_messages,
|
||||||
|
"do_sample": self.do_sample,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"stopping_tokens": stopping_tokens,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=data, headers=headers)
|
||||||
|
if response.status_code == 429:
|
||||||
|
return "Rate limited, please try again soon"
|
||||||
|
elif response.ok:
|
||||||
|
return response.json().get("answer", "No answer found")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return f"An error occurred: {str(e)}"
|
||||||
|
|
||||||
|
# Fallback return statement, in case of unexpected code paths
|
||||||
|
return "An unexpected error occurred"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Identifies the key parameters of the chat model for logging
|
||||||
|
or tracking purposes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of the key configuration parameters.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"system_message_workaround": self.system_message_workaround,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
}
|
Loading…
Reference in New Issue