diff --git a/.gitignore b/.gitignore index 7114a35..2c05f14 100644 --- a/.gitignore +++ b/.gitignore @@ -126,3 +126,6 @@ dmypy.json # Pyre type checker .pyre/ + +# vim +*.swp diff --git a/cli/inference_one_block.py b/cli/inference_one_block.py index c7ea7f5..bda240a 100644 --- a/cli/inference_one_block.py +++ b/cli/inference_one_block.py @@ -32,18 +32,29 @@ if __name__ == "__main__": parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict") parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run") parser.add_argument("--device", default=None, type=str, help="Run inference on this device") + parser.add_argument("--block-path", default='', type=str, help="The path to the Bloom block-path") args = parser.parse_args() if args.device is None: args.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f'Using device {args.device}') config = DistributedBloomConfig.from_json_file(args.config) - block = BloomBlock(config, args.layer_index).to(args.device) + block = BloomBlock(config, args.layer_index) + + if args.block_path != '': + print(f'Loading block from {args.block_path}') + block.load_state_dict( torch.load(args.block_path)) + #print(list(block_data.keys())) + #block.load(args.block_path) + + block = block.to(args.device) + block = block.to(torch.bfloat16) cache = None for i in trange(args.num_steps): - dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device) + dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device).to(torch.bfloat16) alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device) with torch.no_grad(): outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)