Stay Hungry.Stay Foolish.

用 LSTM 完成分类任务

循环神经网络(Recurrent Neural Network)简介中我们了解了什么是 RNN,本文用 TensorFlow 实现一个超级简单的分类任务,并对代码进行详细说明,防止自己以后忘记( ╯□╰ )。

代码:

import tensorflow as tf
import time
from tensorflow.contrib import rnn

mnist = input_data.read_data_sets("data/", one_hot=True)

lr = 0.001
training_iters = 100000
batch_size = 128
display_step = 10

n_input = 28
n_step = 28
n_hidden = 128
n_classes = 10

x = tf.placeholder(tf.float32, [None, n_step, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
w = tf.Variable(tf.random_normal([n_hidden, n_classes]))
b = tf.Variable(tf.random_normal([n_classes]))

def RNN(x):
    cell = rnn.BasicLSTMCell(n_hidden)
    # outputs shape 为(batch, time_step, output_size)
    outputs, states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32)
    # 一个 time_step 就是对一行像素的预测,因为是分类任务,我们只需要最后一个 time_step
    # 的预测,因为它包含了前面所有 time_step 所蕴含的上下文信息,因此 [-1] 是只取最后一个 time_step 的预测结果
    return tf.matmul(tf.transpose(outputs, [1, 0, 2])[-1], w) + b

pred = RNN(x)

cost = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))

optimizer = tf.train.AdamOptimizer(lr).minimize(cost)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

tf.summary.scalar('accuracy', accuracy)
tf.summary.scalar('loss', cost)
summaries = tf.summary.merge_all()

tt = time.time()
with tf.Session() as sess:
    train_writer = tf.summary.FileWriter('logs/', sess.graph)
    init = tf.global_variables_initializer()
    sess.run(init)
    step = 1
    while batch_size * step < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        batch_x = batch_x.reshape(batch_size, n_step, n_input)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
        if step % display_step == 0:
            acc, loss = sess.run(
                [accuracy, cost], feed_dict={x: batch_x,
                                             y: batch_y})
            print("Iter " + str(step * batch_size) + ", Minibatch Loss= " + \
                  "{:.6f}".format(loss) + ", Training Accuracy= " + \
                  "{:.5f}".format(acc))
        if step % 100 == 0:
            s = sess.run(summaries, feed_dict={x: batch_x, y: batch_y})
            train_writer.add_summary(s, global_step=step)

        step += 1
    print("Optimization Finished!")

    test_len = 128
    test_data = mnist.test.images[:test_len].reshape((-1, n_step, n_input))
    test_label = mnist.test.labels[:test_len]
    print("Testing Accuracy:", \
          sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
    print(time.time() - tt)

至于为什么 LSTM 能实现分类任务,是因为我们把图片看成了像素序列。然后预测结果我们只取最后一个 time_step 的预测结果,因为一张图片的分类结果只有一个,并且最后一个 time_step 的预测结果包含了图片的所有信息。

经过 92.33 秒的训练,测试准确率达到了 0.976563。

Iter 1280, Minibatch Loss= 1.757248, Training Accuracy= 0.36719
Iter 2560, Minibatch Loss= 1.679881, Training Accuracy= 0.46094
Iter 3840, Minibatch Loss= 1.272040, Training Accuracy= 0.53125
Iter 5120, Minibatch Loss= 0.972794, Training Accuracy= 0.70312
Iter 6400, Minibatch Loss= 0.653458, Training Accuracy= 0.81250
Iter 7680, Minibatch Loss= 0.723198, Training Accuracy= 0.80469
Iter 8960, Minibatch Loss= 0.522374, Training Accuracy= 0.88281
Iter 10240, Minibatch Loss= 0.538418, Training Accuracy= 0.82031
Iter 11520, Minibatch Loss= 0.408725, Training Accuracy= 0.87500
Iter 12800, Minibatch Loss= 0.482461, Training Accuracy= 0.82812
Iter 14080, Minibatch Loss= 0.562638, Training Accuracy= 0.85156
Iter 15360, Minibatch Loss= 0.211240, Training Accuracy= 0.93750
Iter 16640, Minibatch Loss= 0.359455, Training Accuracy= 0.83594
Iter 17920, Minibatch Loss= 0.387070, Training Accuracy= 0.89844
Iter 19200, Minibatch Loss= 0.372927, Training Accuracy= 0.88281
Iter 20480, Minibatch Loss= 0.310806, Training Accuracy= 0.88281
Iter 21760, Minibatch Loss= 0.301324, Training Accuracy= 0.90625
Iter 23040, Minibatch Loss= 0.249105, Training Accuracy= 0.90625
Iter 24320, Minibatch Loss= 0.184059, Training Accuracy= 0.94531
Iter 25600, Minibatch Loss= 0.206372, Training Accuracy= 0.92969
Iter 26880, Minibatch Loss= 0.210895, Training Accuracy= 0.93750
Iter 28160, Minibatch Loss= 0.290178, Training Accuracy= 0.89062
Iter 29440, Minibatch Loss= 0.206363, Training Accuracy= 0.95312
Iter 30720, Minibatch Loss= 0.185184, Training Accuracy= 0.93750
Iter 32000, Minibatch Loss= 0.209735, Training Accuracy= 0.92969
Iter 33280, Minibatch Loss= 0.097975, Training Accuracy= 0.96875
Iter 34560, Minibatch Loss= 0.255305, Training Accuracy= 0.90625
Iter 35840, Minibatch Loss= 0.086390, Training Accuracy= 0.99219
Iter 37120, Minibatch Loss= 0.060936, Training Accuracy= 0.99219
Iter 38400, Minibatch Loss= 0.267242, Training Accuracy= 0.91406
Iter 39680, Minibatch Loss= 0.256366, Training Accuracy= 0.91406
Iter 40960, Minibatch Loss= 0.147943, Training Accuracy= 0.96094
Iter 42240, Minibatch Loss= 0.135919, Training Accuracy= 0.97656
Iter 43520, Minibatch Loss= 0.095043, Training Accuracy= 0.96875
Iter 44800, Minibatch Loss= 0.204021, Training Accuracy= 0.89844
Iter 46080, Minibatch Loss= 0.179127, Training Accuracy= 0.94531
Iter 47360, Minibatch Loss= 0.138613, Training Accuracy= 0.95312
Iter 48640, Minibatch Loss= 0.266192, Training Accuracy= 0.92188
Iter 49920, Minibatch Loss= 0.108609, Training Accuracy= 0.97656
Iter 51200, Minibatch Loss= 0.161269, Training Accuracy= 0.94531
Iter 52480, Minibatch Loss= 0.113301, Training Accuracy= 0.97656
Iter 53760, Minibatch Loss= 0.083430, Training Accuracy= 0.98438
Iter 55040, Minibatch Loss= 0.132865, Training Accuracy= 0.96094
Iter 56320, Minibatch Loss= 0.100576, Training Accuracy= 0.96875
Iter 57600, Minibatch Loss= 0.227338, Training Accuracy= 0.92969
Iter 58880, Minibatch Loss= 0.114374, Training Accuracy= 0.95312
Iter 60160, Minibatch Loss= 0.226231, Training Accuracy= 0.95312
Iter 61440, Minibatch Loss= 0.121312, Training Accuracy= 0.97656
Iter 62720, Minibatch Loss= 0.122376, Training Accuracy= 0.96875
Iter 64000, Minibatch Loss= 0.073419, Training Accuracy= 0.97656
Iter 65280, Minibatch Loss= 0.077218, Training Accuracy= 0.98438
Iter 66560, Minibatch Loss= 0.124091, Training Accuracy= 0.94531
Iter 67840, Minibatch Loss= 0.062538, Training Accuracy= 0.98438
Iter 69120, Minibatch Loss= 0.185202, Training Accuracy= 0.92969
Iter 70400, Minibatch Loss= 0.117486, Training Accuracy= 0.96094
Iter 71680, Minibatch Loss= 0.069971, Training Accuracy= 0.98438
Iter 72960, Minibatch Loss= 0.068637, Training Accuracy= 0.97656
Iter 74240, Minibatch Loss= 0.078597, Training Accuracy= 0.97656
Iter 75520, Minibatch Loss= 0.065298, Training Accuracy= 0.99219
Iter 76800, Minibatch Loss= 0.141951, Training Accuracy= 0.94531
Iter 78080, Minibatch Loss= 0.073046, Training Accuracy= 0.97656
Iter 79360, Minibatch Loss= 0.092249, Training Accuracy= 0.96875
Iter 80640, Minibatch Loss= 0.135571, Training Accuracy= 0.96094
Iter 81920, Minibatch Loss= 0.166597, Training Accuracy= 0.94531
Iter 83200, Minibatch Loss= 0.103855, Training Accuracy= 0.97656
Iter 84480, Minibatch Loss= 0.100378, Training Accuracy= 0.96875
Iter 85760, Minibatch Loss= 0.115921, Training Accuracy= 0.98438
Iter 87040, Minibatch Loss= 0.108667, Training Accuracy= 0.96875
Iter 88320, Minibatch Loss= 0.099972, Training Accuracy= 0.96094
Iter 89600, Minibatch Loss= 0.129549, Training Accuracy= 0.96875
Iter 90880, Minibatch Loss= 0.081135, Training Accuracy= 0.96875
Iter 92160, Minibatch Loss= 0.100124, Training Accuracy= 0.96875
Iter 93440, Minibatch Loss= 0.138314, Training Accuracy= 0.96094
Iter 94720, Minibatch Loss= 0.094340, Training Accuracy= 0.95312
Iter 96000, Minibatch Loss= 0.140076, Training Accuracy= 0.95312
Iter 97280, Minibatch Loss= 0.062758, Training Accuracy= 0.99219
Iter 98560, Minibatch Loss= 0.074103, Training Accuracy= 0.98438
Iter 99840, Minibatch Loss= 0.064638, Training Accuracy= 0.97656
Optimization Finished!
Testing Accuracy: 0.976563
92.33089709281921

⬆️

写的不错,帮助到了您,赞助一下主机费~

扫一扫,用支付宝赞赏
扫一扫,用微信赞赏

暂无评论~~