You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

23 lines
834 B
Python

import numpy as np
def split_data(chars, batch_size, num_steps, split_frac=0.9):
slice_size = batch_size * num_steps
n_batches = int(len(chars) / slice_size)
x = chars[: n_batches*slice_size]
y = chars[1: n_batches*slice_size + 1]
split_idx = int(batch_size*split_frac)
x = np.stack(np.split(x, batch_size))
y = np.stack(np.split(y, batch_size))
split_idx = int(n_batches*split_frac)
train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]
val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]
return train_x, train_y, val_x, val_y
def get_batch(arrs, num_steps):
batch_size, slice_size = arrs[0].shape
n_batches = int(slice_size/num_steps)
for b in range(n_batches):
yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]