From 0de55048b7b9c22cae2057165a50027a2ad045fc Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 2 Feb 2023 08:47:20 -0800 Subject: [PATCH] return code for pal (#844) --- docs/modules/chains/examples/pal.ipynb | 116 ++++++++++++++++++++++++- langchain/chains/pal/base.py | 11 ++- 2 files changed, 121 insertions(+), 6 deletions(-) diff --git a/docs/modules/chains/examples/pal.ipynb b/docs/modules/chains/examples/pal.ipynb index 3e0de46b..f13dc36f 100644 --- a/docs/modules/chains/examples/pal.ipynb +++ b/docs/modules/chains/examples/pal.ipynb @@ -21,6 +21,24 @@ "from langchain import OpenAI" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a58e15e", + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)" + ] + }, + { + "cell_type": "markdown", + "id": "095adc76", + "metadata": {}, + "source": [ + "## Math Prompt" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -28,7 +46,6 @@ "metadata": {}, "outputs": [], "source": [ - "llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)\n", "pal_chain = PALChain.from_math_prompt(llm, verbose=True)" ] }, @@ -64,7 +81,7 @@ " result = total_pets\n", " return result\u001b[0m\n", "\n", - "\u001b[1m> Finished PALChain chain.\u001b[0m\n" + "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { @@ -82,6 +99,14 @@ "pal_chain.run(question)" ] }, + { + "cell_type": "markdown", + "id": "0269d20a", + "metadata": {}, + "source": [ + "## Colored Objects" + ] + }, { "cell_type": "code", "execution_count": 5, @@ -89,7 +114,6 @@ "metadata": {}, "outputs": [], "source": [ - "llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)\n", "pal_chain = PALChain.from_colored_object_prompt(llm, verbose=True)" ] }, @@ -147,10 +171,94 @@ "pal_chain.run(question)" ] }, + { + "cell_type": "markdown", + "id": "fc3d7f10", + "metadata": {}, + "source": [ + "## Intermediate Steps\n", + "You can also use the intermediate steps flag to return the code executed that generates the answer." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9d2d9c61", + "metadata": {}, + "outputs": [], + "source": [ + "pal_chain = PALChain.from_colored_object_prompt(llm, verbose=True, return_intermediate_steps=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b29b971b", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses. If I remove all the pairs of sunglasses from the desk, how many purple items remain on it?\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a2c40c28", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new PALChain chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n", + "objects = []\n", + "objects += [('booklet', 'blue')] * 2\n", + "objects += [('booklet', 'purple')] * 2\n", + "objects += [('sunglasses', 'yellow')] * 2\n", + "\n", + "# Remove all pairs of sunglasses\n", + "objects = [object for object in objects if object[0] != 'sunglasses']\n", + "\n", + "# Count number of purple objects\n", + "num_purple = len([object for object in objects if object[1] == 'purple'])\n", + "answer = num_purple\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + } + ], + "source": [ + "result = pal_chain({\"question\": question})" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "efddd033", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"# Put objects into a list to record ordering\\nobjects = []\\nobjects += [('booklet', 'blue')] * 2\\nobjects += [('booklet', 'purple')] * 2\\nobjects += [('sunglasses', 'yellow')] * 2\\n\\n# Remove all pairs of sunglasses\\nobjects = [object for object in objects if object[0] != 'sunglasses']\\n\\n# Count number of purple objects\\nnum_purple = len([object for object in objects if object[1] == 'purple'])\\nanswer = num_purple\"" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result['intermediate_steps']" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "4ab20fec", + "id": "dfd88594", "metadata": {}, "outputs": [], "source": [] diff --git a/langchain/chains/pal/base.py b/langchain/chains/pal/base.py index 36c0be62..5f574aaa 100644 --- a/langchain/chains/pal/base.py +++ b/langchain/chains/pal/base.py @@ -27,6 +27,7 @@ class PALChain(Chain, BaseModel): python_globals: Optional[Dict[str, Any]] = None python_locals: Optional[Dict[str, Any]] = None output_key: str = "result" #: :meta private: + return_intermediate_steps: bool = False class Config: """Configuration for this pydantic object.""" @@ -48,7 +49,10 @@ class PALChain(Chain, BaseModel): :meta private: """ - return [self.output_key] + if not self.return_intermediate_steps: + return [self.output_key] + else: + return [self.output_key, "intermediate_steps"] def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: llm_chain = LLMChain(llm=self.llm, prompt=self.prompt) @@ -58,7 +62,10 @@ class PALChain(Chain, BaseModel): ) repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals) res = repl.run(code + f"\n{self.get_answer_expr}") - return {self.output_key: res.strip()} + output = {self.output_key: res.strip()} + if self.return_intermediate_steps: + output["intermediate_steps"] = code + return output @classmethod def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain: