TensorFlow教程之进阶指南 3.8 自定义数据读取

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权。

自定义数据读取

基本要求:

我们将支持文件格式的任务分成两部分:

  • 文件格式: 我们使用 Reader Op来从文件中读取一个 record (可以使任意字符串)。
  • 记录格式: 我们使用解码器或者解析运算将一个字符串记录转换为TensorFlow可以使用的张量。

例如, 读取一个 CSV 文件,我们使用 一个文本读写器, 然后是从一行文本中解析CSV数据的运算

主要内容

自定义数据读取

编写一个文件格式读写器

Reader 是专门用来读取文件中的记录的。TensorFlow中内建了一些读写器Op的实例:

你可以看到这些读写器的界面是一样的,唯一的差异是在它们的构造函数中。最重要的方法是 Read。 它需要一个行列参数,通过这个行列参数,可以在需要的时候随时读取文件名 (例如: 当 Read Op首次运行,或者前一个 Read` 从一个文件中读取最后一条记录时)。它将会生成两个标量张量: 一个字符串和一个字符串关键值。

新创建一个名为 SomeReader 的读写器,需要以下步骤:

  1. 在 C++ 中, 定义一个 tensorflow::ReaderBase的子类,命名为 "SomeReader".
  2. 在 C++ 中,注册一个新的读写器Op和Kernel,命名为 "SomeReader"。
  3. 在 Python 中, 定义一个 tf.ReaderBase 的子类,命名为 "SomeReader"。

你可以把所有的 C++ 代码放在 tensorflow/core/user_ops/some_reader_op.cc文件中. 读取文件的代码将被嵌入到C++ 的 ReaderBase 类的迭代中。 这个 ReaderBase 类 是在 tensorflow/core/kernels/reader_base.h 中定义的。 你需要执行以下的方法:

  • OnWorkStartedLocked:打开下一个文件
  • ReadLocked:读取一个记录或报告 EOF/error
  • OnWorkFinishedLocked:关闭当前文件
  • ResetLocked:清空记录,例如:一个错误记录

以上这些方法的名字后面都带有 "Locked", 表示 ReaderBase 在调用任何一个方法之前确保获得互斥锁,这样就不用担心线程安全(虽然只保护了该类中的元素而不是全局的)。

对于 OnWorkStartedLocked, 需要打开的文件名是 current_work() 函数的返回值。此时的 ReadLocked 的数字签名如下:

Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)

如果 ReadLocked 从文件中成功读取了一条记录,它将更新为:

  • *key: 记录的标志位,通过该标志位可以重新定位到该记录。 可以包含从 current_work() 返回值获得的文件名,并追加一个记录号或其他信息。
  • *value: 包含记录的内容。
  • *produced: 设置为 true。

当你在文件(EOF)末尾,设置 *at_end 为 true ,在任何情况下,都将返回 Status::OK()。 当出现错误的时候,只需要使用 tensorflow/core/lib/core/errors.h 中的一个辅助功能就可以简单地返回,不需要做任何参数修改。

接下来你讲创建一个实际的读写器Op。 如果你已经熟悉了添加新的Op 那会很有帮助。 主要步骤如下:

  • 注册Op。
  • 定义并注册 OpKernel。

要注册Op,你需要用到一个调用指令定义在 tensorflow/core/framework/op.h中的REGISTER_OP。

读写器 Op 没有输入,只有 Ref(string) 类型的单输出。它们调用 SetIsStateful(),并有一个 container 字符串和 shared_name 属性. 你可以在一个 Doc 中定义配置或包含文档的额外属性。 例如:详见tensorflow/core/ops/io_ops.cc等:

 #include "tensorflow/core/framework/op.h"
REGISTER_OP("TextLineReader")
    .Output("reader_handle: Ref(string)")
    .Attr("skip_header_lines: int = 0")
    .Attr("container: string = ''")
    .Attr("shared_name: string = ''")
    .SetIsStateful()
    .Doc(R"doc(
A Reader that outputs the lines of a file delimited by '\n'.
)doc");

要定义一个 OpKernel, 读写器可以使用定义在tensorflow/core/framework/reader_op_kernel.h中的 ReaderOpKernel 的递减快捷方式,并运行一个叫 SetReaderFactory 的构造函数。 定义所需要的类之后,你需要通过 REGISTER_KERNEL_BUILDER(...) 注册这个类。

一个没有属性的例子:

 #include "tensorflow/core/framework/reader_op_kernel.h"
class TFRecordReaderOp : public ReaderOpKernel {
 public:
  explicit TFRecordReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    Env* env = context->env();
    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
  }
};
REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
                        TFRecordReaderOp);

一个带有属性的例子:

 #include "tensorflow/core/framework/reader_op_kernel.h"
class TextLineReaderOp : public ReaderOpKernel {
 public:
  explicit TextLineReaderOp(OpKernelConstruction* context)
      : ReaderOpKernel(context) {
    int skip_header_lines = -1;
    OP_REQUIRES_OK(context,
                   context->GetAttr("skip_header_lines", &skip_header_lines));
    OP_REQUIRES(context, skip_header_lines >= 0,
                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
                                        skip_header_lines));
    Env* env = context->env();
    SetReaderFactory([this, skip_header_lines, env]() {
      return new TextLineReader(name(), skip_header_lines, env);
    });
  }
};
REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
                        TextLineReaderOp);

最后一步是添加 Python 包装器,你需要将 tensorflow.python.ops.io_ops 导入到tensorflow/python/user_ops/user_ops.py,并添加一个 io_ops.ReaderBase的衍生函数。

from tensorflow.python.framework import ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import io_ops
class SomeReader(io_ops.ReaderBase):
    def __init__(self, name=None):
        rr = gen_user_ops.some_reader(name=name)
        super(SomeReader, self).__init__(rr)
ops.NoGradient("SomeReader")
ops.RegisterShape("SomeReader")(common_shapes.scalar_shape)

你可以在 tensorflow/python/ops/io_ops.py中查看一些范例。

编写一个记录格式Op

一般来说,这是一个普通的Op, 需要一个标量字符串记录作为输入, 因此遵循 添加Op的说明。 你可以选择一个标量字符串作为输入, 并包含在错误消息中报告不正确的格式化数据。

用于解码记录的运算实例:

请注意,使用多个Op 来解码某个特定的记录格式也是有效的。 例如,你有一张以字符串格式保存在tf.train.Example 协议缓冲区的图像文件。 根据该图像的格式, 你可能从 tf.parse_single_example 的Op 读取响应输出并调用 tf.decode_jpeg, tf.decode_png, 或者 tf.decode_raw。通过读取 tf.decode_raw 的响应输出并使用tf.slice 和 tf.reshape 来提取数据是通用的方法。

时间: 2024-12-20 22:13:42

TensorFlow教程之进阶指南 3.8 自定义数据读取的相关文章

TensorFlow教程之进阶指南 3.1 总览

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 综述 Overview Variables: 创建,初始化,保存,和恢复 TensorFlow Variables 是内存中的容纳 tensor 的缓存.这一小节介绍了用它们在模型训练时(during training)创建.保存和更新模型参数(model parameters) 的方法. TensorFlow 机制 101 用 MNIST 手写数字识别作为一个小例子,一步一步的将使用 TensorFlow 基

TensorFlow教程之进阶指南 3.5 读取数据

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 数据读取 TensorFlow程序读取数据一共有3种方法: 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据. 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据. 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况). 目录 数据读取 供给数据(Feeding) 从文件读取数据 文

TensorFlow教程之进阶指南 3.2 变量:创建、初始化、保存和加载

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 当训练模型时,用变量来存储和更新参数.变量包含张量 (Tensor)存放于内存的缓存区.建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘.这些变量的值可在之后模型训练和分析是被加载. 本文档描述以下两个TensorFlow类.点击以下链接可查看完整的API文档: tf.Variable 类 tf.train.Saver 类 创建 当创建一个变量时,你将一个张量作为初始值传入构造函数Variable(

TensorFlow教程之进阶指南 3.10 共享变量

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 共享变量 你可以在怎么使用变量中所描述的方式来创建,初始化,保存及加载单一的变量.但是当创建复杂的模块时,通常你需要共享大量变量集并且如果你还想在同一个地方初始化这所有的变量,我们又该怎么做呢.本教程就是演示如何使用tf.variable_scope() 和tf.get_variable()两个方法来实现这一点. 问题 假设你为图片过滤器创建了一个简单的模块,和我们的卷积神经网络教程模块相似,但是这里包括两个卷

TensorFlow教程之进阶指南 3.7 添加新的OP

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 增加一个新 Op 预备知识: 对 C++ 有一定了解. 已经下载 TensorFlow 源代码并有能力编译它. 如果现有的库没有涵盖你想要的操作, 你可以自己定制一个. 为了使定制的 Op 能够兼容原有的库 , 你必须做以下工作: 在一个 C++ 文件中注册新 Op. Op 的注册与实现是相互独立的. 在其注册时描述了 Op 该如何执行. 例如, 注册 Op 时定义了 Op 的名字, 并指定了它的输入和输出.

TensorFlow教程之进阶指南 3.6 线程和队列

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 在使用TensorFlow进行异步计算时,队列是一种强大的机制. 正如TensorFlow中的其他组件一样,队列就是TensorFlow图中的节点.这是一种有状态的节点,就像变量一样:其他节点可以修改它的内容.具体来说,其他节点可以把新元素插入到队列后端(rear),也可以把队列前端(front)的元素删除. 为了感受一下队列,让我们来看一个简单的例子.我们先创建一个"先入先出"的队列(FIFOQue

TensorFlow教程之进阶指南 3.3 TensorBoard:可视化学习

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. TensorBoard:可视化学习 TensorBoard 涉及到的运算,通常是在训练庞大的深度神经网络中出现的复杂而又难以理解的运算. 为了更方便 TensorFlow 程序的理解.调试与优化,我们发布了一套叫做 TensorBoard 的可视化工具.你可以用 TensorBoard 来展现你的 TensorFlow 图像,绘制图像生成的定量指标图以及附加数据. 当 TensorBoard 设置完成后,它应该

TensorFlow教程之进阶指南 3.9 使用 GPUs

本文档为TensorFlow参考文档,本转载已得到TensorFlow中文社区授权. 支持的设备 在一套标准的系统上通常有多个计算设备. TensorFlow 支持 CPU 和 GPU 这两种设备. 我们用指定字符串strings 来标识这些设备. 比如: "/cpu:0": 机器中的 CPU "/gpu:0": 机器中的 GPU, 如果你有一个的话. "/gpu:1": 机器中的第二个 GPU, 以此类推... 如果一个 TensorFlow

《从Excel到Python——数据分析进阶指南》一第1章 生成数据表

第1章 生成数据表从Excel到Python--数据分析进阶指南常见的生成数据表的方法有两种,第一种是导入外部数据,第二种是直接写入数据. Excel中的"文件"菜单中提供了获取外部数据的功能,支持数据库和文本文件和页面的多种数据源导入. Python支持从多种类型的数据导入.在开始使用Python进行数据导入前需要先导入pandas库,为了方便起见,我们也同时导入numpy库. import numpy as np import pandas as pd 导入数据表下面分别是从Exc