Fix imports in notebook (#9458)

pull/9465/head^2
William FH 1 year ago committed by GitHub
parent c29fbede59
commit d4f790fd40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,19 +21,19 @@
"tags": []
},
"source": [
"To use, you should have the ``transformers`` python [package installed](https://pypi.org/project/transformers/)."
"To use, you should have the ``transformers`` python [package installed](https://pypi.org/project/transformers/), as well as [pytorch](https://pytorch.org/get-started/locally/). You can also install `xformer` for a more memory-efficient attention implementation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "d772b637-de00-4663-bd77-9bc96d798db2",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install transformers > /dev/null"
"%pip install transformers --quiet"
]
},
{
@ -46,22 +46,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "165ae236-962a-4763-8052-c4836d78a5d2",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:Failed to default session, using empty session: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /sessions (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x1117f9790>: Failed to establish a new connection: [Errno 61] Connection refused'))\n"
]
}
],
"outputs": [],
"source": [
"from langchain import HuggingFacePipeline\n",
"from langchain.llms import HuggingFacePipeline\n",
"\n",
"llm = HuggingFacePipeline.from_model_id(\n",
" model_id=\"bigscience/bloom-1b7\",\n",
@ -75,24 +67,18 @@
"id": "00104b27-0c15-4a97-b198-4512337ee211",
"metadata": {},
"source": [
"### Integrate the model in an LLMChain"
"### Create Chain\n",
"\n",
"With the model loaded into memory, you can compose it with a prompt to\n",
"form a chain."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "3acf0069",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/wfh/code/lc/lckg/.venv/lib/python3.11/site-packages/transformers/generation/utils.py:1288: UserWarning: Using `max_length`'s default (64) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n",
" warnings.warn(\n",
"WARNING:root:Failed to persist run: HTTPConnectionPool(host='localhost', port=8000): Max retries exceeded with url: /chain-runs (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x144d06910>: Failed to establish a new connection: [Errno 61] Connection refused'))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -102,27 +88,19 @@
}
],
"source": [
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])\n",
"prompt = PromptTemplate.from_template(template)\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"chain = prompt | llm\n",
"\n",
"question = \"What is electroencephalography?\"\n",
"\n",
"print(llm_chain.run(question))"
"print(chain.invoke({\"question\": question}))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "843a3837",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

Loading…
Cancel
Save