import base64 import json from langchain_community.chat_models import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate from langchain_core.pydantic_v1 import Field from langserve import CustomUserType from .prompts import ( AI_REPONSE_DICT, FULL_PROMPT, USER_EXAMPLE_DICT, create_prompt, ) from .utils import parse_llm_output llm = ChatOpenAI(temperature=0, model="gpt-4") prompt = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(FULL_PROMPT), ("human", "{user_example}"), ("ai", "{ai_response}"), ("human", "{input}"), ], ) # ATTENTION: Inherit from CustomUserType instead of BaseModel otherwise # the server will decode it into a dict instead of a pydantic model. class FileProcessingRequest(CustomUserType): """Request including a base64 encoded file.""" # The extra field is used to specify a widget for the playground UI. file: str = Field(..., extra={"widget": {"type": "base64file"}}) num_plates: int = None num_rows: int = 8 num_cols: int = 12 def _load_file(request: FileProcessingRequest): return base64.b64decode(request.file.encode("utf-8")).decode("utf-8") def _load_prompt(request: FileProcessingRequest): return create_prompt( num_plates=request.num_plates, num_rows=request.num_rows, num_cols=request.num_cols, ) def _get_col_range_str(request: FileProcessingRequest): if request.num_cols: return f"from 1 to {request.num_cols}" else: return "" def _get_json_format(request: FileProcessingRequest): return json.dumps( [ { "row_start": 12, "row_end": 12 + request.num_rows - 1, "col_start": 1, "col_end": 1 + request.num_cols - 1, "contents": "Entity ID", } ] ) chain = ( { # Should add validation to ensure numeric indices "input": _load_file, "hint": _load_prompt, "col_range_str": _get_col_range_str, "json_format": _get_json_format, "user_example": lambda x: USER_EXAMPLE_DICT[x.num_rows * x.num_cols], "ai_response": lambda x: AI_REPONSE_DICT[x.num_rows * x.num_cols], } | prompt | llm | StrOutputParser() | parse_llm_output ).with_types(input_type=FileProcessingRequest)