diff --git a/docs/modules/prompts/output_parsers/examples/datetime.ipynb b/docs/modules/prompts/output_parsers/examples/datetime.ipynb new file mode 100644 index 00000000..630d6f3a --- /dev/null +++ b/docs/modules/prompts/output_parsers/examples/datetime.ipynb @@ -0,0 +1,134 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "07311335", + "metadata": {}, + "source": [ + "# Datetime\n", + "\n", + "This OutputParser shows out to parse LLM output into datetime format." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "77e49a3d", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.output_parsers import DatetimeOutputParser\n", + "from langchain.chains import LLMChain\n", + "from langchain.llms import OpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ace93488", + "metadata": {}, + "outputs": [], + "source": [ + "output_parser = DatetimeOutputParser()\n", + "template = \"\"\"Answer the users question:\n", + "\n", + "{question}\n", + "\n", + "{format_instructions}\"\"\"\n", + "prompt = PromptTemplate.from_template(template, partial_variables={\"format_instructions\": output_parser.get_format_instructions()})" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9240a3ae", + "metadata": {}, + "outputs": [], + "source": [ + "chain = LLMChain(prompt=prompt, llm=OpenAI())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ad62eacc", + "metadata": {}, + "outputs": [], + "source": [ + "output = chain.run(\"around when was bitcoin founded?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "96657765", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n\\n2008-01-03T18:15:05.000000Z'" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bf714e52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "datetime.datetime(2008, 1, 3, 18, 15, 5)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_parser.parse(output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a56112b1", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index 4133e4a1..1d58be35 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -1,3 +1,4 @@ +from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.fix import OutputFixingParser from langchain.output_parsers.list import ( CommaSeparatedListOutputParser, @@ -22,4 +23,5 @@ __all__ = [ "RetryOutputParser", "RetryWithErrorOutputParser", "OutputFixingParser", + "DatetimeOutputParser", ] diff --git a/langchain/output_parsers/datetime.py b/langchain/output_parsers/datetime.py new file mode 100644 index 00000000..658459d2 --- /dev/null +++ b/langchain/output_parsers/datetime.py @@ -0,0 +1,50 @@ +import random +from datetime import datetime, timedelta +from typing import List + +from langchain.schema import BaseOutputParser, OutputParserException +from langchain.utils import comma_list + + +def _generate_random_datetime_strings( + pattern: str, + n: int = 3, + start_date: datetime = datetime(1, 1, 1), + end_date: datetime = datetime.now() + timedelta(days=3650), +) -> List[str]: + """ + Generates n random datetime strings conforming to the + given pattern within the specified date range. + Pattern should be a string containing the desired format codes. + start_date and end_date should be datetime objects representing + the start and end of the date range. + """ + examples = [] + delta = end_date - start_date + for i in range(n): + random_delta = random.uniform(0, delta.total_seconds()) + dt = start_date + timedelta(seconds=random_delta) + date_string = dt.strftime(pattern) + examples.append(date_string) + return examples + + +class DatetimeOutputParser(BaseOutputParser[datetime]): + format: str = "%Y-%m-%dT%H:%M:%S.%fZ" + + def get_format_instructions(self) -> str: + examples = comma_list(_generate_random_datetime_strings(self.format)) + return f"""Write a datetime string that matches the + following pattern: "{self.format}". Examples: {examples}""" + + def parse(self, response: str) -> datetime: + try: + return datetime.strptime(response.strip(), self.format) + except ValueError as e: + raise OutputParserException( + f"Could not parse datetime string: {response}" + ) from e + + @property + def _type(self) -> str: + return "datetime" diff --git a/langchain/utils.py b/langchain/utils.py index 0e9b79f5..52d7a0ca 100644 --- a/langchain/utils.py +++ b/langchain/utils.py @@ -2,7 +2,7 @@ import contextlib import datetime import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from requests import HTTPError, Response @@ -82,6 +82,10 @@ def stringify_dict(data: dict) -> str: return text +def comma_list(items: List[Any]) -> str: + return ", ".join(str(item) for item in items) + + @contextlib.contextmanager def mock_now(dt_value): # type: ignore """Context manager for mocking out datetime.now() in unit tests. diff --git a/tests/unit_tests/output_parsers/test_datetime_parser.py b/tests/unit_tests/output_parsers/test_datetime_parser.py new file mode 100644 index 00000000..45038abf --- /dev/null +++ b/tests/unit_tests/output_parsers/test_datetime_parser.py @@ -0,0 +1,49 @@ +from datetime import datetime +from time import sleep + +from langchain.output_parsers.datetime import DatetimeOutputParser + + +def test_datetime_output_parser_parse() -> None: + parser = DatetimeOutputParser() + + # Test valid input + date = datetime.now() + datestr = date.strftime(parser.format) + result = parser.parse(datestr) + assert result == date + + # Test valid input + parser.format = "%Y-%m-%dT%H:%M:%S" + date = datetime.now() + datestr = date.strftime(parser.format) + result = parser.parse(datestr) + assert ( + result.year == date.year + and result.month == date.month + and result.day == date.day + and result.hour == date.hour + and result.minute == date.minute + and result.second == date.second + ) + + # Test valid input + parser.format = "%H:%M:%S" + date = datetime.now() + datestr = date.strftime(parser.format) + result = parser.parse(datestr) + assert ( + result.hour == date.hour + and result.minute == date.minute + and result.second == date.second + ) + + # Test invalid input + try: + sleep(0.001) + datestr = date.strftime(parser.format) + result = parser.parse(datestr) + assert result == date + assert False, "Should have raised AssertionError" + except AssertionError: + pass