TensorFlow教程之进阶指南 3.2 变量:创建、初始化、保存和加载

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权。

当训练模型时,用变量来存储和更新参数。变量包含张量 (Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析是被加载。

本文档描述以下两个TensorFlow类。点击以下链接可查看完整的API文档:

创建

当创建一个变量时,你将一个张量作为初始值传入构造函数Variable()。TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值

注意,所有这些操作符都需要你指定张量的shape。那个形状自动成为变量的shape。变量的shape通常是固定的,但TensorFlow提供了高级的机制来重新调整其行列数。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

调用tf.Variable()添加一些操作(Op, operation)到graph:

  • 一个Variable操作存放变量的值。
  • 一个初始化op将变量设置为初始值。这事实上是一个tf.assign操作.
  • 初始值的操作,例如示例中对biases变量的zeros操作也被加入了graph。

tf.Variable的返回值是Python的tf.Variable类的一个实例。

初始化

变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。

你或者可以从检查点文件中重新获取变量值,详见下文。

使用tf.initialize_all_variables()添加一个操作对变量做初始化。记得在完全构建好模型并加载之后再运行那个操作。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Later, when launching the model
with tf.Session() as sess:
  # Run the init operation.
  sess.run(init_op)
  ...
  # Use the model
  ...

由另一个变量初始化

你有时候会需要用另一个变量的初始化值给当前变量初始化。由于tf.initialize_all_variables()是并行地初始化所有变量,所以在有这种需求的情况下需要小心。

用其它变量的值初始化一个新的变量时,使用其它变量的initialized_value()属性。你可以直接把已初始化的值作为新变量的初始值,或者把它当做tensor计算得到一个值赋予新变量。

# Create a variable with a random value.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
# Create another variable with the same value as 'weights'.
w2 = tf.Variable(weights.initialized_value(), name="w2")
# Create another variable with twice the value of 'weights'
w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

自定义初始化

tf.initialize_all_variables()函数便捷地添加一个op来初始化模型的所有变量。你也可以给它传入一组变量进行初始化。详情请见Variables Documentation,包括检查变量是否被初始化。

保存和加载

最简单的保存和恢复模型的方法是使用tf.train.Saver对象。构造器给graph的所有变量,或是定义在列表里的变量,添加saverestoreops。saver对象提供了方法来运行这些ops,定义检查点文件的读写路径。

检查点文件

变量存储在二进制文件里,主要包含从变量名到tensor值的映射关系。

当你创建一个Saver对象时,你可以选择性地为检查点文件中的变量挑选变量名。默认情况下,将每个变量Variable.name属性的值。

保存变量

tf.train.Saver()创建一个Saver来管理模型中的所有变量。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print "Model saved in file: ", save_path

恢复变量

用同一个Saver对象来恢复变量。注意,当你从文件中恢复变量时,不需要事先对它们做初始化。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print "Model restored."
  # Do some work with the model
  ...

选择存储和恢复哪些变量

如果你不给tf.train.Saver()传入任何参数,那么saver将处理graph中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。

有时候在检查点文件中明确定义变量的名称很有用。举个例子,你也许已经训练得到了一个模型,其中有个变量命名为"weights",你想把它的值恢复到一个新的变量"params"中。

有时候仅保存和恢复模型的一部分变量很有用。再举个例子,你也许训练得到了一个5层神经网络,现在想训练一个6层的新模型,可以将之前5层模型的参数导入到新模型的前5层中。

你可以通过给tf.train.Saver()构造函数传入Python字典,很容易地定义需要保持的变量及对应名称:键对应使用的名称,值对应被管理的变量。

注意:

  • 如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore()函数被运行时,它的值才会发生改变。
  • 如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化op。详情请见tf.initialize_variables()
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...
时间: 2025-01-27 05:04:50

TensorFlow教程之进阶指南 3.2 变量:创建、初始化、保存和加载的相关文章

TensorFlow教程之进阶指南 3.1 总览

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 综述 Overview Variables: 创建,初始化,保存,和恢复 TensorFlow Variables 是内存中的容纳 tensor 的缓存.这一小节介绍了用它们在模型训练时(during training)创建.保存和更新模型参数(model parameters) 的方法. TensorFlow 机制 101 用 MNIST 手写数字识别作为一个小例子,一步一步的将使用 TensorFlow 基

TensorFlow教程之进阶指南 3.10 共享变量

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 共享变量 你可以在怎么使用变量中所描述的方式来创建,初始化,保存及加载单一的变量.但是当创建复杂的模块时,通常你需要共享大量变量集并且如果你还想在同一个地方初始化这所有的变量,我们又该怎么做呢.本教程就是演示如何使用tf.variable_scope() 和tf.get_variable()两个方法来实现这一点. 问题 假设你为图片过滤器创建了一个简单的模块,和我们的卷积神经网络教程模块相似,但是这里包括两个卷

TensorFlow教程之进阶指南 3.5 读取数据

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 目录 数据读取 供给数据(Feeding) 从文件读取数据 文

TensorFlow教程之进阶指南 3.6 线程和队列

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 在使用TensorFlow进行异步计算时,队列是一种强大的机制. 正如TensorFlow中的其他组件一样,队列就是TensorFlow图中的节点.这是一种有状态的节点,就像变量一样:其他节点可以修改它的内容.具体来说,其他节点可以把新元素插入到队列后端(rear),也可以把队列前端(front)的元素删除. 为了感受一下队列,让我们来看一个简单的例子.我们先创建一个"先入先出"的队列(FIFOQue

TensorFlow教程之进阶指南 3.3 TensorBoard:可视化学习

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. TensorBoard:可视化学习 TensorBoard 涉及到的运算,通常是在训练庞大的深度神经网络中出现的复杂而又难以理解的运算. 为了更方便 TensorFlow 程序的理解.调试与优化,我们发布了一套叫做 TensorBoard 的可视化工具.你可以用 TensorBoard 来展现你的 TensorFlow 图像,绘制图像生成的定量指标图以及附加数据. 当 TensorBoard 设置完成后,它应该

TensorFlow教程之进阶指南 3.7 添加新的OP

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 增加一个新 Op 预备知识: 对 C++ 有一定了解. 已经下载 TensorFlow 源代码并有能力编译它. 如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作: 在一个 C++ 文件中注册新 Op. Op 的注册与实现是相互独立的. 在其注册时描述了 Op 该如何执行. 例如, 注册 Op 时定义了 Op 的名字, 并指定了它的输入和输出.

TensorFlow教程之进阶指南 3.8 自定义数据读取

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 自定义数据读取 基本要求: 熟悉 C++ 编程. 确保下载 TensorFlow 源文件, 并可编译使用. 我们将支持文件格式的任务分成两部分: 文件格式: 我们使用 Reader Op来从文件中读取一个 record (可以使任意字符串). 记录格式: 我们使用解码器或者解析运算将一个字符串记录转换为TensorFlow可以使用的张量. 例如, 读取一个 CSV 文件,我们使用 一个文本读写器, 然后是从一行

TensorFlow教程之进阶指南 3.9 使用 GPUs

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 支持的设备 在一套标准的系统上通常有多个计算设备. TensorFlow 支持 CPU 和 GPU 这两种设备. 我们用指定字符串strings 来标识这些设备. 比如: "/cpu:0": 机器中的 CPU "/gpu:0": 机器中的 GPU, 如果你有一个的话. "/gpu:1": 机器中的第二个 GPU, 以此类推... 如果一个 TensorFlow

TensorFlow教程之完整教程 2.3 MNIST入门

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 机器学习入门 这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手. 当我们开始学习编程的时候,第一件事往往是学习打印"Hello World".就好比编程入门有Hello World,机器学习入门有MNIST. MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片: 它也包含每一张图片对应的标签,告诉我们这个是数字几.比如,上面这四张图片的标签分别是5,0,4,1. 在此