You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
import os
|
|
import pickle
|
|
import copy
|
|
import numpy as np
|
|
|
|
|
|
CODES = {'<PAD>': 0, '<EOS>': 1, '<UNK>': 2, '<GO>': 3 }
|
|
|
|
|
|
def load_data(path):
|
|
"""
|
|
Load Dataset from File
|
|
"""
|
|
input_file = os.path.join(path)
|
|
with open(input_file, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
|
|
def preprocess_and_save_data(source_path, target_path, text_to_ids):
|
|
"""
|
|
Preprocess Text Data. Save to to file.
|
|
"""
|
|
# Preprocess
|
|
source_text = load_data(source_path)
|
|
target_text = load_data(target_path)
|
|
|
|
source_text = source_text.lower()
|
|
target_text = target_text.lower()
|
|
|
|
source_vocab_to_int, source_int_to_vocab = create_lookup_tables(source_text)
|
|
target_vocab_to_int, target_int_to_vocab = create_lookup_tables(target_text)
|
|
|
|
source_text, target_text = text_to_ids(source_text, target_text, source_vocab_to_int, target_vocab_to_int)
|
|
|
|
# Save Data
|
|
with open('/output/preprocess.p', 'wb') as out_file:
|
|
pickle.dump((
|
|
(source_text, target_text),
|
|
(source_vocab_to_int, target_vocab_to_int),
|
|
(source_int_to_vocab, target_int_to_vocab)), out_file)
|
|
|
|
|
|
def load_preprocess():
|
|
"""
|
|
Load the Preprocessed Training data and return them in batches of <batch_size> or less
|
|
"""
|
|
with open('/output/preprocess.p', mode='rb') as in_file:
|
|
return pickle.load(in_file)
|
|
|
|
|
|
def create_lookup_tables(text):
|
|
"""
|
|
Create lookup tables for vocabulary
|
|
"""
|
|
vocab = set(text.split())
|
|
vocab_to_int = copy.copy(CODES)
|
|
|
|
for v_i, v in enumerate(vocab, len(CODES)):
|
|
vocab_to_int[v] = v_i
|
|
|
|
int_to_vocab = {v_i: v for v, v_i in vocab_to_int.items()}
|
|
|
|
return vocab_to_int, int_to_vocab
|
|
|
|
|
|
def save_params(params):
|
|
"""
|
|
Save parameters to file
|
|
"""
|
|
with open('/output/params.p', 'wb') as out_file:
|
|
pickle.dump(params, out_file)
|
|
|
|
|
|
def load_params():
|
|
"""
|
|
Load parameters from file
|
|
"""
|
|
with open('/output/params.p', mode='rb') as in_file:
|
|
return pickle.load(in_file)
|
|
|
|
|
|
def batch_data(source, target, batch_size):
|
|
"""
|
|
Batch source and target together
|
|
"""
|
|
for batch_i in range(0, len(source)//batch_size):
|
|
start_i = batch_i * batch_size
|
|
source_batch = source[start_i:start_i + batch_size]
|
|
target_batch = target[start_i:start_i + batch_size]
|
|
yield np.array(pad_sentence_batch(source_batch)), np.array(pad_sentence_batch(target_batch))
|
|
|
|
|
|
def pad_sentence_batch(sentence_batch):
|
|
"""
|
|
Pad sentence with <PAD> id
|
|
"""
|
|
max_sentence = max([len(sentence) for sentence in sentence_batch])
|
|
return [sentence + [CODES['<PAD>']] * (max_sentence - len(sentence))
|
|
for sentence in sentence_batch]
|