[core, langchain] modelio code improvements (#15277)

pull/15284/head
Harrison Chase 6 months ago committed by GitHub
parent 694bbb14cd
commit b86803153e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ class BaseExampleSelector(ABC):
@abstractmethod
def add_example(self, example: Dict[str, str]) -> Any:
"""Add new example to store for a key."""
"""Add new example to store."""
@abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:

@ -3,7 +3,7 @@ from langchain_core.output_parsers.base import (
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
@ -30,4 +30,5 @@ __all__ = [
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
]

@ -0,0 +1,11 @@
# flake8: noqa
JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```"""

@ -3,12 +3,14 @@ from __future__ import annotations
import json
import re
from json import JSONDecodeError
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Type
import jsonpatch # type: ignore[import]
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.pydantic_v1 import BaseModel
def _replace_new_line(match: re.Match[str]) -> str:
@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing
@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object.
"""
pydantic_object: Optional[Type[BaseModel]] = None
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
@ -190,6 +194,26 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e
def get_format_instructions(self) -> str:
if self.pydantic_object is None:
return "Return a JSON object."
else:
schema = self.pydantic_object.schema()
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property
def _type(self) -> str:
return "simple_json_output_parser"
# For backwards compatibility
SimpleJsonOutputParser = JsonOutputParser

@ -34,7 +34,11 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]:
text = text.strip("`").strip("xml")
# Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
# If match found, use the content within the backticks
text = match.group(2)
encoding_match = self.encoding_matcher.search(text)
if encoding_match:
text = encoding_match.group(2)

@ -13,6 +13,7 @@ EXPECTED_ALL = [
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
"JsonOutputParser",
]

@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):
def get_format_instructions(self) -> str:
examples = comma_list(_generate_random_datetime_strings(self.format))
return f"""Write a datetime string that matches the
following pattern: "{self.format}". Examples: {examples}"""
return (
f"Write a datetime string that matches the "
f"following pattern: '{self.format}'.\n\n"
f"Examples: {examples}\n\n"
f"Return ONLY this string, no other words!"
)
def parse(self, response: str) -> datetime:
try:

Loading…
Cancel
Save