TensorFlow中的那些高级API

TensorFlow拥有很多库,比如KerasTFLearnSonnet,对于模型训练来说,使用这些库比使用低级功能更简单。尽管Keras的API目前正在添加到TensorFlow中去,但TensorFlow本身就提供了一些高级构件,而且最新的1.3版本中也引入了一些新的构件。

在这篇文章中,我们将看到一个使用了这些最新的高级构件的例子,包括Estimator(估算器)、Experiment(实验)和Dataset(数据集)。值得注意的是,你可以独立地使用Experiment和Dataset。我在这里假设你已经了解TensorFlow的基础知识;如果没有的话,那么TensorFlow官网上提供的教程值得学习。


Experiment、Estimator和DataSet框架以及它们之间的交互。

我们在本文中将使用MNIST作为数据集。这是一个使用起来很简单的数据集,可以从TensorFlow官网获取到。你可以在这个gist中找到完整的代码示例。使用这些框架的其中一个好处是,我们不需要直接处理会话

Estimator(估算器)类

Estimator类代表了一个模型,以及如何对这个模型进行训练和评估。我们可以像下面这段代码创建一个Estimator:

return tf.estimator.Estimator(
    model_fn=model_fn,  # First-class function
    params=params,  # HParams
    config=run_config  # RunConfig
)

要创建Estimator,需要传入一个模型函数、一组参数和一些配置。

  • 传入的参数应该是模型超参数的一个集合。这可以是一个dictionary,但是我们将在这个例子中把它表示成一个HParams对象,就像namedtuple一样。
  • 传入的配置用于指定如何运行训练和评估,以及在哪里存储结果。这个配置是一个RunConfig对象,该对象会把模型运行环境相关的信息告诉Estimator。
  • 模型函数是一个Python函数,它根据给定的输入构建模型。

模型函数

模型函数是一个Python函数,并作为一级函数传递给Estimator。稍后我们会看到,TensorFlow在其他地方也使用了一级函数。将模型表示为一个函数的好处是可以通过实例化函数来多次创建模型。模型可以在训练过程中用不同的输入重新创建,例如,在训练过程中运行验证测试。

模型函数把输入特征作为参数,将相应的标签作为张量。它也能以某种方式来告知用户模型是在训练、评估或是在执行推理。模型函数的最后一个参数是超参数集合,它们与传递给Estimator的超参数集合相同。模型函数返回一个EstimatorSpec对象,该对象定义了一个完整的模型。

EstimatorSpec对象用于对操作进行预测、损失、训练和评估,因此,它定义了一个用于训练、评估和推理的完整的模型图。由于EstimatorSpec只可用于常规的TensorFlow操作,因此,我们可以使用像TF-Slim这样的框架来定义模型。

Experiment(实验)类

Experiment类定义了如何训练模型,它与Estimator完美地集成在一起。我们可以像如下代码创建一个Experiment对象:

experiment = tf.contrib.learn.Experiment(
    estimator=estimator,  # Estimator
    train_input_fn=train_input_fn,  # First-class function
    eval_input_fn=eval_input_fn,  # First-class function
    train_steps=params.train_steps,  # Minibatch steps
    min_eval_frequency=params.min_eval_frequency,  # Eval frequency
    train_monitors=[train_input_hook],  # Hooks for training
    eval_hooks=[eval_input_hook],  # Hooks for evaluation
    eval_steps=None  # Use evaluation feeder until its empty
)

以下几种情况会把Experiment对象作为输入:

  • 一个estimator(例如我们上面定义的)。
  • 作为一级函数训练和评估数据。这里使用了与前面提到的模型函数相同的概念。如果需要的话,通过传入函数而不是操作,可以重新创建输入图。稍后我们还会谈到这个。
  • 训练和评估hook(钩子)。钩子可用于保存或监视特定的内容,或者在图或会话中设置某些操作。例如,我们将其传入到操作中,帮助初始化数据加载器。
  • 描述需要训练多久以及何时评估的各种参数。

一旦定义了experiment,我们就可以像下面这段代码那样使用learn_runner.run来运行它训练和评估模型:

learn_runner.run(
    experiment_fn=experiment_fn,  # First-class function
    run_config=run_config,  # RunConfig
    schedule="train_and_evaluate",  # What to run
    hparams=params  # HParams
)

与模型函数和数据函数一样,learn_runner将一个创建experiment的函数作为参数传入。

Dataset(数据集)类

我们将使用Dataset类和相应的Iterator来表示数据的训练和评估,以及创建在训练过程中迭代数据的数据馈送器。 在本示例中,我们将使用在Tensorflow中可用的MNIST数据,并为其构建一个Dataset包装。例如,我们将把训练输入数据表示为:

# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
    """Return the input function to get the training data.
    Args:
        batch_size (int): Batch size of training iterator that is returned
                          by the input function.
        mnist_data (Object): Object holding the loaded mnist data.
    Returns:
        (Input function, IteratorInitializerHook):
            - Function that returns (features, labels) when called.
            - Hook to initialise input iterator.
    """
    iterator_initializer_hook = IteratorInitializerHook()

    def train_inputs():
        """Returns training set as Operations.
        Returns:
            (features, labels) Operations that iterate over the dataset
            on every evaluation
        """
        with tf.name_scope('Training_data'):
            # Get Mnist data
            images = mnist_data.train.images.reshape([-1, 28, 28, 1])
            labels = mnist_data.train.labels
            # Define placeholders
            images_placeholder = tf.placeholder(
                images.dtype, images.shape)
            labels_placeholder = tf.placeholder(
                labels.dtype, labels.shape)
            # Build dataset iterator
            dataset = tf.contrib.data.Dataset.from_tensor_slices(
                (images_placeholder, labels_placeholder))
            dataset = dataset.repeat(None)  # Infinite iterations
            dataset = dataset.shuffle(buffer_size=10000)
            dataset = dataset.batch(batch_size)
            iterator = dataset.make_initializable_iterator()
            next_example, next_label = iterator.get_next()
            # Set runhook to initialize iterator
            iterator_initializer_hook.iterator_initializer_func = \
                lambda sess: sess.run(
                    iterator.initializer,
                    feed_dict={images_placeholder: images,
                               labels_placeholder: labels})
            # Return batched (features, labels)
            return next_example, next_label

    # Return function and hook
    return train_inputs, iterator_initializer_hook

调用这个get_train_inputs将返回一个一级函数,用于在TensorFlow图中创建数据加载操作,以及返回一个用于初始化迭代器的Hook

本示例中使用的MNIST数据最初是一个Numpy数组。我们创建了一个占位符张量来获取数据;使用占位符的目的是为了避免数据的复制。接下来,我们在from_tensor_slices的帮助下创建一个切片数据集。我们要确保该数据集可以运行无限次数,并且数据被重新洗牌并放入指定大小的批次中。

要迭代数据,就需要从数据集中创建一个迭代器。由于我们正在使用占位符,因此需要使用NumPy数据在相关会话中对占位符进行初始化。可以通过创建一个可初始化的迭代器来实现这个。在创建图的时候,将创建一个自定义的IteratorInitializerHook对象来初始化迭代器:

class IteratorInitializerHook(tf.train.SessionRunHook):
    """Hook to initialise data iterator after Session is created."""

    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None

    def after_create_session(self, session, coord):
        """Initialise the iterator after the session has been created."""
        self.iterator_initializer_func(session)

IteratorInitializerHook继承自SessionRunHook。这个钩子将在相关会话创建后立即调用after_create_session,并使用正确的数据初始化占位符。这个钩子由我们的get_train_inputs函数返回,并在创建时传递给Experiment对象。

train_inputs函数返回的数据加载操作是TensorFlow的操作,该操作每次评估时都会返回一个新的批处理。

运行代码

现在,我们已经定义了所有内容,可以使用下面这个命令运行代码了:

python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data

如果不传入参数,它将使用文件开头的默认标志来确定数据和模型保存的位置。

在训练过程中,在终端上会输出这段时间内的全​​局步骤、损失和准确性等信息。除此之外,Experiment和Estimator框架将记录TensorBoard可视化的某些统计信息。如果我们运行这个命令:

tensorboard --logdir='./mnist_training'

那么我们可以看到所有的训练统计数据,如训练损失、评估准确性、每个步骤的时间,以及模型图。


TensorBoard可视化中的评估准确度

我写这篇文章,是因为我在编写代码示例时,无法找到有关Tensorflow Estimator 、Experiment和Dataset框架太多的信息和示例。我希望这篇文章能向你简要介绍一下这些框架是如何工作的,它们采用了什么样的抽象方法以及如何使用它们。如果你对使用这些框架感兴趣,下面我将介绍一些注意点和其他的文档。

有关Estimator、Experiment和Dataset框架的注意点

文章原标题《Higher-Level APIs in TensorFlow》,作者:Peter Roelants,译者:夏天,审校:主题曲。

文章为简译,更为详细的内容,请查看原文需要爬梯,不方便的同学也可以下载下方的PDF附件,阅读原文内容。

时间: 2024-08-01 19:08:52

TensorFlow中的那些高级API的相关文章

在Android应用中使用百度地图api

本篇通过一个简单的示例一步步介绍如何在Android应用中使用百度地图api. 1)下载百度地图移动版 API(Android)开发包 要在Android应用中使用百度地图API,就需要在工程中引用百度地图API开发包,这个 开发包包含两个文件:baidumapapi.jar和libBMapApiEngine.so.下载地址: http://dev.baidu.com/wiki/static/imap/files/BaiduMapApi_Lib_Android_1.0.zip 2)申请API K

TensorFlow中RNN实现的正确打开方式

上周写的文章<完全图解RNN.RNN变体.Seq2Seq.Attention机制>介绍了一下RNN的几种结构,今天就来聊一聊如何在TensorFlow中实现这些结构,这篇文章的主要内容为: 一个完整的.循序渐进的学习TensorFlow中RNN实现的方法.这个学习路径的曲线较为平缓,应该可以减少不少学习精力,帮助大家少走弯路. 一些可能会踩的坑 TensorFlow源码分析 一个Char RNN实现示例,可以用来写诗,生成歌词,甚至可以用来写网络小说!(项目地址:https://github.

在TensorFlow中对比两大生成模型:VAE与GAN(附测试代码)

项目链接:https://github.com/kvmanohar22/ Generative-Models 变分自编码器(VAE)与生成对抗网络(GAN)是复杂分布上无监督学习最具前景的两类方法. 本项目总结了使用变分自编码器(Variational Autoencode,VAE)和生成对抗网络(GAN)对给定数据分布进行建模,并且对比了这些模型的性能.你可能会问:我们已经有了数百万张图像,为什么还要从给定数据分布中生成图像呢?正如 Ian Goodfellow 在 NIPS 2016 教程中

在Visual Web应用程序中使用Java Persistence API

借助 NetBeans IDE 6.0 和 Visual Web 工具,您可以使用 Visual Web 数据提供程 序组件以及 Java Persistence API(JPA)来编写连接到数据库表的应用程序.建立了到 数据库表的连接之后,可以使用 Java Persistence API 执行数据库 CRUD 操作(即创建 .读取.更新和删除操作).在开发基于数据库的应用程序时,使用 Java Persistence API 能提供更高的灵活性. 本文是本系列文章的第 1 篇(共两篇),主要

[译] 探索 Swift 4 中新的 String API

本文讲的是[译] 探索 Swift 4 中新的 String API, WWDC 已经结束了(我觉得是自 2014 年来最好的一场 WWDC),同时 Xcode 9 beta 版也发布了,很多开发者已经开始把玩 Swift 4 ,今年的新版本真心不错,这是一个改进版本而不是重构版本(像 Swift 2 和 3),因此大多数代码升级起来会更容易. 其中一个改进是 String 的 API,在 Swift 4 中更易用,也更强大.在过去的 Swift 版本中,String API 经常被提出为一个例

java、api-如何获得其他网页的pm2.5的数据传到自己数据库中,有相关api接口,可不会用

问题描述 如何获得其他网页的pm2.5的数据传到自己数据库中,有相关api接口,可不会用 如何获得其他网页的pm2.5的数据传到自己数据库中,有相关api接口,可不会用 解决方案 有相关接口那就很简单啊,比如说你调用一个接口,然后按照里面的规则传参数,然后获取回来一般有json数据,如果是android的话你可以用Gson解析成为对象.如果不是android的话,你可以用最直接的方法那就是字符串的截取,你看看怎么适合咯.

《面向机器智能的TensorFlow实践》一3.2 在TensorFlow中定义数据流图

3.2 在TensorFlow中定义数据流图 在本书中,你将接触到多样化的以及相当复杂的机器学习模型.然而,不同的模型在TensorFlow中的定义过程却遵循着相似的模式.当掌握了各种数学概念,并学会如何实现它们时,对TensorFlow核心工作模式的理解将有助于你脚踏实地开展工作.幸运的是,这个工作流非常容易记忆,它只包含两个步骤: 1)定义数据流图. 2)运行数据流图(在数据上). 这里有一个显而易见的道理,如果数据流图不存在,那么肯定无法运行它.头脑中有这种概念是很有必要的,因为当你编写代

云计算系统中对开发者的API设计问题

本文讲的是云计算系统中对开发者的API设计问题,[IT168 资讯]近年来,随着互联网应用的普及与深化,网络信息与服务趋于海量,用户体验需求不断增长,数据海量.分布异构.处理复杂.使用繁琐等问题逐渐突显,旨在解决这些问题的云计算(Cloud Computing)相关技术得到了迅猛发展.云计算概念的提出在成为新的发展机遇的同时也在云计算技术方面受到挑战.特别是云计算系统中的API设计问题受到极大挑战. 云计算是分布式处理(Distributed Computing).并行处理(Parallel C

Win10开发:微软详解在应用中使用新型OneDrive API

从Win8开始OneDrive被深度集成到系统中,成为了Windows系统中的一个组件,为用户提供了一个免费的云存储服务.对于开发者来说,也可以在自己的应用中使用OneDrive API,从而实现内容的同步等功能. 在WP8开发框架中,OneDrive团队已经提供了一款非常方便好用的SDK,但仍存在很多限制.例如,使用内置按钮控件才可以实现登录机制,开发者也无法更改外观和行为.更不方便的地方在于,无法在各平台之间共享代码. 现在,微软OneDrive团队基于HTTP请求(GET.POST和PUT