From 49341483daa405c2a0b08a54480feeba6f95847d Mon Sep 17 00:00:00 2001 From: Nik <6206742+nik-418@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:46:17 +0200 Subject: [PATCH] Update Banana.dev docs to latest correct usage (#10183) - Description: this PR updates all Banana.dev-related docs to match the latest client usage. The code in the docs before this PR were out of date and would never run. - Issue: [#6404](https://github.com/langchain-ai/langchain/issues/6404) - Dependencies: - - Tag maintainer: - Twitter handle: [BananaDev_ ](https://twitter.com/BananaDev_ ) --- docs/extras/integrations/llms/banana.ipynb | 11 ++- .../integrations/providers/bananadev.mdx | 81 +++++++++---------- libs/langchain/langchain/llms/bananadev.py | 25 ++++-- 3 files changed, 64 insertions(+), 53 deletions(-) diff --git a/docs/extras/integrations/llms/banana.ipynb b/docs/extras/integrations/llms/banana.ipynb index 44e51faafa..b92db8daba 100644 --- a/docs/extras/integrations/llms/banana.ipynb +++ b/docs/extras/integrations/llms/banana.ipynb @@ -31,11 +31,16 @@ "outputs": [], "source": [ "# get new tokens: https://app.banana.dev/\n", - "# We need two tokens, not just an `api_key`: `BANANA_API_KEY` and `YOUR_MODEL_KEY`\n", + "# We need three parameters to make a Banana.dev API call:\n", + "# * a team api key\n", + "# * the model's unique key\n", + "# * the model's url slug\n", "\n", "import os\n", "from getpass import getpass\n", "\n", + "# You can get this from the main dashboard\n", + "# at https://app.banana.dev\n", "os.environ[\"BANANA_API_KEY\"] = \"YOUR_API_KEY\"\n", "# OR\n", "# BANANA_API_KEY = getpass()" @@ -70,7 +75,9 @@ "metadata": {}, "outputs": [], "source": [ - "llm = Banana(model_key=\"YOUR_MODEL_KEY\")" + "# Both of these are found in your model's \n", + "# detail page in https://app.banana.dev\n", + "llm = Banana(model_key=\"YOUR_MODEL_KEY\", model_url_slug=\"YOUR_MODEL_URL_SLUG\")" ] }, { diff --git a/docs/extras/integrations/providers/bananadev.mdx b/docs/extras/integrations/providers/bananadev.mdx index 4961e5f88b..ee7992be74 100644 --- a/docs/extras/integrations/providers/bananadev.mdx +++ b/docs/extras/integrations/providers/bananadev.mdx @@ -1,79 +1,72 @@ # Banana -This page covers how to use the Banana ecosystem within LangChain. -It is broken into two parts: installation and setup, and then references to specific Banana wrappers. +Banana provided serverless GPU inference for AI models, including a CI/CD build pipeline and a simple Python framework (Potassium) to server your models. + +This page covers how to use the [Banana](https://www.banana.dev) ecosystem within LangChain. + +It is broken into two parts: +* installation and setup, +* and then references to specific Banana wrappers. ## Installation and Setup - Install with `pip install banana-dev` -- Get an Banana api key and set it as an environment variable (`BANANA_API_KEY`) +- Get an Banana api key from the [Banana.dev dashboard](https://app.banana.dev) and set it as an environment variable (`BANANA_API_KEY`) +- Get your model's key and url slug from the model's details page ## Define your Banana Template -If you want to use an available language model template you can find one [here](https://app.banana.dev/templates/conceptofmind/serverless-template-palmyra-base). -This template uses the Palmyra-Base model by [Writer](https://writer.com/product/api/). -You can check out an example Banana repository [here](https://github.com/conceptofmind/serverless-template-palmyra-base). +You'll need to set up a Github repo for your Banana app. You can get started in 5 minutes using [this guide](https://docs.banana.dev/banana-docs/). + +Alternatively, for a ready-to-go LLM example, you can check out Banana's [CodeLlama-7B-Instruct-GPTQ](https://github.com/bananaml/demo-codellama-7b-instruct-gptq) GitHub repository. Just fork it and deploy it within Banana. + +Other starter repos are available [here](https://github.com/orgs/bananaml/repositories?q=demo-&type=all&language=&sort=). ## Build the Banana app -Banana Apps must include the "output" key in the return json. -There is a rigid response structure. +To use Banana apps within Langchain, they must include the `outputs` key +in the returned json, and the value must be a string. ```python # Return the results as a dictionary -result = {'output': result} +result = {'outputs': result} ``` An example inference function would be: ```python -def inference(model_inputs:dict) -> dict: - global model - global tokenizer - - # Parse out your arguments - prompt = model_inputs.get('prompt', None) - if prompt == None: - return {'message': "No prompt provided"} - - # Run the model - input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda() - output = model.generate( - input_ids, - max_length=100, - do_sample=True, - top_k=50, - top_p=0.95, - num_return_sequences=1, - temperature=0.9, - early_stopping=True, - no_repeat_ngram_size=3, - num_beams=5, - length_penalty=1.5, - repetition_penalty=1.5, - bad_words_ids=[[tokenizer.encode(' ', add_prefix_space=True)[0]]] - ) - - result = tokenizer.decode(output[0], skip_special_tokens=True) - # Return the results as a dictionary - result = {'output': result} - return result +@app.handler("/") +def handler(context: dict, request: Request) -> Response: + """Handle a request to generate code from a prompt.""" + model = context.get("model") + tokenizer = context.get("tokenizer") + max_new_tokens = request.json.get("max_new_tokens", 512) + temperature = request.json.get("temperature", 0.7) + prompt = request.json.get("prompt") + prompt_template=f'''[INST] Write code to solve the following coding problem that obeys the constraints and passes the example test cases. Please wrap your code answer using ```: + {prompt} + [/INST] + ''' + input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda() + output = model.generate(inputs=input_ids, temperature=temperature, max_new_tokens=max_new_tokens) + result = tokenizer.decode(output[0]) + return Response(json={"outputs": result}, status=200) ``` -You can find a full example of a Banana app [here](https://github.com/conceptofmind/serverless-template-palmyra-base/blob/main/app.py). +This example is from the `app.py` file in [CodeLlama-7B-Instruct-GPTQ](https://github.com/bananaml/demo-codellama-7b-instruct-gptq). ## Wrappers ### LLM -There exists an Banana LLM wrapper, which you can access with +Within Langchain, there exists a Banana LLM wrapper, which you can access with ```python from langchain.llms import Banana ``` -You need to provide a model key located in the dashboard: +You need to provide a model key and model url slug, which you can get from the model's details page in the [Banana.dev dashboard](https://app.banana.dev). ```python -llm = Banana(model_key="YOUR_MODEL_KEY") +llm = Banana(model_key="YOUR_MODEL_KEY", model_url_slug="YOUR_MODEL_URL_SLUG") ``` diff --git a/libs/langchain/langchain/llms/bananadev.py b/libs/langchain/langchain/llms/bananadev.py index f0659118d6..3a984a3cb2 100644 --- a/libs/langchain/langchain/llms/bananadev.py +++ b/libs/langchain/langchain/llms/bananadev.py @@ -15,6 +15,7 @@ class Banana(LLM): To use, you should have the ``banana-dev`` python package installed, and the environment variable ``BANANA_API_KEY`` set with your API key. + This is the team API key available in the Banana dashboard. Any parameters that are valid to be passed to the call can be passed in, even if not explicitly saved on this class. @@ -23,10 +24,13 @@ class Banana(LLM): .. code-block:: python from langchain.llms import Banana - banana = Banana(model_key="") + banana = Banana(model_key="", model_url_slug="") """ model_key: str = "" + """model key to use""" + + model_url_slug: str = "" """model endpoint to use""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) @@ -72,6 +76,7 @@ class Banana(LLM): """Get the identifying parameters.""" return { **{"model_key": self.model_key}, + **{"model_url_slug": self.model_url_slug}, **{"model_kwargs": self.model_kwargs}, } @@ -89,7 +94,7 @@ class Banana(LLM): ) -> str: """Call to Banana endpoint.""" try: - import banana_dev as banana + from banana_dev import Client except ImportError: raise ImportError( "Could not import banana-dev python package. " @@ -99,19 +104,25 @@ class Banana(LLM): params = {**params, **kwargs} api_key = self.banana_api_key model_key = self.model_key + model_url_slug = self.model_url_slug model_inputs = { # a json specific to your model. "prompt": prompt, **params, } - response = banana.run(api_key, model_key, model_inputs) + model = Client( + # Found in main dashboard + api_key=api_key, + # Both found in model details page + model_key=model_key, + url=f"https://{model_url_slug}.run.banana.dev", + ) + response, meta = model.call("/", model_inputs) try: - text = response["modelOutputs"][0]["output"] + text = response["outputs"] except (KeyError, TypeError): - returned = response["modelOutputs"][0] raise ValueError( - "Response should be of schema: {'output': 'text'}." - f"\nResponse was: {returned}" + "Response should be of schema: {'outputs': 'text'}." "\nTo fix this:" "\n- fork the source repo of the Banana model" "\n- modify app.py to return the above schema"