Commit Graph

213 Commits

Author SHA1 Message Date
Aleksandr Borzunov
fe2eb15766 Add timeouts 2022-11-25 04:56:09 +00:00
Aleksandr Borzunov
cd829fde92 Refactor _{forward,backward}_stream() 2022-11-25 04:19:21 +00:00
Aleksandr Borzunov
4518d65fdd Add MIT license 2022-11-24 22:25:51 +00:00
Alexander Borzunov
898f614515
Fix floating point issues in block_selection.py (#89) 2022-11-25 02:17:59 +04:00
Alexander Borzunov
c07a7e0812
Add "Terms of Use" 2022-11-21 18:54:07 +04:00
Artem Chumachenko
0d9c7de0bd
Add sst-2 ipynb example (#86)
- Add sst-2 example of a prompt-based training
- Have some enhancement in the persona-chat example
2022-11-07 13:55:00 +04:00
Alexander Borzunov
57e8d2e721
Implement exponential backoff for forward & backward (#85) 2022-11-02 01:21:15 +04:00
Alexander Borzunov
ee4e69c254
Enable rebalancing by default (#84) 2022-11-02 00:50:01 +04:00
Artem Chumachenko
2cb82dd648
Add colab-related changes (#80)
Add some stuff to work on COLAB more comfortable.

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
2022-11-01 13:35:16 +04:00
Alexander Borzunov
87fd6a4f08
Fix "Too many open files" during rebalancing (#83)
Now, the number of open files stays the same after every rebalancing.
2022-11-01 04:37:13 +04:00
Alexander Borzunov
f64eb3a665
Update hivemind to 1.1.2, mark model argument as required (#81) 2022-10-26 03:23:18 +04:00
Alexander Borzunov
149f433763
Rebalance swarm when necessary (#34) 2022-10-12 14:28:27 +04:00
Alexander Borzunov
640bbc38a9
Make even smaller readability changes 2022-09-20 15:03:57 +04:00
Alexander Borzunov
d1b012b479
Make small readability & style changes to the instructions (#77) 2022-09-20 15:00:59 +04:00
justheuristic
fef48d7d99
Use bitsandbytes==0.34.0, update readme (#76)
* unlock bnb backward
* Fix bnb version in README
* Update requirements.txt
2022-09-20 13:07:34 +03:00
justheuristic
8caf1145a8
Quality of life changes: update readme, simplify run_server interface (#75)
- run_server now accepts model name as both positional and keyword argument
- changed names in README to account for interface updates
- moved model conversion from README to a separate wiki page
- updated requirements.txt
2022-09-20 03:51:57 +03:00
Artem Chumachenko
1046911dea
Add prompt tuning example on Personachat dataset (#69) 2022-09-19 14:52:35 +04:00
justheuristic
3fdcc55a56
fix protobuf version (#74)
* fix protobuf version
2022-09-18 04:54:08 +03:00
justheuristic
e92487e5d2
Update dependency versions (#71)
* update dependency versions
* install bitsandbytes cpuonly from pip
* remove deprecated API from task pool
* clearer startup logs

Co-authored-by: Tim Dettmers <dettmers@cs.washington.edu>
2022-09-13 03:51:15 +03:00
Pavel Samygin
50535a8435
Priority tasks (#47)
* priority in handlers and backend pools
* simple points system on server side
* priortize task in handler before submit task
* fix tests
* s/expert/block/g

Co-authored-by: justheuristic <justheuristic@gmail.com>
2022-09-10 22:24:42 +03:00
justheuristic
892d18fea7
Build cpuonly from bitsandbytes main (#70)
Build cpuonly from main
2022-09-08 21:06:19 +03:00
justheuristic
f3984b192a
Make attention cache wait until memory is freed (#53)
Previously, attempting to allocate with MemoryCache that does not have enough space would throw AllocationFailed.

PR changes this behavior to the following:
- by default, wait until memory is freed by other tenants (FIFO)
- if could not allocate within timeout, throw AllocationFailed
- if allocated size is too big to fit even in empty cache, throw AllocationFailed

- [x] passes existing tests
- [x] passes manual load tests

p.s. if anyone wondered: using mp.Condition will not make the code simpler, their lock behavior is slightly different to what we need here

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
2022-09-07 02:14:34 +03:00
justheuristic
8a0c056929
Fix calling rpc_info multiple times (#60)
call info once
2022-09-07 01:41:23 +03:00
Artem Chumachenko
ada98a1b37
Add deep prompt inference (#66)
Add deep prompt in inference_step. Small refactoring in deep prompt code.
2022-09-06 21:33:00 +04:00
Alexander Borzunov
54ad745bed
Warn that current instructions involve 6B model but we will replace them soon (#63) 2022-09-05 15:05:59 +04:00
Alexander Borzunov
5f0c5329d4
Update readme with arxiv link and more discussions (#62)
Co-authored-by: justheuristic <justheuristic@gmail.com>
2022-09-05 12:04:50 +04:00
Alexander Borzunov
9bea7b9ea8
Update bullet points with feedback from Tim and other people (#61)
Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
2022-09-03 06:38:18 +04:00
Alexander Borzunov
7653562aa1
Use latest version of Petals scheme, shrink Petals logo (#59) 2022-09-02 15:38:04 +04:00
Alexander Borzunov
2eb5843852
Update readme for the 1st public release (#57) 2022-09-01 08:41:49 +04:00
Pavel Samygin
0be21775af
remove transformer block, implement as sequential of size 1 (#54)
* remove transformer block, implement as sequence size 1
* reimplement get_remote_module
* fix readme

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
2022-09-01 04:26:31 +03:00
Artem Chumachenko
77220c718c
Add shallow prefix-tuned inference (#55)
* Add prefix-tuned inference

* Add prefix-tuned inference

* Add preseq_length in prefix size
2022-08-31 13:21:25 +04:00
justheuristic
d271b75dd4
Let users specify sequence length instead of assuming 2048 (#52)
- Maximum length is now provided in `.inference_session(max_length=100)`
   - previously, we would always assume max length = 2048
- added a generic way to forward **kwargs to inference session
  - for compatibility with #47 
  - Note to @borzunov : it does *not* pass them arbitrarily, but instead checks for kwarg names at the bottom level
- run_server can be started with a custom max_length for inference
- renamed --cache_size_bytes to --attention_cache_bytes (to avoid collision with --cache_dir)
- --attn_cache_bytes can now support humane file sizes (e.g. 300MB instead of 314572800)
- made some server-side errors more human-readable to user (e.g. when max length is exceeded)

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
2022-08-29 21:04:37 +03:00
Dmitry Baranchuk
948877149c
Fix recovering for sequential_backward (#50) 2022-08-29 01:05:00 +03:00
Dmitry Baranchuk
24ba3433e4
[Fix] make distributed seq cls to not create the full bloom model (#49) 2022-08-28 20:20:51 +03:00
justheuristic
f12d0deee9
[quickfix 1/n] remove expensive assertions in inference code (#48)
remove expensive assertions in inference code
2022-08-28 18:43:07 +03:00
Dmitry Baranchuk
0fd2caa4be
Convert actual model weights (#46) 2022-08-17 23:32:14 +03:00
justheuristic
a2634001e9
Reduce vocabulary size in test model, fix bug in routing when overlapped (#45)
This PR reduces this vocabulary size to save memory during conversion, keeping only the first 50k tokens
As a result, 

* tests that load client-side embeddings need significantly less RAM
* we can now run CI tests with 4 servers instead of 2 - needed to test routing - see bugs uncovered
* some of the servers now use load balancing
* CI convert_model now takes 4-5 minutes (was 6-7)
2022-08-17 18:50:52 +03:00
Dmitry Baranchuk
5745882c67
fix rpc_forward_stream 2022-08-13 01:55:02 +06:00
Dmitry Baranchuk
6095f58681
Deep distributed prompt tuning (#42)
* implemented an option to add learnable prompts to intermediate layers
* added support for prompts (as input) in rpc_forward and rpc_backward
* added a test to check that RemoteSequential works correctly with deep prompts

Co-authored-by: justheuristic <justheuristic@gmail.com>
2022-08-12 18:28:21 +03:00
justheuristic
9460220a10
make pytest outputs more verbose (#44)
this PR adds --verbose and --duration* to pytest
2022-08-10 18:53:28 +03:00
Dmitry Baranchuk
c4aa1f49df
Rename 350m -> 560m (#43) 2022-08-10 11:03:10 +03:00
Dmitry Baranchuk
11a424837f
integrate mixed-8bit model (#39)
* integrate mixed-8bit model
* Fix bug with model duplication in RAM
* set throughput=1.0 to fix zero throughput problem
* add revision support
* update hivemind and bitsandbytes
* update deploy scripts
* update installation instructions
2022-08-04 09:57:37 +03:00
Alexander Borzunov
7d39d46966
Use "PETALS" as the readme title (#40)
Since we've chosen the system name, let's use it in the repo name and the readme title.
2022-08-02 18:48:54 +04:00
Dmitry Baranchuk
04a2b6f5e3
Support various backend dtypes & async serialization (#38) 2022-07-28 18:33:58 +03:00
Artem Chumachenko
d989b94614
Pack of Inference Changes (#37)
* Return multibatch mode

* Add tests

* fixes
2022-07-27 10:19:45 +04:00
Dmitry Baranchuk
6573076883
Sequential and parallel forward / backward (#36) 2022-07-23 14:32:39 +03:00
justheuristic
f0cffbf67e
Miscellaneous fixes to automatic tests (#35)
1. __Reduce memory usage in in test_full_model__ 
     - previously, loading the full model would consistently fail IF github is enforcing memory limit [example](https://github.com/bigscience-workshop/distributed-bloom/runs/7473920049?check_suite_focus=true)
     - the new version uses accelerate to save 2GB of peak memory, that was previously used when loading both reference model AND its state dict at the same time - only to load that state dict :)
2. __Safer delays when creating servers__
    - run-tests will now wait for a few seconds after creating the first server - and before creating the second one, so as to make 
sure that the first server creates a DHT instance that subsequent servers can connect to.
    - also increased the wait time after creating servers by 30 seconds to make sure we load the model in time even when bumping into slow remotes on HF side
3. __Fix environment variables in CI to avoid build conflicts__
    - the previous code was using a wrong environment variable that was always "main". The current one will correctly resolve branch name, both in main and on pull request.
    - For reference, below you can find sample environments when running CI in both cases: on pull request and on push to main.

<details>
<summary> Environment variables when building this branch (on pull request) </summary>

SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.12/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=98 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.5 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=3b457e8a14e5ecb0d65d6e4c0e9161f7756a8861 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/pull/35/merge RUNNER_OS=Linux GITHUB_REF_PROTECTED=false HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_cc9b46e4-56a1-40c5-ba08-5a91e21f0f95 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=8f0072e74f2847c0851e7ff9b5e4af7c GITHUB_EVENT_NAME=pull_request GITHUB_RUN_ID=2720198689 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_0aba811a-a04b-40a2-ba42-79efb2723e9e HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220717.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=35/merge STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle CHROME_BIN=/usr/bin/google-chrome GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.4/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:23653 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF=main CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF=fix-branch-name GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env
</details>
<details>
<summary> Environment variables when building in main (on push) </summary>

SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.11/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=53 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.4.2 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=49242d81006454d687ff3293c49f6bf234793627 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/heads/main RUNNER_OS=Linux GITHUB_REF_PROTECTED=true HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_7668f06a-99e1-4ed1-81e9-46d75fab3f33 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=3dadac48981b4a679a33224db89be1ed GITHUB_EVENT_NAME=push GITHUB_RUN_ID=2680158280 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220710.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=main STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom CHROME_BIN=/usr/bin/google-chrome ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.3/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:22000 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF= CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF= GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env
</details>



Co-authored-by: Dmitry Baranchuk <dmitrybaranchuk@gmail.com>
2022-07-22 22:38:40 +03:00
Dmitry Baranchuk
7de3acf909
Fix is_subsequence (#32) 2022-07-19 22:37:19 +03:00
justheuristic
f0c7383181
Implement RemoteSequential slicing and extra repr, add tests (#30)
- finish renaming RemoteSequenceInfo -> RemoteSequenceManager (why: if it was an *Info, user would expect it to be similar - to a dataclass; whereas in actuality, the class is doing heavy network interactions on its own)
- implement RemoteSequenceManager.make_sequence (from https://pastebin.com/uXgy2U8B )
- make RemoteSequentialInferenceSession use RemoteSequenceManager.make_sequence
- make tests pass again
- make it possible to create inference session without RemoteTransformerBlock
- make a standalone test for RemoteSequential
- rollback convert-model

Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
2022-07-19 04:28:04 +03:00
Artem Chumachenko
6ee942e915
Add GenerationMixin class (#29)
Add generation abstraction, that's using inference_session.
Added modes:
- Greedy, top-k/top-p sampling
- Multibatch generation
- Constraint abstraction
In the future, will add prefix-tuned generation, beam-search and more hf-like stuff.
2022-07-19 00:44:16 +03:00