diff --git a/tv-script-generation/.gitattributes b/tv-script-generation/.gitattributes new file mode 100644 index 0000000..5698bc4 --- /dev/null +++ b/tv-script-generation/.gitattributes @@ -0,0 +1,3 @@ +save.* filter=lfs diff=lfs merge=lfs -text +*.p filter=lfs diff=lfs merge=lfs -text +**/data filter=lfs diff=lfs merge=lfs -text diff --git a/tv-script-generation/.ipynb_checkpoints/dlnd_tv_script_generation-checkpoint.ipynb b/tv-script-generation/.ipynb_checkpoints/dlnd_tv_script_generation-checkpoint.ipynb index d9d2a57..0cbb481 100644 --- a/tv-script-generation/.ipynb_checkpoints/dlnd_tv_script_generation-checkpoint.ipynb +++ b/tv-script-generation/.ipynb_checkpoints/dlnd_tv_script_generation-checkpoint.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 64, "metadata": { "collapsed": false, "deletable": true, @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 14, "metadata": { "collapsed": false, "deletable": true, @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 65, "metadata": { "collapsed": false, "deletable": true, @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, @@ -294,16 +294,21 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "deletable": true, + "editable": true + }, "source": [ "### Extra hyper parameters" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 2, "metadata": { - "collapsed": false + "collapsed": false, + "deletable": true, + "editable": true }, "outputs": [], "source": [ @@ -345,21 +350,17 @@ "cell_type": "code", "execution_count": 3, "metadata": { - "collapsed": false + "collapsed": false, + "deletable": true, + "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TensorFlow Version: 1.0.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/spike/.pyenv/versions/3.5.1/envs/ml/lib/python3.5/site-packages/ipykernel/__main__.py:14: UserWarning: No GPU found. Please use a GPU to train your neural network.\n" + "TensorFlow Version: 1.0.0\n", + "Default GPU Device: /gpu:0\n" ] } ], @@ -384,7 +385,10 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "deletable": true, + "editable": true + }, "source": [ "### Input\n", "Implement the `get_inputs()` function to create TF Placeholders for the Neural Network. It should create the following placeholders:\n", @@ -397,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 66, "metadata": { "collapsed": false, "deletable": true, @@ -420,10 +424,10 @@ " \"\"\"\n", " \n", " # We use shape [None, None] to feed any batch size and any sequence length\n", - " input_placeholder = tf.placeholder(tf.int32, [None, None],name='input')\n", + " input_placeholder = tf.placeholder(tf.int64, [None, None],name='input')\n", " \n", " # Targets are [batch_size, seq_length]\n", - " targets_placeholder = tf.placeholder(tf.int32, [None, None]) \n", + " targets_placeholder = tf.placeholder(tf.int64, [None, None]) \n", " \n", " \n", " learning_rate_placeholder = tf.placeholder(tf.float32)\n", @@ -454,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 67, "metadata": { "collapsed": false, "deletable": true, @@ -511,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": { "collapsed": false, "deletable": true, @@ -567,7 +571,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 68, "metadata": { "collapsed": false, "deletable": true, @@ -624,7 +628,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 157, "metadata": { "collapsed": false, "deletable": true, @@ -635,6 +639,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "logits after reshape: Tensor(\"logits:0\", shape=(128, 5, 27), dtype=float32)\n", "Tests Passed\n" ] } @@ -651,7 +656,30 @@ " \"\"\"\n", " \n", " num_outputs = vocab_size\n", - " batch_size = input_data.get_shape().as_list()[0]\n", + " \n", + " \n", + " ## Not sure why the unit test was made without taking into \n", + " # account we are handling dynamic tensor shape that we need to infer\n", + " # at runtime, so I made an if statement just to pass the test case\n", + " #\n", + " # Some references: https://goo.gl/vD3egn\n", + " # https://goo.gl/E8vT2M \n", + " \n", + " if input_data.get_shape().as_list()[1] is not None:\n", + " batch_size = input_data.get_shape().as_list()[0]\n", + " seq_len = input_data.get_shape().as_list()[1]\n", + " \n", + " # Infer dynamic tensor shape of input\n", + " else:\n", + " input_dims = tf.shape(input_data)\n", + " batch_size = input_dims[0]\n", + " seq_len = input_dims[1]\n", + "\n", + " ###############\n", + " # This enables test passing\n", + " ###############\n", + " \n", + "\n", " \n", " embed = get_embed(input_data, vocab_size, HYPER.embedding_size)\n", " \n", @@ -665,21 +693,23 @@ " \n", " # Put outputs in rows\n", " # make the output into [batch_size*time_step, rnn_size] for easy matmul\n", - " outputs = tf.reshape(raw_rnn_outputs, [-1, rnn_size])\n", + " outputs = tf.reshape(raw_rnn_outputs, [-1, rnn_size], name='rnn_output')\n", " \n", " \n", " # Question, why are we using linear activation and not softmax ?\n", " # My Guess: because seq2seq.sequence_loss has an efficient way to calculate the loss directly from logits \n", " with tf.variable_scope('linear_layer'):\n", - " linear_w = tf.Variable(tf.truncated_normal((rnn_size, num_outputs), stddev=0.1), name='linear_w')\n", + " linear_w = tf.Variable(tf.truncated_normal((rnn_size, num_outputs), stddev=0.05), name='linear_w')\n", " linear_b = tf.Variable(tf.zeros(num_outputs), name='linear_b')\n", " \n", " logits = tf.matmul(outputs, linear_w) + linear_b\n", " \n", + " \n", + " \n", " # Reshape the logits back into the original input shape -> [batch_size, seq_len, num_classes]\n", " # We do this beceause the loss function seq2seq.sequence_loss takes as logits a shape of [batch_size,seq_len,num_decoded_symbols]\n", - " logits = tf.reshape(logits, [batch_size, -1, num_outputs])\n", - " \n", + " logits = tf.reshape(logits, [batch_size, seq_len, num_outputs], name='logits')\n", + " print('logits after reshape: ', logits)\n", " \n", " return logits, final_state\n", "\n", @@ -728,45 +758,7 @@ }, { "cell_type": "code", - "execution_count": 141, - "metadata": { - "collapsed": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Stored '_input' (ndarray)\n", - "Stored '_target' (ndarray)\n", - "Stored 'test_int_text' (list)\n" - ] - } - ], - "source": [ - "batch_size = 128\n", - "seq_length = 5\n", - "slice_size = batch_size * seq_length\n", - "test_int_text = list(range(1000*seq_length))\n", - "n_batches = int(len(test_int_text)/slice_size)\n", - "\n", - "# input part\n", - "_input = np.array(int_text[:n_batches*slice_size])\n", - "\n", - "# target part\n", - "_target = np.array(int_text[1:n_batches*slice_size + 1])\n", - "\n", - "%store _input\n", - "%store _target\n", - "%store test_int_text\n", - "\n", - "for b in range(n_batches):\n", - " print \n" - ] - }, - { - "cell_type": "code", - "execution_count": 174, + "execution_count": 158, "metadata": { "collapsed": false, "deletable": true, @@ -774,17 +766,10 @@ }, "outputs": [ { - "ename": "AttributeError", - "evalue": "'list' object has no attribute 'shape'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0mDON\u001b[0m\u001b[0;31m'\u001b[0m\u001b[0mT\u001b[0m \u001b[0mMODIFY\u001b[0m \u001b[0mANYTHING\u001b[0m \u001b[0mIN\u001b[0m \u001b[0mTHIS\u001b[0m \u001b[0mCELL\u001b[0m \u001b[0mTHAT\u001b[0m \u001b[0mIS\u001b[0m \u001b[0mBELOW\u001b[0m \u001b[0mTHIS\u001b[0m \u001b[0mLINE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \"\"\"\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mtests\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_get_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_batches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/home/spike/ml/udacity/nd101/deep-learning-modified/tv-script-generation/problem_unittests.py\u001b[0m in \u001b[0;36mtest_get_batches\u001b[0;34m(get_batches)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mtest_seq_length\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[0mtest_int_text\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mtest_seq_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0mbatches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_int_text\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_batch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_seq_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;31m# Check type\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mget_batches\u001b[0;34m(int_text, batch_size, seq_length)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvectorize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_input\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_target\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mvectorize\u001b[0;34m(_inputs, _targets)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# Go through all inputs, targets and split them into batch_size*seq\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mseq_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_targets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mseq_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;31m# Stack inputs and targets into batch_size * seq_length\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'shape'" + "name": "stdout", + "output_type": "stream", + "text": [ + "Tests Passed\n" ] } ], @@ -802,40 +787,24 @@ " n_batches = int(len(int_text)/slice_size)\n", " \n", " # input part\n", - " _input = np.array(int_text[:n_batches*slice_size])\n", + " _inputs = np.array(int_text[:n_batches*slice_size])\n", " \n", " # target part\n", - " _target = np.array(int_text[1:n_batches*slice_size + 1])\n", - " \n", - " \n", - " def vectorize(_inputs, _targets):\n", - " # Takes flattened inputs and targets\n", - " # returns shape [n_batches, 2, batch_size, seq_length]\n", - " \n", - " # Go through all inputs, targets and split them into batch_size*seq list of items\n", - " # [batch*seq, batch*seq, ...]\n", - " inputs, targets = np.split(_inputs, batch_size*seq_length), np.split(_targets, batch_size*seq_length)\n", - " \n", - " # Reshape into [batch x seq, batch x seq, ...]\n", - " \n", - " # Stack inputs and targets into batch_size * seq_length \n", - " # Shape should become batch_size x seq_length\n", - " inputs, targets = np.stack(inputs), np.stack(targets)\n", - " \n", - " \n", - " # Stack Inputs and Targets\n", - " batches = np.concatenate((inputs, targets))\n", - " \n", - " return batch\n", + " _targets = np.array(int_text[1:n_batches*slice_size + 1])\n", " \n", + "\n", + " # Go through all inputs, targets and split them into batch_size*seq_len list of items\n", + " # [batch, batch, ...]\n", + " inputs, targets = np.split(_inputs, n_batches), np.split(_targets, n_batches)\n", " \n", - " result = vectorize(_input, _target)\n", + " # concat inputs and targets\n", + " batches = np.c_[inputs, targets]\n", + " #print(batches.shape)\n", " \n", - " \n", - " # preare result as reference for target shape\n", - " #result = np.empty((n_batches, 2, batch_size, seq_length), dtype=np.int32)\n", + " # Reshape into final batches output\n", + " batches = batches.reshape((-1, 2, batch_size, seq_length))\n", " \n", - " return None\n", + " return batches\n", "\n", "\n", "\"\"\"\n", @@ -865,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 164, "metadata": { "collapsed": true, "deletable": true, @@ -874,17 +843,17 @@ "outputs": [], "source": [ "# Number of Epochs\n", - "num_epochs = None\n", + "num_epochs = 100\n", "# Batch Size\n", - "batch_size = None\n", + "batch_size = 128\n", "# RNN Size\n", - "rnn_size = None\n", + "rnn_size = 256\n", "# Sequence Length\n", - "seq_length = None\n", + "seq_length = 100\n", "# Learning Rate\n", - "learning_rate = None\n", + "learning_rate = 1e-3\n", "# Show stats for every n number of batches\n", - "show_every_n_batches = None\n", + "show_every_n_batches = 1\n", "\n", "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE\n", @@ -905,13 +874,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "6779" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab_size" + ] + }, + { + "cell_type": "code", + "execution_count": 165, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logits after reshape: Tensor(\"logits:0\", shape=(?, ?, 6779), dtype=float32)\n" + ] + } + ], "source": [ "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL\n", @@ -944,6 +943,29 @@ " train_op = optimizer.apply_gradients(capped_gradients)" ] }, + { + "cell_type": "code", + "execution_count": 163, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 163, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "batches = get_batches(int_text, batch_size, seq_length)\n", + "len(batches)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -963,7 +985,66 @@ "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 Batch 0/5 train_loss = 8.828\n", + "Epoch 0 Batch 1/5 train_loss = 8.793\n", + "Epoch 0 Batch 2/5 train_loss = 8.737\n", + "Epoch 0 Batch 3/5 train_loss = 8.602\n", + "Epoch 0 Batch 4/5 train_loss = 8.298\n", + "Epoch 1 Batch 0/5 train_loss = 7.938\n", + "Epoch 1 Batch 1/5 train_loss = 7.662\n", + "Epoch 1 Batch 2/5 train_loss = 7.364\n", + "Epoch 1 Batch 3/5 train_loss = 7.164\n", + "Epoch 1 Batch 4/5 train_loss = 6.899\n", + "Epoch 2 Batch 0/5 train_loss = 6.596\n", + "Epoch 2 Batch 1/5 train_loss = 6.462\n", + "Epoch 2 Batch 2/5 train_loss = 6.309\n", + "Epoch 2 Batch 3/5 train_loss = 6.330\n", + "Epoch 2 Batch 4/5 train_loss = 6.250\n", + "Epoch 3 Batch 0/5 train_loss = 6.055\n", + "Epoch 3 Batch 1/5 train_loss = 6.048\n", + "Epoch 3 Batch 2/5 train_loss = 6.012\n", + "Epoch 3 Batch 3/5 train_loss = 6.133\n", + "Epoch 3 Batch 4/5 train_loss = 6.159\n", + "Epoch 4 Batch 0/5 train_loss = 5.996\n", + "Epoch 4 Batch 1/5 train_loss = 6.021\n", + "Epoch 4 Batch 2/5 train_loss = 6.010\n", + "Epoch 4 Batch 3/5 train_loss = 6.125\n", + "Epoch 4 Batch 4/5 train_loss = 6.156\n", + "Epoch 5 Batch 0/5 train_loss = 5.978\n", + "Epoch 5 Batch 1/5 train_loss = 5.993\n", + "Epoch 5 Batch 2/5 train_loss = 5.977\n", + "Epoch 5 Batch 3/5 train_loss = 6.081\n", + "Epoch 5 Batch 4/5 train_loss = 6.103\n", + "Epoch 6 Batch 0/5 train_loss = 5.928\n", + "Epoch 6 Batch 1/5 train_loss = 5.950\n", + "Epoch 6 Batch 2/5 train_loss = 5.938\n", + "Epoch 6 Batch 3/5 train_loss = 6.053\n", + "Epoch 6 Batch 4/5 train_loss = 6.074\n", + "Epoch 7 Batch 0/5 train_loss = 5.909\n", + "Epoch 7 Batch 1/5 train_loss = 5.937\n", + "Epoch 7 Batch 2/5 train_loss = 5.925\n", + "Epoch 7 Batch 3/5 train_loss = 6.043\n", + "Epoch 7 Batch 4/5 train_loss = 6.060\n", + "Epoch 8 Batch 0/5 train_loss = 5.896\n", + "Epoch 8 Batch 1/5 train_loss = 5.922\n", + "Epoch 8 Batch 2/5 train_loss = 5.912\n", + "Epoch 8 Batch 3/5 train_loss = 6.028\n", + "Epoch 8 Batch 4/5 train_loss = 6.049\n", + "Epoch 9 Batch 0/5 train_loss = 5.889\n", + "Epoch 9 Batch 1/5 train_loss = 5.912\n", + "Epoch 9 Batch 2/5 train_loss = 5.906\n", + "Epoch 9 Batch 3/5 train_loss = 6.020\n", + "Epoch 9 Batch 4/5 train_loss = 6.042\n", + "Epoch 10 Batch 0/5 train_loss = 5.884\n", + "Epoch 10 Batch 1/5 train_loss = 5.905\n" + ] + } + ], "source": [ "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL\n", @@ -1238,7 +1319,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.1" + "version": "3.5.2" }, "toc": { "colors": { diff --git a/tv-script-generation/__pycache__/helper.cpython-35.pyc b/tv-script-generation/__pycache__/helper.cpython-35.pyc index a9561b5..e3903f8 100644 Binary files a/tv-script-generation/__pycache__/helper.cpython-35.pyc and b/tv-script-generation/__pycache__/helper.cpython-35.pyc differ diff --git a/tv-script-generation/__pycache__/problem_unittests.cpython-35.pyc b/tv-script-generation/__pycache__/problem_unittests.cpython-35.pyc index fb7b788..f0ca97b 100644 Binary files a/tv-script-generation/__pycache__/problem_unittests.cpython-35.pyc and b/tv-script-generation/__pycache__/problem_unittests.cpython-35.pyc differ diff --git a/tv-script-generation/checkpoint b/tv-script-generation/checkpoint new file mode 100644 index 0000000..dd3d2b9 --- /dev/null +++ b/tv-script-generation/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "save" +all_model_checkpoint_paths: "save" diff --git a/tv-script-generation/dlnd_tv_script_generation.ipynb b/tv-script-generation/dlnd_tv_script_generation.ipynb index 7c76d1c..9a36bd0 100644 --- a/tv-script-generation/dlnd_tv_script_generation.ipynb +++ b/tv-script-generation/dlnd_tv_script_generation.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 64, "metadata": { "collapsed": false, "deletable": true, @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 14, "metadata": { "collapsed": false, "deletable": true, @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 65, "metadata": { "collapsed": false, "deletable": true, @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, @@ -294,16 +294,21 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "deletable": true, + "editable": true + }, "source": [ "### Extra hyper parameters" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 177, "metadata": { - "collapsed": false + "collapsed": false, + "deletable": true, + "editable": true }, "outputs": [], "source": [ @@ -311,7 +316,7 @@ "\n", "hyper_params = (('embedding_size', 128),\n", " ('lstm_layers', 2),\n", - " ('keep_prob', 0.5)\n", + " ('keep_prob', 0.7)\n", " )\n", "\n", "\n", @@ -345,21 +350,17 @@ "cell_type": "code", "execution_count": 3, "metadata": { - "collapsed": false + "collapsed": false, + "deletable": true, + "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TensorFlow Version: 1.0.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/spike/.pyenv/versions/3.5.1/envs/ml/lib/python3.5/site-packages/ipykernel/__main__.py:14: UserWarning: No GPU found. Please use a GPU to train your neural network.\n" + "TensorFlow Version: 1.0.0\n", + "Default GPU Device: /gpu:0\n" ] } ], @@ -384,7 +385,10 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "deletable": true, + "editable": true + }, "source": [ "### Input\n", "Implement the `get_inputs()` function to create TF Placeholders for the Neural Network. It should create the following placeholders:\n", @@ -397,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 225, "metadata": { "collapsed": false, "deletable": true, @@ -420,13 +424,13 @@ " \"\"\"\n", " \n", " # We use shape [None, None] to feed any batch size and any sequence length\n", - " input_placeholder = tf.placeholder(tf.int32, [None, None],name='input')\n", + " input_placeholder = tf.placeholder(tf.int64, [None, None],name='input')\n", " \n", " # Targets are [batch_size, seq_length]\n", - " targets_placeholder = tf.placeholder(tf.int32, [None, None]) \n", + " targets_placeholder = tf.placeholder(tf.int64, [None, None], name='targets') \n", " \n", " \n", - " learning_rate_placeholder = tf.placeholder(tf.float32)\n", + " learning_rate_placeholder = tf.placeholder(tf.float32, name='learning_rate')\n", " return input_placeholder, targets_placeholder, learning_rate_placeholder\n", "\n", "\n", @@ -454,7 +458,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 227, "metadata": { "collapsed": false, "deletable": true, @@ -477,17 +481,19 @@ " :param rnn_size: Size of RNNs\n", " :return: Tuple (cell, initialize state)\n", " \"\"\"\n", - " lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)\n", - " \n", - " # add a dropout wrapper\n", - " drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=HYPER.keep_prob)\n", - " \n", - " #cell = tf.contrib.rnn.MultiRNNCell([drop] * HYPER.lstm_layers)\n", - " \n", - " cell = tf.contrib.rnn.MultiRNNCell([lstm] * HYPER.lstm_layers)\n", + " with tf.name_scope('RNN_layers'):\n", + " lstm = tf.contrib.rnn.BasicLSTMCell(rnn_size)\n", + "\n", + " # add a dropout wrapper\n", + " drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=HYPER.keep_prob)\n", + "\n", + " #cell = tf.contrib.rnn.MultiRNNCell([drop] * HYPER.lstm_layers)\n", + "\n", + " cell = tf.contrib.rnn.MultiRNNCell([lstm] * HYPER.lstm_layers)\n", " \n", - " initial_state = cell.zero_state(batch_size, tf.float32)\n", - " initial_state = tf.identity(initial_state, name='initial_state')\n", + " \n", + " _initial_state = cell.zero_state(batch_size, tf.float32)\n", + " initial_state = tf.identity(_initial_state, name='initial_state')\n", " \n", " return cell, initial_state\n", "\n", @@ -511,7 +517,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 207, "metadata": { "collapsed": false, "deletable": true, @@ -535,11 +541,12 @@ " :param embed_dim: Number of embedding dimensions\n", " :return: Embedded input.\n", " \"\"\"\n", - " embeddings = tf.Variable(\n", - " tf.random_uniform([vocab_size, embed_dim], -1.0, 1.0)\n", - " )\n", - " \n", - " embed = tf.nn.embedding_lookup(embeddings, input_data)\n", + " with tf.name_scope('Embedding'):\n", + " embeddings = tf.Variable(\n", + " tf.random_uniform([vocab_size, embed_dim], -1.0, 1.0)\n", + " )\n", + "\n", + " embed = tf.nn.embedding_lookup(embeddings, input_data)\n", " \n", " return embed\n", "\n", @@ -567,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 228, "metadata": { "collapsed": false, "deletable": true, @@ -592,8 +599,9 @@ " \"\"\"\n", " ## NOTES\n", " # dynamic rnn automatically takes the seq size in dim=1 [batch_size, max_time, ...] time_major==false (default)\n", + " with tf.name_scope('RNN_output'):\n", + " outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)\n", " \n", - " outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, dtype=tf.float32)\n", " final_state = tf.identity(final_state, name='final_state')\n", " \n", " \n", @@ -624,7 +632,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 231, "metadata": { "collapsed": false, "deletable": true, @@ -635,6 +643,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "logits after reshape: Tensor(\"logits_reshape_to_loss/logits:0\", shape=(128, 5, 27), dtype=float32)\n", "Tests Passed\n" ] } @@ -651,7 +660,27 @@ " \"\"\"\n", " \n", " num_outputs = vocab_size\n", - " batch_size = input_data.get_shape().as_list()[0]\n", + " \n", + " \n", + " ## Not sure why the unit test was made without taking into \n", + " # account we are handling dynamic tensor shape that we need to infer\n", + " # at runtime, so I made an if statement just to pass the test case\n", + " #\n", + " # Some references: https://goo.gl/vD3egn\n", + " # https://goo.gl/E8vT2M \n", + " \n", + " if input_data.get_shape().as_list()[1] is not None:\n", + " batch_size = input_data.get_shape().as_list()[0]\n", + " seq_len = input_data.get_shape().as_list()[1]\n", + " \n", + " # Infer dynamic tensor shape of input\n", + " else:\n", + " input_dims = tf.shape(input_data)\n", + " batch_size = input_dims[0]\n", + " seq_len = input_dims[1]\n", + "\n", + " \n", + "\n", " \n", " embed = get_embed(input_data, vocab_size, HYPER.embedding_size)\n", " \n", @@ -663,23 +692,29 @@ " ## [batch_size, time_step, rnn_size]\n", " raw_rnn_outputs, final_state = build_rnn(cell, embed)\n", " \n", + " \n", " # Put outputs in rows\n", " # make the output into [batch_size*time_step, rnn_size] for easy matmul\n", - " outputs = tf.reshape(raw_rnn_outputs, [-1, rnn_size])\n", + " with tf.name_scope('sequence_reshape'):\n", + " outputs = tf.reshape(raw_rnn_outputs, [-1, rnn_size], name='rnn_output')\n", " \n", " \n", " # Question, why are we using linear activation and not softmax ?\n", " # My Guess: because seq2seq.sequence_loss has an efficient way to calculate the loss directly from logits \n", - " with tf.variable_scope('linear_layer'):\n", - " linear_w = tf.Variable(tf.truncated_normal((rnn_size, num_outputs), stddev=0.1), name='linear_w')\n", - " linear_b = tf.Variable(tf.zeros(num_outputs), name='linear_b')\n", + " with tf.name_scope('logits'):\n", " \n", - " logits = tf.matmul(outputs, linear_w) + linear_b\n", + " linear_w = tf.Variable(tf.truncated_normal((rnn_size, num_outputs), stddev=0.05), name='linear_w')\n", + " linear_b = tf.Variable(tf.zeros(num_outputs), name='linear_b')\n", + "\n", + " logits = tf.matmul(outputs, linear_w) + linear_b\n", + " \n", + " \n", " \n", " # Reshape the logits back into the original input shape -> [batch_size, seq_len, num_classes]\n", " # We do this beceause the loss function seq2seq.sequence_loss takes as logits a shape of [batch_size,seq_len,num_decoded_symbols]\n", - " logits = tf.reshape(logits, [batch_size, -1, num_outputs])\n", - " \n", + " with tf.name_scope('logits_reshape_to_loss'):\n", + " logits = tf.reshape(logits, [batch_size, seq_len, num_outputs], name='logits')\n", + " print('logits after reshape: ', logits)\n", " \n", " return logits, final_state\n", "\n", @@ -728,7 +763,7 @@ }, { "cell_type": "code", - "execution_count": 238, + "execution_count": 233, "metadata": { "collapsed": false, "deletable": true, @@ -739,135 +774,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "(7, 1280)\n", - "[[ 0 1 2 3 4]\n", - " [ 5 6 7 8 9]\n", - " [ 10 11 12 13 14]\n", - " [ 15 16 17 18 19]\n", - " [ 20 21 22 23 24]\n", - " [ 25 26 27 28 29]\n", - " [ 30 31 32 33 34]\n", - " [ 35 36 37 38 39]\n", - " [ 40 41 42 43 44]\n", - " [ 45 46 47 48 49]\n", - " [ 50 51 52 53 54]\n", - " [ 55 56 57 58 59]\n", - " [ 60 61 62 63 64]\n", - " [ 65 66 67 68 69]\n", - " [ 70 71 72 73 74]\n", - " [ 75 76 77 78 79]\n", - " [ 80 81 82 83 84]\n", - " [ 85 86 87 88 89]\n", - " [ 90 91 92 93 94]\n", - " [ 95 96 97 98 99]\n", - " [100 101 102 103 104]\n", - " [105 106 107 108 109]\n", - " [110 111 112 113 114]\n", - " [115 116 117 118 119]\n", - " [120 121 122 123 124]\n", - " [125 126 127 128 129]\n", - " [130 131 132 133 134]\n", - " [135 136 137 138 139]\n", - " [140 141 142 143 144]\n", - " [145 146 147 148 149]\n", - " [150 151 152 153 154]\n", - " [155 156 157 158 159]\n", - " [160 161 162 163 164]\n", - " [165 166 167 168 169]\n", - " [170 171 172 173 174]\n", - " [175 176 177 178 179]\n", - " [180 181 182 183 184]\n", - " [185 186 187 188 189]\n", - " [190 191 192 193 194]\n", - " [195 196 197 198 199]\n", - " [200 201 202 203 204]\n", - " [205 206 207 208 209]\n", - " [210 211 212 213 214]\n", - " [215 216 217 218 219]\n", - " [220 221 222 223 224]\n", - " [225 226 227 228 229]\n", - " [230 231 232 233 234]\n", - " [235 236 237 238 239]\n", - " [240 241 242 243 244]\n", - " [245 246 247 248 249]\n", - " [250 251 252 253 254]\n", - " [255 256 257 258 259]\n", - " [260 261 262 263 264]\n", - " [265 266 267 268 269]\n", - " [270 271 272 273 274]\n", - " [275 276 277 278 279]\n", - " [280 281 282 283 284]\n", - " [285 286 287 288 289]\n", - " [290 291 292 293 294]\n", - " [295 296 297 298 299]\n", - " [300 301 302 303 304]\n", - " [305 306 307 308 309]\n", - " [310 311 312 313 314]\n", - " [315 316 317 318 319]\n", - " [320 321 322 323 324]\n", - " [325 326 327 328 329]\n", - " [330 331 332 333 334]\n", - " [335 336 337 338 339]\n", - " [340 341 342 343 344]\n", - " [345 346 347 348 349]\n", - " [350 351 352 353 354]\n", - " [355 356 357 358 359]\n", - " [360 361 362 363 364]\n", - " [365 366 367 368 369]\n", - " [370 371 372 373 374]\n", - " [375 376 377 378 379]\n", - " [380 381 382 383 384]\n", - " [385 386 387 388 389]\n", - " [390 391 392 393 394]\n", - " [395 396 397 398 399]\n", - " [400 401 402 403 404]\n", - " [405 406 407 408 409]\n", - " [410 411 412 413 414]\n", - " [415 416 417 418 419]\n", - " [420 421 422 423 424]\n", - " [425 426 427 428 429]\n", - " [430 431 432 433 434]\n", - " [435 436 437 438 439]\n", - " [440 441 442 443 444]\n", - " [445 446 447 448 449]\n", - " [450 451 452 453 454]\n", - " [455 456 457 458 459]\n", - " [460 461 462 463 464]\n", - " [465 466 467 468 469]\n", - " [470 471 472 473 474]\n", - " [475 476 477 478 479]\n", - " [480 481 482 483 484]\n", - " [485 486 487 488 489]\n", - " [490 491 492 493 494]\n", - " [495 496 497 498 499]\n", - " [500 501 502 503 504]\n", - " [505 506 507 508 509]\n", - " [510 511 512 513 514]\n", - " [515 516 517 518 519]\n", - " [520 521 522 523 524]\n", - " [525 526 527 528 529]\n", - " [530 531 532 533 534]\n", - " [535 536 537 538 539]\n", - " [540 541 542 543 544]\n", - " [545 546 547 548 549]\n", - " [550 551 552 553 554]\n", - " [555 556 557 558 559]\n", - " [560 561 562 563 564]\n", - " [565 566 567 568 569]\n", - " [570 571 572 573 574]\n", - " [575 576 577 578 579]\n", - " [580 581 582 583 584]\n", - " [585 586 587 588 589]\n", - " [590 591 592 593 594]\n", - " [595 596 597 598 599]\n", - " [600 601 602 603 604]\n", - " [605 606 607 608 609]\n", - " [610 611 612 613 614]\n", - " [615 616 617 618 619]\n", - " [620 621 622 623 624]\n", - " [625 626 627 628 629]\n", - " [630 631 632 633 634]\n", - " [635 636 637 638 639]]\n", "Tests Passed\n" ] } @@ -898,12 +804,12 @@ " \n", " # concat inputs and targets\n", " batches = np.c_[inputs, targets]\n", - " print(batches.shape)\n", + " #print(batches.shape)\n", " \n", " # Reshape into final batches output\n", " batches = batches.reshape((-1, 2, batch_size, seq_length))\n", "\n", - " print(batches[0][0])\n", + " #print(batches[0][1])\n", "\n", " \n", " return batches\n", @@ -936,7 +842,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 234, "metadata": { "collapsed": true, "deletable": true, @@ -945,17 +851,17 @@ "outputs": [], "source": [ "# Number of Epochs\n", - "num_epochs = None\n", + "num_epochs = 1000\n", "# Batch Size\n", - "batch_size = None\n", + "batch_size = 128\n", "# RNN Size\n", - "rnn_size = None\n", + "rnn_size = 70\n", "# Sequence Length\n", - "seq_length = None\n", + "seq_length = 100\n", "# Learning Rate\n", - "learning_rate = None\n", + "learning_rate = 1e-3\n", "# Show stats for every n number of batches\n", - "show_every_n_batches = None\n", + "show_every_n_batches = 10\n", "\n", "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL THAT IS BELOW THIS LINE\n", @@ -976,13 +882,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 235, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "logits after reshape: Tensor(\"logits_reshape_to_loss/logits:0\", shape=(?, ?, 6779), dtype=float32)\n" + ] + } + ], "source": [ "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL\n", @@ -1015,6 +929,22 @@ " train_op = optimizer.apply_gradients(capped_gradients)" ] }, + { + "cell_type": "code", + "execution_count": 238, + "metadata": { + "collapsed": false, + "deletable": true, + "editable": true + }, + "outputs": [], + "source": [ + "# write out the graph for tensorboard\n", + "\n", + "with tf.Session(graph=train_graph) as sess:\n", + " file_writer = tf.summary.FileWriter('./logs/1', sess.graph)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -1028,13 +958,522 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 197, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5\n", + "Epoch 0 Batch 0/5 train_loss = 8.825\n", + "Epoch 2 Batch 0/5 train_loss = 6.441\n", + "Epoch 4 Batch 0/5 train_loss = 6.023\n", + "Epoch 6 Batch 0/5 train_loss = 5.927\n", + "Epoch 8 Batch 0/5 train_loss = 5.903\n", + "Epoch 10 Batch 0/5 train_loss = 5.883\n", + "Epoch 12 Batch 0/5 train_loss = 5.874\n", + "Epoch 14 Batch 0/5 train_loss = 5.858\n", + "Epoch 16 Batch 0/5 train_loss = 5.833\n", + "Epoch 18 Batch 0/5 train_loss = 5.794\n", + "Epoch 20 Batch 0/5 train_loss = 5.739\n", + "Epoch 22 Batch 0/5 train_loss = 5.682\n", + "Epoch 24 Batch 0/5 train_loss = 5.626\n", + "Epoch 26 Batch 0/5 train_loss = 5.572\n", + "Epoch 28 Batch 0/5 train_loss = 5.521\n", + "Epoch 30 Batch 0/5 train_loss = 5.471\n", + "Epoch 32 Batch 0/5 train_loss = 5.421\n", + "Epoch 34 Batch 0/5 train_loss = 5.365\n", + "Epoch 36 Batch 0/5 train_loss = 5.304\n", + "Epoch 38 Batch 0/5 train_loss = 5.244\n", + "Epoch 40 Batch 0/5 train_loss = 5.185\n", + "Epoch 42 Batch 0/5 train_loss = 5.124\n", + "Epoch 44 Batch 0/5 train_loss = 5.063\n", + "Epoch 46 Batch 0/5 train_loss = 5.003\n", + "Epoch 48 Batch 0/5 train_loss = 4.945\n", + "Epoch 50 Batch 0/5 train_loss = 4.891\n", + "Epoch 52 Batch 0/5 train_loss = 4.841\n", + "Epoch 54 Batch 0/5 train_loss = 4.794\n", + "Epoch 56 Batch 0/5 train_loss = 4.751\n", + "Epoch 58 Batch 0/5 train_loss = 4.710\n", + "Epoch 60 Batch 0/5 train_loss = 4.669\n", + "Epoch 62 Batch 0/5 train_loss = 4.638\n", + "Epoch 64 Batch 0/5 train_loss = 4.638\n", + "Epoch 66 Batch 0/5 train_loss = 4.589\n", + "Epoch 68 Batch 0/5 train_loss = 4.537\n", + "Epoch 70 Batch 0/5 train_loss = 4.501\n", + "Epoch 72 Batch 0/5 train_loss = 4.469\n", + "Epoch 74 Batch 0/5 train_loss = 4.436\n", + "Epoch 76 Batch 0/5 train_loss = 4.405\n", + "Epoch 78 Batch 0/5 train_loss = 4.375\n", + "Epoch 80 Batch 0/5 train_loss = 4.344\n", + "Epoch 82 Batch 0/5 train_loss = 4.363\n", + "Epoch 84 Batch 0/5 train_loss = 4.311\n", + "Epoch 86 Batch 0/5 train_loss = 4.274\n", + "Epoch 88 Batch 0/5 train_loss = 4.240\n", + "Epoch 90 Batch 0/5 train_loss = 4.211\n", + "Epoch 92 Batch 0/5 train_loss = 4.182\n", + "Epoch 94 Batch 0/5 train_loss = 4.155\n", + "Epoch 96 Batch 0/5 train_loss = 4.135\n", + "Epoch 98 Batch 0/5 train_loss = 4.107\n", + "Epoch 100 Batch 0/5 train_loss = 4.093\n", + "Epoch 102 Batch 0/5 train_loss = 4.053\n", + "Epoch 104 Batch 0/5 train_loss = 4.030\n", + "Epoch 106 Batch 0/5 train_loss = 4.002\n", + "Epoch 108 Batch 0/5 train_loss = 3.978\n", + "Epoch 110 Batch 0/5 train_loss = 3.951\n", + "Epoch 112 Batch 0/5 train_loss = 3.928\n", + "Epoch 114 Batch 0/5 train_loss = 3.902\n", + "Epoch 116 Batch 0/5 train_loss = 3.884\n", + "Epoch 118 Batch 0/5 train_loss = 3.862\n", + "Epoch 120 Batch 0/5 train_loss = 3.840\n", + "Epoch 122 Batch 0/5 train_loss = 3.814\n", + "Epoch 124 Batch 0/5 train_loss = 3.803\n", + "Epoch 126 Batch 0/5 train_loss = 3.775\n", + "Epoch 128 Batch 0/5 train_loss = 3.738\n", + "Epoch 130 Batch 0/5 train_loss = 3.714\n", + "Epoch 132 Batch 0/5 train_loss = 3.690\n", + "Epoch 134 Batch 0/5 train_loss = 3.665\n", + "Epoch 136 Batch 0/5 train_loss = 3.642\n", + "Epoch 138 Batch 0/5 train_loss = 3.619\n", + "Epoch 140 Batch 0/5 train_loss = 3.596\n", + "Epoch 142 Batch 0/5 train_loss = 3.577\n", + "Epoch 144 Batch 0/5 train_loss = 3.588\n", + "Epoch 146 Batch 0/5 train_loss = 3.561\n", + "Epoch 148 Batch 0/5 train_loss = 3.537\n", + "Epoch 150 Batch 0/5 train_loss = 3.494\n", + "Epoch 152 Batch 0/5 train_loss = 3.475\n", + "Epoch 154 Batch 0/5 train_loss = 3.444\n", + "Epoch 156 Batch 0/5 train_loss = 3.431\n", + "Epoch 158 Batch 0/5 train_loss = 3.403\n", + "Epoch 160 Batch 0/5 train_loss = 3.393\n", + "Epoch 162 Batch 0/5 train_loss = 3.371\n", + "Epoch 164 Batch 0/5 train_loss = 3.352\n", + "Epoch 166 Batch 0/5 train_loss = 3.323\n", + "Epoch 168 Batch 0/5 train_loss = 3.328\n", + "Epoch 170 Batch 0/5 train_loss = 3.281\n", + "Epoch 172 Batch 0/5 train_loss = 3.261\n", + "Epoch 174 Batch 0/5 train_loss = 3.238\n", + "Epoch 176 Batch 0/5 train_loss = 3.216\n", + "Epoch 178 Batch 0/5 train_loss = 3.197\n", + "Epoch 180 Batch 0/5 train_loss = 3.172\n", + "Epoch 182 Batch 0/5 train_loss = 3.169\n", + "Epoch 184 Batch 0/5 train_loss = 3.140\n", + "Epoch 186 Batch 0/5 train_loss = 3.136\n", + "Epoch 188 Batch 0/5 train_loss = 3.145\n", + "Epoch 190 Batch 0/5 train_loss = 3.106\n", + "Epoch 192 Batch 0/5 train_loss = 3.069\n", + "Epoch 194 Batch 0/5 train_loss = 3.038\n", + "Epoch 196 Batch 0/5 train_loss = 3.019\n", + "Epoch 198 Batch 0/5 train_loss = 2.995\n", + "Epoch 200 Batch 0/5 train_loss = 2.979\n", + "Epoch 202 Batch 0/5 train_loss = 2.960\n", + "Epoch 204 Batch 0/5 train_loss = 2.943\n", + "Epoch 206 Batch 0/5 train_loss = 2.963\n", + "Epoch 208 Batch 0/5 train_loss = 2.917\n", + "Epoch 210 Batch 0/5 train_loss = 2.898\n", + "Epoch 212 Batch 0/5 train_loss = 2.867\n", + "Epoch 214 Batch 0/5 train_loss = 2.863\n", + "Epoch 216 Batch 0/5 train_loss = 2.834\n", + "Epoch 218 Batch 0/5 train_loss = 2.809\n", + "Epoch 220 Batch 0/5 train_loss = 2.797\n", + "Epoch 222 Batch 0/5 train_loss = 2.774\n", + "Epoch 224 Batch 0/5 train_loss = 2.759\n", + "Epoch 226 Batch 0/5 train_loss = 2.732\n", + "Epoch 228 Batch 0/5 train_loss = 2.742\n", + "Epoch 230 Batch 0/5 train_loss = 2.704\n", + "Epoch 232 Batch 0/5 train_loss = 2.703\n", + "Epoch 234 Batch 0/5 train_loss = 2.663\n", + "Epoch 236 Batch 0/5 train_loss = 2.672\n", + "Epoch 238 Batch 0/5 train_loss = 2.638\n", + "Epoch 240 Batch 0/5 train_loss = 2.620\n", + "Epoch 242 Batch 0/5 train_loss = 2.595\n", + "Epoch 244 Batch 0/5 train_loss = 2.585\n", + "Epoch 246 Batch 0/5 train_loss = 2.563\n", + "Epoch 248 Batch 0/5 train_loss = 2.539\n", + "Epoch 250 Batch 0/5 train_loss = 2.534\n", + "Epoch 252 Batch 0/5 train_loss = 2.517\n", + "Epoch 254 Batch 0/5 train_loss = 2.497\n", + "Epoch 256 Batch 0/5 train_loss = 2.475\n", + "Epoch 258 Batch 0/5 train_loss = 2.463\n", + "Epoch 260 Batch 0/5 train_loss = 2.478\n", + "Epoch 262 Batch 0/5 train_loss = 2.450\n", + "Epoch 264 Batch 0/5 train_loss = 2.436\n", + "Epoch 266 Batch 0/5 train_loss = 2.417\n", + "Epoch 268 Batch 0/5 train_loss = 2.384\n", + "Epoch 270 Batch 0/5 train_loss = 2.363\n", + "Epoch 272 Batch 0/5 train_loss = 2.340\n", + "Epoch 274 Batch 0/5 train_loss = 2.323\n", + "Epoch 276 Batch 0/5 train_loss = 2.314\n", + "Epoch 278 Batch 0/5 train_loss = 2.302\n", + "Epoch 280 Batch 0/5 train_loss = 2.300\n", + "Epoch 282 Batch 0/5 train_loss = 2.300\n", + "Epoch 284 Batch 0/5 train_loss = 2.283\n", + "Epoch 286 Batch 0/5 train_loss = 2.246\n", + "Epoch 288 Batch 0/5 train_loss = 2.246\n", + "Epoch 290 Batch 0/5 train_loss = 2.210\n", + "Epoch 292 Batch 0/5 train_loss = 2.203\n", + "Epoch 294 Batch 0/5 train_loss = 2.185\n", + "Epoch 296 Batch 0/5 train_loss = 2.170\n", + "Epoch 298 Batch 0/5 train_loss = 2.150\n", + "Epoch 300 Batch 0/5 train_loss = 2.130\n", + "Epoch 302 Batch 0/5 train_loss = 2.132\n", + "Epoch 304 Batch 0/5 train_loss = 2.113\n", + "Epoch 306 Batch 0/5 train_loss = 2.083\n", + "Epoch 308 Batch 0/5 train_loss = 2.073\n", + "Epoch 310 Batch 0/5 train_loss = 2.060\n", + "Epoch 312 Batch 0/5 train_loss = 2.072\n", + "Epoch 314 Batch 0/5 train_loss = 2.081\n", + "Epoch 316 Batch 0/5 train_loss = 2.031\n", + "Epoch 318 Batch 0/5 train_loss = 2.007\n", + "Epoch 320 Batch 0/5 train_loss = 2.001\n", + "Epoch 322 Batch 0/5 train_loss = 1.987\n", + "Epoch 324 Batch 0/5 train_loss = 1.978\n", + "Epoch 326 Batch 0/5 train_loss = 1.963\n", + "Epoch 328 Batch 0/5 train_loss = 1.952\n", + "Epoch 330 Batch 0/5 train_loss = 1.932\n", + "Epoch 332 Batch 0/5 train_loss = 1.918\n", + "Epoch 334 Batch 0/5 train_loss = 1.898\n", + "Epoch 336 Batch 0/5 train_loss = 1.885\n", + "Epoch 338 Batch 0/5 train_loss = 1.872\n", + "Epoch 340 Batch 0/5 train_loss = 1.864\n", + "Epoch 342 Batch 0/5 train_loss = 1.867\n", + "Epoch 344 Batch 0/5 train_loss = 1.848\n", + "Epoch 346 Batch 0/5 train_loss = 1.821\n", + "Epoch 348 Batch 0/5 train_loss = 1.814\n", + "Epoch 350 Batch 0/5 train_loss = 1.788\n", + "Epoch 352 Batch 0/5 train_loss = 1.806\n", + "Epoch 354 Batch 0/5 train_loss = 1.790\n", + "Epoch 356 Batch 0/5 train_loss = 1.761\n", + "Epoch 358 Batch 0/5 train_loss = 1.745\n", + "Epoch 360 Batch 0/5 train_loss = 1.735\n", + "Epoch 362 Batch 0/5 train_loss = 1.718\n", + "Epoch 364 Batch 0/5 train_loss = 1.747\n", + "Epoch 366 Batch 0/5 train_loss = 1.726\n", + "Epoch 368 Batch 0/5 train_loss = 1.753\n", + "Epoch 370 Batch 0/5 train_loss = 1.703\n", + "Epoch 372 Batch 0/5 train_loss = 1.662\n", + "Epoch 374 Batch 0/5 train_loss = 1.643\n", + "Epoch 376 Batch 0/5 train_loss = 1.624\n", + "Epoch 378 Batch 0/5 train_loss = 1.617\n", + "Epoch 380 Batch 0/5 train_loss = 1.598\n", + "Epoch 382 Batch 0/5 train_loss = 1.613\n", + "Epoch 384 Batch 0/5 train_loss = 1.601\n", + "Epoch 386 Batch 0/5 train_loss = 1.584\n", + "Epoch 388 Batch 0/5 train_loss = 1.569\n", + "Epoch 390 Batch 0/5 train_loss = 1.557\n", + "Epoch 392 Batch 0/5 train_loss = 1.534\n", + "Epoch 394 Batch 0/5 train_loss = 1.534\n", + "Epoch 396 Batch 0/5 train_loss = 1.520\n", + "Epoch 398 Batch 0/5 train_loss = 1.547\n", + "Epoch 400 Batch 0/5 train_loss = 1.545\n", + "Epoch 402 Batch 0/5 train_loss = 1.521\n", + "Epoch 404 Batch 0/5 train_loss = 1.486\n", + "Epoch 406 Batch 0/5 train_loss = 1.469\n", + "Epoch 408 Batch 0/5 train_loss = 1.458\n", + "Epoch 410 Batch 0/5 train_loss = 1.442\n", + "Epoch 412 Batch 0/5 train_loss = 1.431\n", + "Epoch 414 Batch 0/5 train_loss = 1.410\n", + "Epoch 416 Batch 0/5 train_loss = 1.411\n", + "Epoch 418 Batch 0/5 train_loss = 1.412\n", + "Epoch 420 Batch 0/5 train_loss = 1.398\n", + "Epoch 422 Batch 0/5 train_loss = 1.417\n", + "Epoch 424 Batch 0/5 train_loss = 1.381\n", + "Epoch 426 Batch 0/5 train_loss = 1.355\n", + "Epoch 428 Batch 0/5 train_loss = 1.354\n", + "Epoch 430 Batch 0/5 train_loss = 1.338\n", + "Epoch 432 Batch 0/5 train_loss = 1.321\n", + "Epoch 434 Batch 0/5 train_loss = 1.326\n", + "Epoch 436 Batch 0/5 train_loss = 1.324\n", + "Epoch 438 Batch 0/5 train_loss = 1.314\n", + "Epoch 440 Batch 0/5 train_loss = 1.292\n", + "Epoch 442 Batch 0/5 train_loss = 1.279\n", + "Epoch 444 Batch 0/5 train_loss = 1.259\n", + "Epoch 446 Batch 0/5 train_loss = 1.283\n", + "Epoch 448 Batch 0/5 train_loss = 1.274\n", + "Epoch 450 Batch 0/5 train_loss = 1.251\n", + "Epoch 452 Batch 0/5 train_loss = 1.279\n", + "Epoch 454 Batch 0/5 train_loss = 1.249\n", + "Epoch 456 Batch 0/5 train_loss = 1.214\n", + "Epoch 458 Batch 0/5 train_loss = 1.196\n", + "Epoch 460 Batch 0/5 train_loss = 1.185\n", + "Epoch 462 Batch 0/5 train_loss = 1.174\n", + "Epoch 464 Batch 0/5 train_loss = 1.158\n", + "Epoch 466 Batch 0/5 train_loss = 1.195\n", + "Epoch 468 Batch 0/5 train_loss = 1.158\n", + "Epoch 470 Batch 0/5 train_loss = 1.145\n", + "Epoch 472 Batch 0/5 train_loss = 1.160\n", + "Epoch 474 Batch 0/5 train_loss = 1.123\n", + "Epoch 476 Batch 0/5 train_loss = 1.118\n", + "Epoch 478 Batch 0/5 train_loss = 1.103\n", + "Epoch 480 Batch 0/5 train_loss = 1.088\n", + "Epoch 482 Batch 0/5 train_loss = 1.089\n", + "Epoch 484 Batch 0/5 train_loss = 1.094\n", + "Epoch 486 Batch 0/5 train_loss = 1.092\n", + "Epoch 488 Batch 0/5 train_loss = 1.106\n", + "Epoch 490 Batch 0/5 train_loss = 1.053\n", + "Epoch 492 Batch 0/5 train_loss = 1.052\n", + "Epoch 494 Batch 0/5 train_loss = 1.046\n", + "Epoch 496 Batch 0/5 train_loss = 1.030\n", + "Epoch 498 Batch 0/5 train_loss = 1.021\n", + "Epoch 500 Batch 0/5 train_loss = 1.020\n", + "Epoch 502 Batch 0/5 train_loss = 1.046\n", + "Epoch 504 Batch 0/5 train_loss = 1.040\n", + "Epoch 506 Batch 0/5 train_loss = 1.026\n", + "Epoch 508 Batch 0/5 train_loss = 0.982\n", + "Epoch 510 Batch 0/5 train_loss = 0.969\n", + "Epoch 512 Batch 0/5 train_loss = 0.962\n", + "Epoch 514 Batch 0/5 train_loss = 0.946\n", + "Epoch 516 Batch 0/5 train_loss = 0.941\n", + "Epoch 518 Batch 0/5 train_loss = 0.951\n", + "Epoch 520 Batch 0/5 train_loss = 0.945\n", + "Epoch 522 Batch 0/5 train_loss = 0.952\n", + "Epoch 524 Batch 0/5 train_loss = 0.931\n", + "Epoch 526 Batch 0/5 train_loss = 0.905\n", + "Epoch 528 Batch 0/5 train_loss = 0.893\n", + "Epoch 530 Batch 0/5 train_loss = 0.881\n", + "Epoch 532 Batch 0/5 train_loss = 0.882\n", + "Epoch 534 Batch 0/5 train_loss = 0.871\n", + "Epoch 536 Batch 0/5 train_loss = 0.904\n", + "Epoch 538 Batch 0/5 train_loss = 0.893\n", + "Epoch 540 Batch 0/5 train_loss = 0.884\n", + "Epoch 542 Batch 0/5 train_loss = 0.864\n", + "Epoch 544 Batch 0/5 train_loss = 0.854\n", + "Epoch 546 Batch 0/5 train_loss = 0.854\n", + "Epoch 548 Batch 0/5 train_loss = 0.836\n", + "Epoch 550 Batch 0/5 train_loss = 0.816\n", + "Epoch 552 Batch 0/5 train_loss = 0.829\n", + "Epoch 554 Batch 0/5 train_loss = 0.813\n", + "Epoch 556 Batch 0/5 train_loss = 0.798\n", + "Epoch 558 Batch 0/5 train_loss = 0.808\n", + "Epoch 560 Batch 0/5 train_loss = 0.789\n", + "Epoch 562 Batch 0/5 train_loss = 0.791\n", + "Epoch 564 Batch 0/5 train_loss = 0.779\n", + "Epoch 566 Batch 0/5 train_loss = 0.765\n", + "Epoch 568 Batch 0/5 train_loss = 0.746\n", + "Epoch 570 Batch 0/5 train_loss = 0.746\n", + "Epoch 572 Batch 0/5 train_loss = 0.733\n", + "Epoch 574 Batch 0/5 train_loss = 0.733\n", + "Epoch 576 Batch 0/5 train_loss = 0.752\n", + "Epoch 578 Batch 0/5 train_loss = 0.727\n", + "Epoch 580 Batch 0/5 train_loss = 0.712\n", + "Epoch 582 Batch 0/5 train_loss = 0.711\n", + "Epoch 584 Batch 0/5 train_loss = 0.708\n", + "Epoch 586 Batch 0/5 train_loss = 0.695\n", + "Epoch 588 Batch 0/5 train_loss = 0.699\n", + "Epoch 590 Batch 0/5 train_loss = 0.688\n", + "Epoch 592 Batch 0/5 train_loss = 0.682\n", + "Epoch 594 Batch 0/5 train_loss = 0.703\n", + "Epoch 596 Batch 0/5 train_loss = 0.681\n", + "Epoch 598 Batch 0/5 train_loss = 0.672\n", + "Epoch 600 Batch 0/5 train_loss = 0.678\n", + "Epoch 602 Batch 0/5 train_loss = 0.657\n", + "Epoch 604 Batch 0/5 train_loss = 0.652\n", + "Epoch 606 Batch 0/5 train_loss = 0.627\n", + "Epoch 608 Batch 0/5 train_loss = 0.623\n", + "Epoch 610 Batch 0/5 train_loss = 0.633\n", + "Epoch 612 Batch 0/5 train_loss = 0.608\n", + "Epoch 614 Batch 0/5 train_loss = 0.614\n", + "Epoch 616 Batch 0/5 train_loss = 0.620\n", + "Epoch 618 Batch 0/5 train_loss = 0.610\n", + "Epoch 620 Batch 0/5 train_loss = 0.596\n", + "Epoch 622 Batch 0/5 train_loss = 0.596\n", + "Epoch 624 Batch 0/5 train_loss = 0.605\n", + "Epoch 626 Batch 0/5 train_loss = 0.574\n", + "Epoch 628 Batch 0/5 train_loss = 0.581\n", + "Epoch 630 Batch 0/5 train_loss = 0.571\n", + "Epoch 632 Batch 0/5 train_loss = 0.563\n", + "Epoch 634 Batch 0/5 train_loss = 0.582\n", + "Epoch 636 Batch 0/5 train_loss = 0.579\n", + "Epoch 638 Batch 0/5 train_loss = 0.562\n", + "Epoch 640 Batch 0/5 train_loss = 0.549\n", + "Epoch 642 Batch 0/5 train_loss = 0.540\n", + "Epoch 644 Batch 0/5 train_loss = 0.520\n", + "Epoch 646 Batch 0/5 train_loss = 0.515\n", + "Epoch 648 Batch 0/5 train_loss = 0.509\n", + "Epoch 650 Batch 0/5 train_loss = 0.509\n", + "Epoch 652 Batch 0/5 train_loss = 0.527\n", + "Epoch 654 Batch 0/5 train_loss = 0.524\n", + "Epoch 656 Batch 0/5 train_loss = 0.509\n", + "Epoch 658 Batch 0/5 train_loss = 0.523\n", + "Epoch 660 Batch 0/5 train_loss = 0.502\n", + "Epoch 662 Batch 0/5 train_loss = 0.477\n", + "Epoch 664 Batch 0/5 train_loss = 0.473\n", + "Epoch 666 Batch 0/5 train_loss = 0.463\n", + "Epoch 668 Batch 0/5 train_loss = 0.457\n", + "Epoch 670 Batch 0/5 train_loss = 0.455\n", + "Epoch 672 Batch 0/5 train_loss = 0.459\n", + "Epoch 674 Batch 0/5 train_loss = 0.475\n", + "Epoch 676 Batch 0/5 train_loss = 0.471\n", + "Epoch 678 Batch 0/5 train_loss = 0.455\n", + "Epoch 680 Batch 0/5 train_loss = 0.443\n", + "Epoch 682 Batch 0/5 train_loss = 0.456\n", + "Epoch 684 Batch 0/5 train_loss = 0.440\n", + "Epoch 686 Batch 0/5 train_loss = 0.421\n", + "Epoch 688 Batch 0/5 train_loss = 0.413\n", + "Epoch 690 Batch 0/5 train_loss = 0.405\n", + "Epoch 692 Batch 0/5 train_loss = 0.401\n", + "Epoch 694 Batch 0/5 train_loss = 0.404\n", + "Epoch 696 Batch 0/5 train_loss = 0.400\n", + "Epoch 698 Batch 0/5 train_loss = 0.428\n", + "Epoch 700 Batch 0/5 train_loss = 0.451\n", + "Epoch 702 Batch 0/5 train_loss = 0.426\n", + "Epoch 704 Batch 0/5 train_loss = 0.410\n", + "Epoch 706 Batch 0/5 train_loss = 0.422\n", + "Epoch 708 Batch 0/5 train_loss = 0.398\n", + "Epoch 710 Batch 0/5 train_loss = 0.377\n", + "Epoch 712 Batch 0/5 train_loss = 0.368\n", + "Epoch 714 Batch 0/5 train_loss = 0.358\n", + "Epoch 716 Batch 0/5 train_loss = 0.352\n", + "Epoch 718 Batch 0/5 train_loss = 0.349\n", + "Epoch 720 Batch 0/5 train_loss = 0.344\n", + "Epoch 722 Batch 0/5 train_loss = 0.346\n", + "Epoch 724 Batch 0/5 train_loss = 0.345\n", + "Epoch 726 Batch 0/5 train_loss = 0.337\n", + "Epoch 728 Batch 0/5 train_loss = 0.345\n", + "Epoch 730 Batch 0/5 train_loss = 0.348\n", + "Epoch 732 Batch 0/5 train_loss = 0.358\n", + "Epoch 734 Batch 0/5 train_loss = 0.346\n", + "Epoch 736 Batch 0/5 train_loss = 0.337\n", + "Epoch 738 Batch 0/5 train_loss = 0.329\n", + "Epoch 740 Batch 0/5 train_loss = 0.320\n", + "Epoch 742 Batch 0/5 train_loss = 0.323\n", + "Epoch 744 Batch 0/5 train_loss = 0.316\n", + "Epoch 746 Batch 0/5 train_loss = 0.304\n", + "Epoch 748 Batch 0/5 train_loss = 0.299\n", + "Epoch 750 Batch 0/5 train_loss = 0.292\n", + "Epoch 752 Batch 0/5 train_loss = 0.288\n", + "Epoch 754 Batch 0/5 train_loss = 0.289\n", + "Epoch 756 Batch 0/5 train_loss = 0.284\n", + "Epoch 758 Batch 0/5 train_loss = 0.290\n", + "Epoch 760 Batch 0/5 train_loss = 0.304\n", + "Epoch 762 Batch 0/5 train_loss = 0.311\n", + "Epoch 764 Batch 0/5 train_loss = 0.405\n", + "Epoch 766 Batch 0/5 train_loss = 0.390\n", + "Epoch 768 Batch 0/5 train_loss = 0.344\n", + "Epoch 770 Batch 0/5 train_loss = 0.320\n", + "Epoch 772 Batch 0/5 train_loss = 0.280\n", + "Epoch 774 Batch 0/5 train_loss = 0.265\n", + "Epoch 776 Batch 0/5 train_loss = 0.258\n", + "Epoch 778 Batch 0/5 train_loss = 0.252\n", + "Epoch 780 Batch 0/5 train_loss = 0.247\n", + "Epoch 782 Batch 0/5 train_loss = 0.243\n", + "Epoch 784 Batch 0/5 train_loss = 0.240\n", + "Epoch 786 Batch 0/5 train_loss = 0.237\n", + "Epoch 788 Batch 0/5 train_loss = 0.233\n", + "Epoch 790 Batch 0/5 train_loss = 0.231\n", + "Epoch 792 Batch 0/5 train_loss = 0.229\n", + "Epoch 794 Batch 0/5 train_loss = 0.225\n", + "Epoch 796 Batch 0/5 train_loss = 0.230\n", + "Epoch 798 Batch 0/5 train_loss = 0.226\n", + "Epoch 800 Batch 0/5 train_loss = 0.222\n", + "Epoch 802 Batch 0/5 train_loss = 0.237\n", + "Epoch 804 Batch 0/5 train_loss = 0.225\n", + "Epoch 806 Batch 0/5 train_loss = 0.225\n", + "Epoch 808 Batch 0/5 train_loss = 0.245\n", + "Epoch 810 Batch 0/5 train_loss = 0.227\n", + "Epoch 812 Batch 0/5 train_loss = 0.210\n", + "Epoch 814 Batch 0/5 train_loss = 0.206\n", + "Epoch 816 Batch 0/5 train_loss = 0.202\n", + "Epoch 818 Batch 0/5 train_loss = 0.198\n", + "Epoch 820 Batch 0/5 train_loss = 0.195\n", + "Epoch 822 Batch 0/5 train_loss = 0.192\n", + "Epoch 824 Batch 0/5 train_loss = 0.189\n", + "Epoch 826 Batch 0/5 train_loss = 0.189\n", + "Epoch 828 Batch 0/5 train_loss = 0.187\n", + "Epoch 830 Batch 0/5 train_loss = 0.186\n", + "Epoch 832 Batch 0/5 train_loss = 0.187\n", + "Epoch 834 Batch 0/5 train_loss = 0.189\n", + "Epoch 836 Batch 0/5 train_loss = 0.189\n", + "Epoch 838 Batch 0/5 train_loss = 0.197\n", + "Epoch 840 Batch 0/5 train_loss = 0.207\n", + "Epoch 842 Batch 0/5 train_loss = 0.196\n", + "Epoch 844 Batch 0/5 train_loss = 0.187\n", + "Epoch 846 Batch 0/5 train_loss = 0.197\n", + "Epoch 848 Batch 0/5 train_loss = 0.189\n", + "Epoch 850 Batch 0/5 train_loss = 0.176\n", + "Epoch 852 Batch 0/5 train_loss = 0.171\n", + "Epoch 854 Batch 0/5 train_loss = 0.164\n", + "Epoch 856 Batch 0/5 train_loss = 0.161\n", + "Epoch 858 Batch 0/5 train_loss = 0.157\n", + "Epoch 860 Batch 0/5 train_loss = 0.154\n", + "Epoch 862 Batch 0/5 train_loss = 0.152\n", + "Epoch 864 Batch 0/5 train_loss = 0.150\n", + "Epoch 866 Batch 0/5 train_loss = 0.148\n", + "Epoch 868 Batch 0/5 train_loss = 0.146\n", + "Epoch 870 Batch 0/5 train_loss = 0.145\n", + "Epoch 872 Batch 0/5 train_loss = 0.145\n", + "Epoch 874 Batch 0/5 train_loss = 0.142\n", + "Epoch 876 Batch 0/5 train_loss = 0.143\n", + "Epoch 878 Batch 0/5 train_loss = 0.159\n", + "Epoch 880 Batch 0/5 train_loss = 0.145\n", + "Epoch 882 Batch 0/5 train_loss = 0.161\n", + "Epoch 884 Batch 0/5 train_loss = 0.211\n", + "Epoch 886 Batch 0/5 train_loss = 0.196\n", + "Epoch 888 Batch 0/5 train_loss = 0.335\n", + "Epoch 890 Batch 0/5 train_loss = 0.325\n", + "Epoch 892 Batch 0/5 train_loss = 0.279\n", + "Epoch 894 Batch 0/5 train_loss = 0.244\n", + "Epoch 896 Batch 0/5 train_loss = 0.214\n", + "Epoch 898 Batch 0/5 train_loss = 0.174\n", + "Epoch 900 Batch 0/5 train_loss = 0.147\n", + "Epoch 902 Batch 0/5 train_loss = 0.138\n", + "Epoch 904 Batch 0/5 train_loss = 0.131\n", + "Epoch 906 Batch 0/5 train_loss = 0.128\n", + "Epoch 908 Batch 0/5 train_loss = 0.125\n", + "Epoch 910 Batch 0/5 train_loss = 0.123\n", + "Epoch 912 Batch 0/5 train_loss = 0.121\n", + "Epoch 914 Batch 0/5 train_loss = 0.119\n", + "Epoch 916 Batch 0/5 train_loss = 0.117\n", + "Epoch 918 Batch 0/5 train_loss = 0.116\n", + "Epoch 920 Batch 0/5 train_loss = 0.114\n", + "Epoch 922 Batch 0/5 train_loss = 0.113\n", + "Epoch 924 Batch 0/5 train_loss = 0.112\n", + "Epoch 926 Batch 0/5 train_loss = 0.111\n", + "Epoch 928 Batch 0/5 train_loss = 0.109\n", + "Epoch 930 Batch 0/5 train_loss = 0.108\n", + "Epoch 932 Batch 0/5 train_loss = 0.107\n", + "Epoch 934 Batch 0/5 train_loss = 0.106\n", + "Epoch 936 Batch 0/5 train_loss = 0.105\n", + "Epoch 938 Batch 0/5 train_loss = 0.103\n", + "Epoch 940 Batch 0/5 train_loss = 0.102\n", + "Epoch 942 Batch 0/5 train_loss = 0.101\n", + "Epoch 944 Batch 0/5 train_loss = 0.100\n", + "Epoch 946 Batch 0/5 train_loss = 0.099\n", + "Epoch 948 Batch 0/5 train_loss = 0.098\n", + "Epoch 950 Batch 0/5 train_loss = 0.097\n", + "Epoch 952 Batch 0/5 train_loss = 0.096\n", + "Epoch 954 Batch 0/5 train_loss = 0.095\n", + "Epoch 956 Batch 0/5 train_loss = 0.094\n", + "Epoch 958 Batch 0/5 train_loss = 0.094\n", + "Epoch 960 Batch 0/5 train_loss = 0.093\n", + "Epoch 962 Batch 0/5 train_loss = 0.092\n", + "Epoch 964 Batch 0/5 train_loss = 0.091\n", + "Epoch 966 Batch 0/5 train_loss = 0.090\n", + "Epoch 968 Batch 0/5 train_loss = 0.089\n", + "Epoch 970 Batch 0/5 train_loss = 0.088\n", + "Epoch 972 Batch 0/5 train_loss = 0.088\n", + "Epoch 974 Batch 0/5 train_loss = 0.087\n", + "Epoch 976 Batch 0/5 train_loss = 0.086\n", + "Epoch 978 Batch 0/5 train_loss = 0.085\n", + "Epoch 980 Batch 0/5 train_loss = 0.084\n", + "Epoch 982 Batch 0/5 train_loss = 0.083\n", + "Epoch 984 Batch 0/5 train_loss = 0.083\n", + "Epoch 986 Batch 0/5 train_loss = 0.082\n", + "Epoch 988 Batch 0/5 train_loss = 0.081\n", + "Epoch 990 Batch 0/5 train_loss = 0.080\n", + "Epoch 992 Batch 0/5 train_loss = 0.080\n", + "Epoch 994 Batch 0/5 train_loss = 0.079\n", + "Epoch 996 Batch 0/5 train_loss = 0.078\n", + "Epoch 998 Batch 0/5 train_loss = 0.078\n", + "Model Trained and Saved\n" + ] + } + ], "source": [ "\"\"\"\n", "DON'T MODIFY ANYTHING IN THIS CELL\n", @@ -1082,7 +1521,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 198, "metadata": { "collapsed": false, "deletable": true, @@ -1109,7 +1548,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 272, "metadata": { "collapsed": false, "deletable": true, @@ -1149,13 +1588,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 273, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tests Passed\n" + ] + } + ], "source": [ "def get_tensors(loaded_graph):\n", " \"\"\"\n", @@ -1163,8 +1610,12 @@ " :param loaded_graph: TensorFlow graph loaded from file\n", " :return: Tuple (InputTensor, InitialStateTensor, FinalStateTensor, ProbsTensor)\n", " \"\"\"\n", - " # TODO: Implement Function\n", - " return None, None, None, None\n", + " \n", + " t_input = loaded_graph.get_tensor_by_name('input:0')\n", + " t_initial_state = loaded_graph.get_tensor_by_name('initial_state:0')\n", + " t_final_state = loaded_graph.get_tensor_by_name('final_state:0')\n", + " t_probs = loaded_graph.get_tensor_by_name('probs:0')\n", + " return t_input, t_initial_state, t_final_state, t_probs\n", "\n", "\n", "\"\"\"\n", @@ -1186,13 +1637,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 274, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tests Passed\n" + ] + } + ], "source": [ "def pick_word(probabilities, int_to_vocab):\n", " \"\"\"\n", @@ -1201,8 +1660,10 @@ " :param int_to_vocab: Dictionary of word ids as the keys and words as the values\n", " :return: String of the predicted word\n", " \"\"\"\n", - " # TODO: Implement Function\n", - " return None\n", + " \n", + " word = int_to_vocab[np.argmax(probabilities)]\n", + " \n", + " return word\n", "\n", "\n", "\"\"\"\n", @@ -1224,13 +1685,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 275, "metadata": { "collapsed": false, "deletable": true, "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "moe_szyslak: sizes good-looking slap detective_homer_simpson: takin' cesss planning parrot smoke parrot sizes frustrated choked slap gesture elmo's jerry duff's butterball officials sizes themselves gesture whiny irrelevant paintings continuing huddle tony butterball worst jerry neighborhood slap slap slap detective_homer_simpson: meatpies crooks sail slap slap slap sizes worst mr slap worst gesture parrot calendars bathed schnapps butterball stuck jerry dash my-y-y-y-y-y slap slap slap detective_homer_simpson: rain gesture bashir's jerry longest slap slap slap detective_homer_simpson: realize gesture parrot neighborhood jerry dad's poet presided scrutinizes presided rope neighborhood booth detective_homer_simpson: enjoyed gesture electronic sam: jerry dash my-y-y-y-y-y butterball protestantism dash my-y-y-y-y-y friendly dash happiness agreement slap protestantism muttering muttering sugar-free parrot is: abandon fudd scrutinizes detective_homer_simpson: itself duff's butterball drinker slap muttering shaky slap cuff giant face knockin' tv-station_announcer: that's slap detective_homer_simpson: celebrate rubbed 2nd_voice_on_transmitter: further rubbed usual laramie bunch slap detective_homer_simpson: itself gesture child jerry premise poet sarcastic slap detective_homer_simpson: meatpies skydiving scrutinizes scream renee: scrutinizes detective_homer_simpson: itself lenses butterball tapered smokin' 2nd_voice_on_transmitter: slap detective_homer_simpson: detective_homer_simpson: detective_homer_simpson: aims always butterball oh-so-sophisticated wine dislike sizes bury gang butterball renee: rope laramie themselves beings slap detective_homer_simpson: rain indicates butterball stunned slap detective_homer_simpson: rain arts butterball ratted 2nd_voice_on_transmitter: pepsi oh-so-sophisticated planning booth rope presided rope abandon worst\n" + ] + } + ], "source": [ "gen_length = 200\n", "# homer_simpson, moe_szyslak, or Barney_Gumble\n", @@ -1309,7 +1778,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.1" + "version": "3.5.2" }, "toc": { "colors": { diff --git a/tv-script-generation/logs/1/events.out.tfevents.1490895533.ip-172-31-18-64 b/tv-script-generation/logs/1/events.out.tfevents.1490895533.ip-172-31-18-64 new file mode 100644 index 0000000..e4b0945 Binary files /dev/null and b/tv-script-generation/logs/1/events.out.tfevents.1490895533.ip-172-31-18-64 differ diff --git a/tv-script-generation/params.p b/tv-script-generation/params.p new file mode 100644 index 0000000..4ff1628 --- /dev/null +++ b/tv-script-generation/params.p @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43aa05ca53cc94bc12959afc391adedffb9df33345748d2a55f22d6c587bd8cf +size 21 diff --git a/tv-script-generation/preprocess.p b/tv-script-generation/preprocess.p index ac84c3d..d4b2507 100644 --- a/tv-script-generation/preprocess.p +++ b/tv-script-generation/preprocess.p @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:be490d22f681ecd706cf3a53a5e7ec52421d9eb7f11e4cdb612740f915b9eab7 -size 391674 +oid sha256:0c4525fb720ac816b5ee0720ef52e698cdab40204342e87096c2b4356b9829a8 +size 387442 diff --git a/tv-script-generation/save.data-00000-of-00001 b/tv-script-generation/save.data-00000-of-00001 new file mode 100644 index 0000000..de861ec --- /dev/null +++ b/tv-script-generation/save.data-00000-of-00001 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a0493768a7bb2c6cee62b851a65725d1a0b03e9845e0e8912efdef605566c426 +size 49730300 diff --git a/tv-script-generation/save.index b/tv-script-generation/save.index new file mode 100644 index 0000000..3a25b31 --- /dev/null +++ b/tv-script-generation/save.index @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:848fb0f5cc5f656c5f5ced00df9f6ba77eef8d67cccef78a03daff24e8e746e3 +size 981 diff --git a/tv-script-generation/save.meta b/tv-script-generation/save.meta new file mode 100644 index 0000000..ca2d46f --- /dev/null +++ b/tv-script-generation/save.meta @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab485075110bf5e5ddfc66a4de50319cc84615affd5098e605e34f1dd6f7bf9a +size 303291