TensorFlow入门(五)多层 LSTM 通俗易懂版

前言: 根据我本人学习 TensorFlow 实现 LSTM 的经历,发现网上虽然也有不少教程,其中很多都是根据官方给出的例子,用多层 LSTM 来实现 PTBModel 语言模型,比如: 
tensorflow笔记:多层LSTM代码分析 
但是感觉这些例子还是太复杂了,所以这里写了个比较简单的版本,虽然不优雅,但是还是比较容易理解。

如果你想了解 LSTM 的原理的话(前提是你已经理解了普通 RNN 的原理),可以参考我前面翻译的博客: 
(译)理解 LSTM 网络 (Understanding LSTM Networks by colah)

如果你想了解 RNN 原理的话,可以参考 AK 的博客: 
The Unreasonable Effectiveness of Recurrent Neural Networks

很多朋友提到多层怎么理解,所以自己做了一个示意图,希望帮助初学者更好地理解 多层RNN.

 

 
图1 3层RNN按时间步展开

 

本例不讲原理。通过本例,你可以了解到单层 LSTM 的实现,多层 LSTM 的实现。输入输出数据的格式。 RNN 的 dropout layer 的实现。

# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

# 设置 GPU 按需增长
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# 首先导入数据,看一下数据的形式
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print mnist.train.images.shape

 

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784)

1. 首先设置好模型用到的各个超参数

lr = 1e-3
# 在训练和测试的时候,我们想用不同的 batch_size.所以采用占位符的方式
batch_size = tf.placeholder(tf.int32)  # 注意类型必须为 tf.int32
# batch_size = 128

# 每个时刻的输入特征是28维的,就是每个时刻输入一行,一行有 28 个像素
input_size = 28
# 时序持续长度为28,即每做一次预测,需要先输入28行
timestep_size = 28
# 每个隐含层的节点数
hidden_size = 256
# LSTM layer 的层数
layer_num = 2
# 最后输出分类类别数量,如果是回归预测的话应该是 1
class_num = 10

_X = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, class_num])
keep_prob = tf.placeholder(tf.float32)

 

2. 开始搭建 LSTM 模型,其实普通 RNNs 模型也一样

# 把784个点的字符信息还原成 28 * 28 的图片
# 下面几个步骤是实现 RNN / LSTM 的关键
####################################################################
# **步骤1:RNN 的输入shape = (batch_size, timestep_size, input_size)
X = tf.reshape(_X, [-1, 28, 28])

# **步骤2:定义一层 LSTM_cell,只需要说明 hidden_size, 它会自动匹配输入的 X 的维度
lstm_cell = rnn.BasicLSTMCell(num_units=hidden_size, forget_bias=1.0, state_is_tuple=True)

# **步骤3:添加 dropout layer, 一般只设置 output_keep_prob
lstm_cell = rnn.DropoutWrapper(cell=lstm_cell, input_keep_prob=1.0, output_keep_prob=keep_prob)

# **步骤4:调用 MultiRNNCell 来实现多层 LSTM
mlstm_cell = rnn.MultiRNNCell([lstm_cell] * layer_num, state_is_tuple=True)

# **步骤5:用全零来初始化state
init_state = mlstm_cell.zero_state(batch_size, dtype=tf.float32)

# **步骤6:方法一,调用 dynamic_rnn() 来让我们构建好的网络运行起来
# ** 当 time_major==False 时, outputs.shape = [batch_size, timestep_size, hidden_size]
# ** 所以,可以取 h_state = outputs[:, -1, :] 作为最后输出
# ** state.shape = [layer_num, 2, batch_size, hidden_size],
# ** 或者,可以取 h_state = state[-1][1] 作为最后输出
# ** 最后输出维度是 [batch_size, hidden_size]
# outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
# h_state = outputs[:, -1, :]  # 或者 h_state = state[-1][1]

# *************** 为了更好的理解 LSTM 工作原理,我们把上面 步骤6 中的函数自己来实现 ***************
# 通过查看文档你会发现, RNNCell 都提供了一个 __call__()函数(见最后附),我们可以用它来展开实现LSTM按时间步迭代。
# **步骤6:方法二,按时间步展开计算
outputs = list()
state = init_state
with tf.variable_scope('RNN'):
    for timestep in range(timestep_size):
        if timestep > 0:
            tf.get_variable_scope().reuse_variables()
        # 这里的state保存了每一层 LSTM 的状态
        (cell_output, state) = mlstm_cell(X[:, timestep, :], state)
        outputs.append(cell_output)
h_state = outputs[-1]

 

3. 设置 loss function 和 优化器,展开训练并完成测试

# 上面 LSTM 部分的输出会是一个 [hidden_size] 的tensor,我们要分类的话,还需要接一个 softmax 层
# 首先定义 softmax 的连接权重矩阵和偏置
# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')
# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')
# 开始训练和测试
W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)
bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)
y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)

# 损失和评估函数
cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))
train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

sess.run(tf.global_variables_initializer())
for i in range(2000):
    _batch_size = 128
    batch = mnist.train.next_batch(_batch_size)
    if (i+1)%200 == 0:
        train_accuracy = sess.run(accuracy, feed_dict={
            _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})
        # 已经迭代完成的 epoch 数: mnist.train.epochs_completed
        print "Iter%d, step %d, training accuracy %g" % ( mnist.train.epochs_completed, (i+1), train_accuracy)
    sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})

# 计算测试数据的准确率
print "test accuracy %g"% sess.run(accuracy, feed_dict={
    _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})

Iter0, step 200, training accuracy 0.851562
Iter0, step 400, training accuracy 0.960938
Iter1, step 600, training accuracy 0.984375
Iter1, step 800, training accuracy 0.960938
Iter2, step 1000, training accuracy 0.984375
Iter2, step 1200, training accuracy 0.9375
Iter3, step 1400, training accuracy 0.96875
Iter3, step 1600, training accuracy 0.984375
Iter4, step 1800, training accuracy 0.992188
Iter4, step 2000, training accuracy 0.984375
test accuracy 0.9858

 

我们一共只迭代不到5个epoch,在测试集上就已经达到了0.9825的准确率,可以看出来 LSTM 在做这个字符分类的任务上还是比较有效的,而且我们最后一次性对 10000 张测试图片进行预测,才占了 725 MiB 的显存。而我们在之前的两层 CNNs 网络中,预测 10000 张图片一共用了 8721 MiB 的显存,差了整整 12 倍呀!! 这主要是因为 RNN/LSTM 网络中,每个时间步所用的权值矩阵都是共享的,可以通过前面介绍的 LSTM 的网络结构分析一下,整个网络的参数非常少。

4. 可视化看看 LSTM 的是怎么做分类的

毕竟 LSTM 更多的是用来做时序相关的问题,要么是文本,要么是序列预测之类的,所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里,我想通过可视化的方式来帮助大家理解 LSTM 是怎么样一步一步地把图片正确的给分类。

import matplotlib.pyplot as plt

 

看下面我找了一个字符 3

print mnist.train.labels[4]
[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]

 

我们先来看看这个字符样子,上半部分还挺像 2 来的

X3 = mnist.train.images[4]
img3 = X3.reshape([28, 28])
plt.imshow(img3, cmap='gray')
plt.show()

 

我们看看在分类的时候,一行一行地输入,分为各个类别的概率会是什么样子的。

X3.shape = [-1, 784]
y_batch = mnist.train.labels[0]
y_batch.shape = [-1, class_num]

X3_outputs = np.array(sess.run(outputs, feed_dict={
            _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))
print X3_outputs.shape
X3_outputs.shape = [28, hidden_size]
print X3_outputs.shape

  • (28, 1, 256)
    (28, 256)

     

h_W = sess.run(W, feed_dict={
            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias = sess.run(bias, feed_dict={
            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})
h_bias.shape = [-1, 10]

bar_index = range(class_num)
for i in xrange(X3_outputs.shape[0]):
    plt.subplot(7, 4, i+1)
    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])
    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))
    plt.bar(bar_index, pro[0], width=0.2 , align='center')
    plt.axis('off')
plt.show()

在上面的图中,为了更清楚地看到线条的变化,我把坐标都去了,每一行显示了 4 个图,共有 7 行,表示了一行一行读取过程中,模型对字符的识别。可以看到,在只看到前面的几行像素时,模型根本认不出来是什么字符,随着看到的像素越来越多,最后就基本确定了它是字符 3.

好了,本次就到这里。有机会再写个优雅一点的例子,哈哈。其实学这个 LSTM 还是比较困难的,当时写 多层 CNNs 也就半天到一天的时间基本上就没啥问题了,但是这个花了我大概整整三四天,而且是在我对原理已经很了解(我自己觉得而已。。。)的情况下,所以学会了感觉还是有点小高兴的~

17-04-19补充几个资料: 
recurrent_network.py 一个简单的 tensorflow LSTM 例子。 
Tensorflow下构建LSTM模型进行序列化标注 介绍非常好的一个 NLP 开源项目。(例子中有些函数可能在新版的 tensorflow 中已经更新了,但并不影响理解)

5. 附:BASICLSTM.__call__()

'''code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py'''

  def __call__(self, inputs, state, scope=None):
      """Long short-term memory cell (LSTM)."""
      with vs.variable_scope(scope or "basic_lstm_cell"):
          # Parameters of gates are concatenated into one multiply for efficiency.
          if self._state_is_tuple:
              c, h = state
          else:
              c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
          concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)

          # ** 下面四个 tensor,分别是四个 gate 对应的权重矩阵
          # i = input_gate, j = new_input, f = forget_gate, o = output_gate
          i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

          # ** 更新 cell 的状态:
          # ** c * sigmoid(f + self._forget_bias) 是保留上一个 timestep 的部分旧信息
          # ** sigmoid(i) * self._activation(j)  是有当前 timestep 带来的新信息
          new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))

          # ** 新的输出
          new_h = self._activation(new_c) * sigmoid(o)

          if self._state_is_tuple:
              new_state = LSTMStateTuple(new_c, new_h)
          else:
              new_state = array_ops.concat([new_c, new_h], 1)
          # ** 在(一般都是) state_is_tuple=True 情况下, new_h=new_state[1]
          # ** 在上面博文中,就有 cell_output = state[1]
          return new_h, new_state

本文转自博客园知识天地的博客,原文链接:TensorFlow入门(五)多层 LSTM 通俗易懂版,如需转载请自行联系原博主。

时间: 2024-10-12 03:41:07

TensorFlow入门(五)多层 LSTM 通俗易懂版的相关文章

tensorflow笔记:多层LSTM代码分析

标签(空格分隔): tensorflow笔记 tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代码分析  (四) tensorflow笔记:常用函数说明  (五) tensorflow笔记:模型的保存与训练过程可视化  (六)tensorflow笔记:使用tf来实现word2vec 之前讲过了tensorflow中CNN的示例代码,现在我们来看RN

《ArcGIS Engine 地理信息系统开发从入门到精通(第二版)》——第6章 空间数据管理 6.1 SDE及空间数据

第6章 空间数据管理 6.1 SDE及空间数据 ArcGIS Engine 地理信息系统开发从入门到精通(第二版)6.1.1 SDE介绍 ArcSDE是数据库系统中管理地理数据库的接口,通过该接口可以往关系数据中加入空间数据,提供地理要素的空间位置及形状等信息,是ArcGIS与关系数据库之间的GIS通道.它允许用户在多种数据管理系统中管理地理信息,并使所有的ArcGIS应用程序都能够使用这些数据. ArcSDE是多用户ArcGIS系统的一个关键部件,它为DBMS提供了一个开放的接口,允许ArcG

《SQL入门经典(第5版)》一一1.4 本书使用的数据库

1.4 本书使用的数据库 SQL入门经典(第5版) 在继续讨论SQL基础知识之前,我们先来介绍一下本书后续课程中要使用的数据库.下面的小节会介绍所用的表,说明它们之间的关系.它们的结构,并展示其中包含的数据. 图1.4展示了本书范例.测验和练习中所用的表的关系.每个表都有不同的名称.包含一些字段.图中的映射线表示了特定表之间通过共用字段(通常被称为主键)建立的联系. 1.4.1 表命名标准 像商业活动中的其他标准一样,表命名标准对于保持良好的控制也是非常重要的.从前面对于表和数据的介绍可以看出,

《SAP入门经典(第4版•修订版)》——第3章 SAP技术基础知识 3.1 SAP技术101:SAP基础知识

第3章 SAP技术基础知识 经过前面章节的学习,我们已经对SAP的基本概念和使用SAP运行业务的意义有所了解了,现在我们要花一些时间讨论一下更深层的基本技术.本节我们将考察几个常用的与架构相关的技术术语,粗略地了解一下支持任何SAP应用程序都必不可少的3种核心技术:硬件.操作系统和数据库.本章最后我们要讨论一下这些技术是如何发挥作用的.即使您有着深厚的技术背景,本章的内容也仍然值得您花些时间进行了解. 3.1 SAP技术101:SAP基础知识 SAP入门经典(第4版•修订版)在第2章中,我们已经

Windows 8风格应用开发入门 五 ListView数据控件

什么是ListView数据控件? 1) ListView数据控件用来显示数据集合. 2) 继承自ItemsControl. 3) 大多数情况是纵向显示数据,显示的数据通常是排序过的. 4) 在切换到Snap View(贴靠视图)时,通常使用ListView显示数据集合. 开发入门 五 ListView数据控件-vba listview 控件"> 如何构建ListView数据控件? 首先我们需要了解一下ListView控件中一些重要属性和事件: 1) IsItemClickEnabled属性

TensorFlow入门指南

更多深度文章,请关注:https://yq.aliyun.com/cloud Tensorflow是一种被广泛使用的实现机器学习和其他涉及大量数学运算的算法的库.Tensorflow是由Google开发的,是GitHub上最流行的机器学习库之一.Google几乎在所有应用程序中都使用Tensorflow来实现机器学习.例如,如果您正在使用Google照片或Google语音搜索,则您间接使用Tensorflow模型,它们可以在大型Google硬件集群上工作,并且在感知任务方面功能强大. 这篇文章的

《C++入门经典(第6版)》导读

前言 C++入门经典(第6版) 祝贺您!当您阅读到这里时,离学习最重要的编程语言之一-- C++又近了20秒. 如果您再花23小时59分40秒,就将掌握C++编程语言的基本知识.只需24个课程(每个课程不超过1小时),就将学会重要的C++功能,如管理I/O.创建循环和数组.使用模板进行面向对象编程以及创建C++程序. 我们将这些主题组织成了结构完美.易于理解的课程.在每章中,都将通过项目.输出和代码分析,演示相关的主题.另外,还清楚地标出了语法示例,以方便参考. 每章末尾还列出了常见问题及其答案

AppleWatch开发入门五——菜单控件的使用

AppleWatch开发入门五--菜单控件的使用 一.简介         菜单也是WatchOS中一个重要的交互方式,限于Watch的屏幕尺寸,若将所有用户交互控件都紧密的排列进展示的UI中,那样难免会使用户操作困难,也会影响界面布局的简洁美观.因此,WatchOS的菜单机制是一层覆盖在屏幕上的交互界面,有如下的特点: 1.菜单是内置于InterfaceController中的,不需显式处理,只需对齐菜单项进行添加设置. 2.菜单最多可以容乃四个选项按钮. 3.通过重按可以呼出和隐藏菜单. 二

java入门新人求代码:网页版冒泡排序,谢谢。

问题描述 java入门新人求代码:网页版冒泡排序,谢谢. 用myeclipse制作一个网页版的冒泡排序,10个数字(10个空,每个空可以填写1个数字),一个提交排序按钮,点击后对输入的十个空格的数字进行冒泡排序. 谢谢. 解决方案 http://ask.csdn.net/questions/239606 不是回答你了么?新建一个html文件,粘贴上面的代码,你遇到什么具体的问题? 解决方案二: 网上那么多.自己搜索的看看.