* Zoo Model

* Remove optional import zoo

* Read model name from zoo model

* Logprobs passed through raw response for gold choices

Co-authored-by: Simran <emailsimran@gmail.com>
Co-authored-by: Dan Fu <danfu@cs.stanford.edu>
laurel/helm
Laurel Orr 2 years ago
parent e0a76d1f93
commit 5428afdc58

@ -7,5 +7,5 @@
[flake8] [flake8]
exclude = .git exclude = .git
max-line-length = 88 max-line-length = 88
ignore = E731, E402, W503, E203 ignore = E731, E402, W503, E203, PAI100, PAI101, PAI201, PAI202, PAI203
per-file-ignores = __init__.py:F401, version.py:D100 per-file-ignores = __init__.py:F401, version.py:D100

@ -20,6 +20,12 @@ PORT = int(os.environ.get("FLASK_PORT", 5000))
MODEL_CONSTRUCTORS = { MODEL_CONSTRUCTORS = {
"huggingface": HuggingFaceModel, "huggingface": HuggingFaceModel,
} }
try:
from manifest.api.models.zoo import ZooModel
MODEL_CONSTRUCTORS["zoo"] = ZooModel # type: ignore
except ImportError:
logger.warning("Zoo model not available.")
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
@ -31,14 +37,19 @@ def parse_args() -> argparse.Namespace:
type=str, type=str,
required=True, required=True,
help="Model type used for finding constructor.", help="Model type used for finding constructor.",
choices=["huggingface"], choices=["huggingface", "zoo"],
) )
parser.add_argument( parser.add_argument(
"--model_name", "--model_name_or_path",
default=None, default=None,
type=str, type=str,
required=True, help="Name of model or path to model. Used in initialize of model class.",
help="Name of model. Used in initialize of model class.", )
parser.add_argument(
"--model_config",
default=None,
type=str,
help="Model config. Used in initialize of model class.",
) )
parser.add_argument( parser.add_argument(
"--cache_dir", default=None, type=str, help="Cache directory for models." "--cache_dir", default=None, type=str, help="Cache directory for models."
@ -79,7 +90,10 @@ def main() -> None:
"""Run main.""" """Run main."""
kwargs = parse_args() kwargs = parse_args()
model_type = kwargs.model_type model_type = kwargs.model_type
model_name = kwargs.model_name model_name_or_path = kwargs.model_name_or_path
model_config = kwargs.model_config
if not model_name_or_path and not model_config:
raise ValueError("Must provide model_name_or_path or model_config.")
use_accelerate = kwargs.use_accelerate_multigpu use_accelerate = kwargs.use_accelerate_multigpu
if use_accelerate: if use_accelerate:
logger.info("Using accelerate. Overridding --device argument.") logger.info("Using accelerate. Overridding --device argument.")
@ -91,7 +105,8 @@ def main() -> None:
# Global model # Global model
global model global model
model = MODEL_CONSTRUCTORS[model_type]( model = MODEL_CONSTRUCTORS[model_type](
model_name, model_name_or_path,
model_config=model_config,
cache_dir=kwargs.cache_dir, cache_dir=kwargs.cache_dir,
device=kwargs.device, device=kwargs.device,
use_accelerate=use_accelerate, use_accelerate=use_accelerate,
@ -112,9 +127,10 @@ def completions() -> Dict:
if not isinstance(prompt, str): if not isinstance(prompt, str):
raise ValueError("Prompt must be a str") raise ValueError("Prompt must be a str")
results = [] results_text = []
for generations in model.generate(prompt, **generation_args): for generations in model.generate(prompt, **generation_args):
results.append(generations) results_text.append(generations)
results = [{"text": r, "text_logprob": None} for r in results_text]
# transform the result into the openai format # transform the result into the openai format
return OpenAIResponse(results).__dict__() return OpenAIResponse(results).__dict__()
@ -134,9 +150,10 @@ def choice_logits() -> Dict:
if not isinstance(gold_choices, list): if not isinstance(gold_choices, list):
raise ValueError("Gold choices must be a list of string choices") raise ValueError("Gold choices must be a list of string choices")
result = model.logits_scoring(prompt, gold_choices, **generation_args) result, score = model.logits_scoring(prompt, gold_choices, **generation_args)
results = [{"text": result, "text_logprob": score}]
# transform the result into the openai format # transform the result into the openai format
return OpenAIResponse([result]).__dict__() return OpenAIResponse(results).__dict__()
@app.route("/params", methods=["POST"]) @app.route("/params", methods=["POST"])

@ -1,10 +1,11 @@
"""Huggingface model.""" """Huggingface model."""
import json import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List, Tuple
import torch import torch
from transformers import ( from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoTokenizer, AutoTokenizer,
BloomForCausalLM, BloomForCausalLM,
@ -25,13 +26,18 @@ MODEL_REGISTRY = {
"EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM, "EleutherAI/gpt-neo-2.7B": GPTNeoForCausalLM,
"EleutherAI/gpt-j-6B": GPTJForCausalLM, "EleutherAI/gpt-j-6B": GPTJForCausalLM,
"EleutherAI/gpt-neox-20b": GPTNeoXForCausalLM, "EleutherAI/gpt-neox-20b": GPTNeoXForCausalLM,
"facebook/opt-125m": OPTForCausalLM,
"facebook/opt-1.3b": OPTForCausalLM, "facebook/opt-1.3b": OPTForCausalLM,
"facebook/opt-2.7b": OPTForCausalLM, "facebook/opt-2.7b": OPTForCausalLM,
"facebook/opt-6.7b": OPTForCausalLM, "facebook/opt-6.7b": OPTForCausalLM,
"facebook/opt-13b": OPTForCausalLM, "facebook/opt-13b": OPTForCausalLM,
"facebook/opt-30b": OPTForCausalLM, "facebook/opt-30b": OPTForCausalLM,
"gpt2": GPT2LMHeadModel, "gpt2": GPT2LMHeadModel,
"bigscience/bloom-560m": BloomForCausalLM,
"bigscience/bloom-1b7": BloomForCausalLM,
"bigscience/bloom-3b": BloomForCausalLM,
"bigscience/bloom-7b1": BloomForCausalLM, "bigscience/bloom-7b1": BloomForCausalLM,
"bigscience/bloom": AutoModelForCausalLM,
"bigscience/T0pp": AutoModelForSeq2SeqLM, "bigscience/T0pp": AutoModelForSeq2SeqLM,
"bigscience/T0_3B": AutoModelForSeq2SeqLM, "bigscience/T0_3B": AutoModelForSeq2SeqLM,
"google/t5-xl-lm-adapt": AutoModelForSeq2SeqLM, "google/t5-xl-lm-adapt": AutoModelForSeq2SeqLM,
@ -117,7 +123,8 @@ class HuggingFaceModel(Model):
def __init__( def __init__(
self, self,
model_name: str, model_name_or_path: str,
model_config: str,
cache_dir: str, cache_dir: str,
device: int, device: int,
use_accelerate: bool, use_accelerate: bool,
@ -131,7 +138,8 @@ class HuggingFaceModel(Model):
All arguments will be passed in the request from Manifest. All arguments will be passed in the request from Manifest.
Args: Args:
model_name: model name string. model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model. cache_dir: cache directory for model.
device: device to use for model. device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference. use_accelerate: whether to use accelerate for multi-gpu inference.
@ -142,32 +150,43 @@ class HuggingFaceModel(Model):
if use_accelerate and use_parallelize: if use_accelerate and use_parallelize:
raise ValueError("Cannot use both accelerate and parallelize") raise ValueError("Cannot use both accelerate and parallelize")
# Check if providing path # Check if providing path
self.model_path = model_name self.model_path = model_name_or_path
if Path(self.model_path).exists() and Path(self.model_path).is_dir(): if Path(self.model_path).exists() and Path(self.model_path).is_dir():
# Try to find config # Try to find config
if (Path(self.model_path) / "config.json").exists(): if (Path(self.model_path) / "config.json").exists():
config = json.load(open(Path(self.model_path) / "config.json")) config = json.load(open(Path(self.model_path) / "config.json"))
model_name = config["_name_or_path"] model_name_or_path = config["_name_or_path"]
self.model_name = model_name self.model_name = model_name_or_path
print("Model Name:", self.model_name, "Model Path:", self.model_path) print("Model Name:", self.model_name, "Model Path:", self.model_path)
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left"
)
except ValueError: except ValueError:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", use_fast=False
)
dtype = torch.float16 if use_fp16 else "auto" dtype = torch.float16 if use_fp16 else "auto"
try: if self.model_name == "bigscience/bloom":
# Try to explicitely find a fp16 copy (gpt-j-6B for example) model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
self.model_path, self.model_path,
cache_dir=cache_dir, cache_dir=cache_dir,
revision="float16", load_in_8bit=True,
torch_dtype=torch.float16, device_map="auto",
)
except Exception:
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
) )
else:
try:
# Try to explicitely find a fp16 copy (gpt-j-6B for example)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
revision="float16",
torch_dtype=torch.float16,
)
except Exception:
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
)
model.eval() model.eval()
print(f"Loaded Model DType {model.dtype}") print(f"Loaded Model DType {model.dtype}")
@ -175,20 +194,21 @@ class HuggingFaceModel(Model):
if not self.is_encdec: if not self.is_encdec:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
if use_accelerate: if self.model_name != "bigscience/bloom":
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red) if use_accelerate:
device = 0 self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
elif use_parallelize: device = 0
model.parallelize() elif use_parallelize:
device = 0 model.parallelize()
else: device = 0
if device > -1: else:
torch_device = ( if device > -1:
torch.device("cpu") torch_device = (
if (device == -1 or not torch.cuda.is_available()) torch.device("cpu")
else torch.device(f"cuda:{device}") if (device == -1 or not torch.cuda.is_available())
) else torch.device(f"cuda:{device}")
model = model.to(torch_device) # type: ignore )
model = model.to(torch_device) # type: ignore
self.pipeline = Pipeline( # type: ignore self.pipeline = Pipeline( # type: ignore
model=model, tokenizer=tokenizer, device=device model=model, tokenizer=tokenizer, device=device
) )
@ -258,6 +278,7 @@ class HuggingFaceModel(Model):
dispatch_model(model, device_map=device_map) dispatch_model(model, device_map=device_map)
return return
@torch.no_grad()
def generate(self, prompt: str, **kwargs: Any) -> List[str]: def generate(self, prompt: str, **kwargs: Any) -> List[str]:
""" """
Generate the prompt from model. Generate the prompt from model.
@ -303,9 +324,10 @@ class HuggingFaceModel(Model):
final_results = [r["generated_text"][start_idx:] for r in result] final_results = [r["generated_text"][start_idx:] for r in result]
return final_results return final_results
@torch.no_grad()
def logits_scoring( def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> str: ) -> Tuple[str, float]:
""" """
Given the prompt and gold choices, choose the best choice with max logits. Given the prompt and gold choices, choose the best choice with max logits.
@ -461,4 +483,4 @@ class HuggingFaceModel(Model):
if not self.is_encdec: if not self.is_encdec:
seq_log_prob = seq_log_prob * (1 / (seq_token_log_probs != 0).sum(dim=-1)) seq_log_prob = seq_log_prob * (1 / (seq_token_log_probs != 0).sum(dim=-1))
prediction = seq_log_prob.argmax(dim=-1).item() prediction = seq_log_prob.argmax(dim=-1).item()
return gold_choices[int(prediction)] return gold_choices[int(prediction)], seq_log_prob[int(prediction)].item()

@ -1,20 +1,37 @@
"""Model class.""" """Model class."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Tuple
class Model(ABC): class Model(ABC):
"""Model class.""" """Model class."""
@abstractmethod @abstractmethod
def __init__(self, model_name: str, **kwargs: Any): def __init__(
self,
model_name_or_path: str,
model_config: str,
cache_dir: str,
device: int,
use_accelerate: bool,
use_parallelize: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
):
""" """
Initialize model. Initialize model.
kwargs are passed to model as default parameters. All arguments will be passed in the request from Manifest.
Args: Args:
model_name: model name string. model_name_or_path: model name string.
model_config: model config string.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -37,3 +54,19 @@ class Model(ABC):
list of generated text (list of length 1 for 1 generation). list of generated text (list of length 1 for 1 generation).
""" """
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> Tuple[str, float]:
"""
Given the prompt and gold choices, choose the best choice with max logits.
Args:
prompt: promt to generate from.
gold_choices: list of choices to choose from.
Returns:
the returned gold choice and the score.
"""
raise NotImplementedError()

@ -0,0 +1,94 @@
"""Zoo model."""
import os
import sys
from typing import Any, Dict, List, Tuple
from manifest.api.models.model import Model
ZOO_PATH = os.environ.get("ZOO_PATH", None)
if not ZOO_PATH:
raise ImportError("ZOO_PATH environment variable not set.")
sys.path.append(ZOO_PATH)
from src.models.s4_seq import S4LMManifest # type: ignore
class ZooModel(Model):
"""Zoo model."""
def __init__(
self,
model_name_or_path: str,
model_config: str,
cache_dir: str,
device: int,
use_accelerate: bool,
use_parallelize: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
):
"""
Initialize model.
All arguments will be passed in the request from Manifest.
Args:
model_name_or_path: model name string.
model_config: model config path.
cache_dir: cache directory for model.
device: device to use for model.
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
# Check if providing path
self.model_path = model_name_or_path
self.model_config = model_config
if not self.model_config:
raise ValueError("Must provide model config.")
self.model = S4LMManifest(
config_path=self.model_config,
weights_path=self.model_path,
)
# Can only load this after the model has been initialized
self.model_name = self.model.get_model_name()
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {
"model_name": self.model_name,
"model_path": self.model_path,
"model_config": self.model_config,
}
def generate(self, prompt: str, **kwargs: Any) -> List[str]:
"""
Generate the prompt from model.
Outputs must be generated text, not including prompt.
Args:
prompt: promt to generate from.
Returns:
list of generated text (list of length 1 for 1 generation).
"""
print(prompt)
final_results = self.model.generate(prompt, **kwargs)
return final_results
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> Tuple[str, float]:
"""
Given the prompt and gold choices, choose the best choice with max logits.
Args:
prompt: promt to generate from.
gold_choices: list of choices to choose from.
Returns:
the returned gold choice and the score
"""
raise NotImplementedError()

@ -2,13 +2,13 @@
import time import time
import uuid import uuid
from typing import Any, Dict from typing import Any, Dict, List
class OpenAIResponse: class OpenAIResponse:
"""OpenAI response.""" """OpenAI response."""
def __init__(self, results: list) -> None: def __init__(self, results: List[Dict[str, Any]]) -> None:
"""Initialize response.""" """Initialize response."""
self.results = results self.results = results
self.response_id = str(uuid.uuid4()) self.response_id = str(uuid.uuid4())
@ -23,7 +23,8 @@ class OpenAIResponse:
"model": "flask_model", "model": "flask_model",
"choices": [ "choices": [
{ {
"text": result, "text": result["text"],
"text_logprob": result["text_logprob"],
# TODO: Add in more metadata for HF models # TODO: Add in more metadata for HF models
# "logprobs": { # "logprobs": {
# "tokens": result["tokens"], # "tokens": result["tokens"],

@ -144,3 +144,19 @@ class AI21Client(Client):
return self.format_response(res.json()) return self.format_response(res.json())
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("AI21 does not support choice logit request.")

@ -81,3 +81,20 @@ class Client(ABC):
request parameters as dict. request parameters as dict.
""" """
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError()

@ -149,3 +149,19 @@ class CRFMClient(Client):
return self.format_response(request_result) return self.format_response(request_result)
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("CRFM does not support choice logit request.")

@ -12,6 +12,8 @@ logger = logging.getLogger(__name__)
OPENAI_ENGINES = { OPENAI_ENGINES = {
"text-davinci-002", "text-davinci-002",
"text-davinci-001",
"davinci",
"text-curie-001", "text-curie-001",
"text-babbage-001", "text-babbage-001",
"text-ada-001", "text-ada-001",
@ -116,3 +118,19 @@ class OpenAIClient(Client):
raise e raise e
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("OpenAI does not support choice logit request.")

@ -86,3 +86,19 @@ class OPTClient(Client):
return res.json() return res.json()
return _run_completion, request_params return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("OPT does not support choice logit request.")

@ -0,0 +1,102 @@
"""Zoo client."""
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import requests
from manifest.clients.client import Client
logger = logging.getLogger(__name__)
# User param -> (client param, default value)
ZOO_PARAMS: Dict[str, Tuple[str, str]] = {}
class ZooClient(Client):
"""Zoo client."""
def connect(
self,
connection_str: Optional[str] = None,
client_args: Dict[str, Any] = {},
) -> None:
"""
Connect to the model.
Args:
connection_str: connection string.
client_args: client arguments.
"""
self.host = connection_str.rstrip("/")
for key in ZOO_PARAMS:
setattr(self, key, client_args.pop(key, ZOO_PARAMS[key][1]))
self.model_params = self.get_model_params()
def close(self) -> None:
"""Close the client."""
pass
def get_model_params(self) -> Dict:
"""
Get model params.
By getting model params from the server, we can add to request
and make sure cache keys are unique to model.
Returns:
model params.
"""
res = requests.post(self.host + "/params")
return res.json()
def get_model_inputs(self) -> List:
"""
Get allowable model inputs.
Returns:
model inputs.
"""
return list(ZOO_PARAMS.keys())
def get_request(
self, query: str, request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function.
Args:
query: query string.
Returns:
request function that takes no input.
request parameters as dict.
"""
request_params = {"prompt": query}
# Zoo is greedy and takes all params
# TODO: Once zoo is finalized, fix this
for key in list(request_args.keys()):
request_params[key] = request_args.pop(key, None)
request_params.update(self.model_params)
def _run_completion() -> Dict:
post_str = self.host + "/completions"
res = requests.post(post_str, json=request_params)
return res.json()
return _run_completion, request_params
def get_choice_logit_request(
self, query: str, gold_choices: List[str], request_args: Dict[str, Any] = {}
) -> Tuple[Callable[[], Dict], Dict]:
"""
Get request string function for choosing max choices.
Args:
query: query string.
gold_choices: choices for model to choose from via max logits.
Returns:
request function that takes no input.
request parameters as dict.
"""
raise NotImplementedError("Zoo does not support choice logit request.")

@ -12,6 +12,7 @@ from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient from manifest.clients.openai import OpenAIClient
from manifest.clients.opt import OPTClient from manifest.clients.opt import OPTClient
from manifest.clients.zoo import ZooClient
from manifest.prompt import Prompt from manifest.prompt import Prompt
from manifest.response import Response from manifest.response import Response
from manifest.session import Session from manifest.session import Session
@ -25,6 +26,7 @@ CLIENT_CONSTRUCTORS = {
"huggingface": HuggingFaceClient, "huggingface": HuggingFaceClient,
"opt": OPTClient, "opt": OPTClient,
"dummy": DummyClient, "dummy": DummyClient,
"zoo": ZooClient,
} }
CACHE_CONSTRUCTORS = { CACHE_CONSTRUCTORS = {
@ -83,12 +85,12 @@ class Manifest:
) )
self.client_name = client_name self.client_name = client_name
# Must pass kwargs as dict for client "pop" methods removed used arguments # Must pass kwargs as dict for client "pop" methods removed used arguments
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, client_args=kwargs
)
self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore self.cache = CACHE_CONSTRUCTORS[cache_name]( # type: ignore
cache_connection, cache_args=kwargs cache_connection, cache_args=kwargs
) )
self.client = CLIENT_CONSTRUCTORS[client_name]( # type: ignore
client_connection, client_args=kwargs
)
self.session = Session(session_id) self.session = Session(session_id)
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError(f"{list(kwargs.items())} arguments are not recognized.") raise ValueError(f"{list(kwargs.items())} arguments are not recognized.")

@ -44,6 +44,7 @@ REQUIRED = [
# What packages are optional? # What packages are optional?
EXTRAS = { EXTRAS = {
"dev": [ "dev": [
"autopep8>=1.6.0",
"black>=22.3.0", "black>=22.3.0",
"isort>=5.9.3", "isort>=5.9.3",
"flake8>=4.0.0", "flake8>=4.0.0",

Loading…
Cancel
Save