Harrison/few shot yaml (#682)

Co-authored-by: vintro <77507980+vintrocode@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-01-21 16:08:03 -08:00 committed by GitHub
parent a2eeaf3d43
commit e45f7e40e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 12 deletions

View File

@ -0,0 +1,4 @@
- input: happy
output: sad
- input: tall
output: short

View File

@ -0,0 +1,14 @@
_type: few_shot
input_variables:
["adjective"]
prefix:
Write antonyms for the following words.
example_prompt:
input_variables:
["input", "output"]
template:
"Input: {input}\nOutput: {output}"
examples:
examples.yaml
suffix:
"Input: {adjective}\nOutput:"

View File

@ -225,6 +225,35 @@
"!cat examples.json" "!cat examples.json"
] ]
}, },
{
"cell_type": "markdown",
"id": "d3052850",
"metadata": {},
"source": [
"And here is what the same examples stored as yaml might look like."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "901385d1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"- input: happy\r\n",
" output: sad\r\n",
"- input: tall\r\n",
" output: short\r\n"
]
}
],
"source": [
"!cat examples.yaml"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8e300335", "id": "8e300335",
@ -236,7 +265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"id": "e2bec0fc", "id": "e2bec0fc",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -267,7 +296,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"id": "98c8f356", "id": "98c8f356",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -293,6 +322,73 @@
"print(prompt.format(adjective=\"funny\"))" "print(prompt.format(adjective=\"funny\"))"
] ]
}, },
{
"cell_type": "markdown",
"id": "13620324",
"metadata": {},
"source": [
"The same would work if you loaded examples from the yaml file."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "831e5e4a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_type: few_shot\r\n",
"input_variables:\r\n",
" [\"adjective\"]\r\n",
"prefix: \r\n",
" Write antonyms for the following words.\r\n",
"example_prompt:\r\n",
" input_variables:\r\n",
" [\"input\", \"output\"]\r\n",
" template:\r\n",
" \"Input: {input}\\nOutput: {output}\"\r\n",
"examples:\r\n",
" examples.yaml\r\n",
"suffix:\r\n",
" \"Input: {adjective}\\nOutput:\"\r\n"
]
}
],
"source": [
"!cat few_shot_prompt_yaml_examples.yaml"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6f0a7eaa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Write antonyms for the following words.\n",
"\n",
"Input: happy\n",
"Output: sad\n",
"\n",
"Input: tall\n",
"Output: short\n",
"\n",
"Input: funny\n",
"Output:\n"
]
}
],
"source": [
"prompt = load_prompt(\"few_shot_prompt_yaml_examples.yaml\")\n",
"print(prompt.format(adjective=\"funny\"))"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4870aa9d", "id": "4870aa9d",
@ -304,7 +400,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 14,
"id": "9d996a86", "id": "9d996a86",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -332,7 +428,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 15,
"id": "dd2c10bb", "id": "dd2c10bb",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -369,7 +465,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 16,
"id": "6cd781ef", "id": "6cd781ef",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -400,7 +496,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 17,
"id": "533ab8a7", "id": "533ab8a7",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -437,7 +533,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 18,
"id": "0b6dd7b8", "id": "0b6dd7b8",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -458,7 +554,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 19,
"id": "76a1065d", "id": "76a1065d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -483,7 +579,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 20,
"id": "744d275d", "id": "744d275d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -530,7 +626,7 @@
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {
"hash": "b1677b440931f40d89ef8be7bf03acb108ce003de0ac9b18e8d43753ea2e7103" "hash": "8eb71adebe840dca1185e9603533462bc47eb1b1a73bf7dab2d0a8a4c932882e"
} }
} }
}, },

View File

@ -52,10 +52,17 @@ def _load_examples(config: dict) -> dict:
pass pass
elif isinstance(config["examples"], str): elif isinstance(config["examples"], str):
with open(config["examples"]) as f: with open(config["examples"]) as f:
examples = json.load(f) if config["examples"].endswith(".json"):
examples = json.load(f)
elif config["examples"].endswith((".yaml", ".yml")):
examples = yaml.safe_load(f)
else:
raise ValueError(
"Invalid file format. Only json or yaml formats are supported."
)
config["examples"] = examples config["examples"] = examples
else: else:
raise ValueError raise ValueError("Invalid examples format. Only list or string are supported.")
return config return config