如何正确结合 TensorFlow 的 Dataset API 和 Keras?

本文介绍了如何正确结合 TensorFlow 的 Dataset API 和 Keras?的处理方法,对大家解决问题具有一定的参考价值

问题描述

Keras 的 fit_generator() 模型方法需要一个生成器来生成形状(输入、目标)的元组,其中两个元素都是 NumPy 数组.文档 似乎暗示如果我简单地将 Dataset 迭代器 在生成器中,并确保将张量转换为 NumPy 数组,我应该很高兴.但是,此代码给了我一个错误:

Keras' fit_generator() model method expects a generator which produces tuples of the shape (input, targets), where both elements are NumPy arrays. The documentation seems to imply that if I simply wrap a Dataset iterator in a generator, and make sure to convert the Tensors to NumPy arrays, I should be good to go. This code, however, gives me an error:

import numpy as np
import os
import keras.backend as K
from keras.layers import Dense, Input
from keras.models import Model
import tensorflow as tf
from tensorflow.contrib.data import Dataset

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

with tf.Session() as sess:
    def create_data_generator():
        dat1 = np.arange(4).reshape(-1, 1)
        ds1 = Dataset.from_tensor_slices(dat1).repeat()

        dat2 = np.arange(5, 9).reshape(-1, 1)
        ds2 = Dataset.from_tensor_slices(dat2).repeat()

        ds = Dataset.zip((ds1, ds2)).batch(4)
        iterator = ds.make_one_shot_iterator()
        while True:
            next_val = iterator.get_next()
            yield sess.run(next_val)

datagen = create_data_generator()

input_vals = Input(shape=(1,))
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mean_squared_error')
model.fit_generator(datagen, steps_per_epoch=1, epochs=5,
                    verbose=2, max_queue_size=2)

这是我得到的错误:

Using TensorFlow backend.
Epoch 1/5
Exception in thread Thread-1:
Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2787, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/home/jsaporta/anaconda3/lib/python3.6/threading.py", line 864, in run
    self._target(*self._args, **self._kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/utils/data_utils.py", line 568, in data_generator_task
    generator_output = next(self._generator)
  File "./datagen_test.py", line 25, in create_data_generator
    yield sess.run(next_val)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 895, in run
    run_metadata_ptr)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1109, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 413, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 233, in for_fetch
    return _ListFetchMapper(fetch)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 340, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 241, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 277, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Tensor 'IteratorGetNext:0' shape=(?, 1) dtype=int64> cannot be interpreted as a Tensor. (Tensor Tensor("IteratorGetNext:0", shape=(?, 1), dtype=int64) is not an element of this graph.)

Traceback (most recent call last):
  File "./datagen_test.py", line 34, in <module>
    verbose=2, max_queue_size=2)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 87, in wrapper
    return func(*args, **kwargs)
  File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 2011, in fit_generator
    generator_output = next(output_generator)
StopIteration

奇怪的是,在我初始化 datagen 的位置之后直接添加包含 next(datagen) 的行会导致代码运行得很好,没有错误.

Strangely enough, adding a line containing next(datagen) directly after where I initialize datagen causes the code to run just fine, with no errors.

为什么我的原始代码不起作用?为什么当我将该行添加到我的代码时它开始工作?有没有更有效的方法来将 TensorFlow 的数据集 API 与 Keras 结合使用,而不涉及将张量转换为 NumPy 数组然后再返回?

Why does my original code not work? Why does it begin to work when I add that line to my code? Is there a more efficient way to use TensorFlow's Dataset API with Keras that doesn't involve converting Tensors to NumPy arrays and back again?

推荐答案

确实有一种更有效的方法来使用 Dataset 而不必将张量转换为 numpy 数组.但是,它不是(还?)在官方文档中.从发行说明来看,这是 Keras 2.0.7 中引入的一项功能.您可能需要安装 keras>=2.0.7 才能使用它.

There is indeed a more efficient way to use Dataset without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)

几个不同点:

  1. tensor 参数提供给 Input 层.Keras 将从该张量中读取值,并将其用作拟合模型的输入.
  2. target_tensors 参数提供给 Model.compile().
  3. 记得把 x 和 y 都转换成 float32.在正常使用情况下,Keras 会为你做这个转换.但现在你必须自己做.
  4. 批量大小是在Dataset 的构建过程中指定的.使用 steps_per_epochepochs 来控制何时停止模型拟合.
  1. Supply the tensor argument to the Input layer. Keras will read values from this tensor, and use it as the input to fit the model.
  2. Supply the target_tensors argument to Model.compile().
  3. Remember to convert both x and y into float32. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself.
  4. Batch size is specified during the construction of Dataset. Use steps_per_epoch and epochs to control when to stop model fitting.

简而言之,使用 Input(tensor=...)model.compile(target_tensors=...)model.fit(x=None, y=None, ...) 如果要从张量中读取数据.

In short, use Input(tensor=...), model.compile(target_tensors=...) and model.fit(x=None, y=None, ...) if your data are to be read from tensors.

这篇关于如何正确结合 TensorFlow 的 Dataset API 和 Keras?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,WP2

admin_action_{$_REQUEST[‘action’]}

do_action( "admin_action_{$_REQUEST[‘action’]}" )动作钩子::在发送“Action”请求变量时激发。Action Hook: Fires when an ‘action’ request variable is sent.目录锚点:#说明#源码说明(Description)钩子名称的动态部分$_REQUEST['action']引用从GET或POST请求派生的操作。源码(Source)更新版本源码位置使用被使用2.6.0 wp-admin/admin.php:...

日期:2020-09-02 17:44:16 浏览:1127

admin_footer-{$GLOBALS[‘hook_suffix’]}

do_action( "admin_footer-{$GLOBALS[‘hook_suffix’]}", string $hook_suffix )操作挂钩:在默认页脚脚本之后打印脚本或数据。Action Hook: Print scripts or data after the default footer scripts.目录锚点:#说明#参数#源码说明(Description)钩子名的动态部分,$GLOBALS['hook_suffix']引用当前页的全局钩子后缀。参数(Parameters)参数类...

日期:2020-09-02 17:44:20 浏览:1032

customize_save_{$this->id_data[‘base’]}

do_action( "customize_save_{$this-&gt;id_data[‘base’]}", WP_Customize_Setting $this )动作钩子::在调用WP_Customize_Setting::save()方法时激发。Action Hook: Fires when the WP_Customize_Setting::save() method is called.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分,$this->id_data...

日期:2020-08-15 15:47:24 浏览:775

customize_value_{$this->id_data[‘base’]}

apply_filters( "customize_value_{$this-&gt;id_data[‘base’]}", mixed $default )过滤器::过滤未作为主题模式或选项处理的自定义设置值。Filter Hook: Filter a Customize setting value not handled as a theme_mod or option.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分,$this->id_date['base'],指的是设置...

日期:2020-08-15 15:47:24 浏览:866

get_comment_author_url

过滤钩子:过滤评论作者的URL。Filter Hook: Filters the comment author’s URL.目录锚点:#源码源码(Source)更新版本源码位置使用被使用 wp-includes/comment-template.php:32610...

日期:2020-08-10 23:06:14 浏览:903

network_admin_edit_{$_GET[‘action’]}

do_action( "network_admin_edit_{$_GET[‘action’]}" )操作挂钩:启动请求的处理程序操作。Action Hook: Fires the requested handler action.目录锚点:#说明#源码说明(Description)钩子名称的动态部分$u GET['action']引用请求的操作的名称。源码(Source)更新版本源码位置使用被使用3.1.0 wp-admin/network/edit.php:3600...

日期:2020-08-02 09:56:09 浏览:848

network_sites_updated_message_{$_GET[‘updated’]}

apply_filters( "network_sites_updated_message_{$_GET[‘updated’]}", string $msg )筛选器挂钩:在网络管理中筛选特定的非默认站点更新消息。Filter Hook: Filters a specific, non-default site-updated message in the Network admin.目录锚点:#说明#参数#源码说明(Description)钩子名称的动态部分$_GET['updated']引用了非默认的...

日期:2020-08-02 09:56:03 浏览:834

pre_wp_is_site_initialized

过滤器::过滤在访问数据库之前是否初始化站点的检查。Filter Hook: Filters the check for whether a site is initialized before the database is accessed.目录锚点:#源码源码(Source)更新版本源码位置使用被使用 wp-includes/ms-site.php:93910...

日期:2020-07-29 10:15:38 浏览:809

WordPress 的SEO 教学:如何在网站中加入关键字(Meta Keywords)与Meta 描述(Meta Description)?

你想在WordPress 中添加关键字和meta 描述吗?关键字和meta 描述使你能够提高网站的SEO。在本文中,我们将向你展示如何在WordPress 中正确添加关键字和meta 描述。为什么要在WordPress 中添加关键字和Meta 描述?关键字和说明让搜寻引擎更了解您的帖子和页面的内容。关键词是人们寻找您发布的内容时,可能会搜索的重要词语或片语。而Meta Description则是对你的页面和文章的简要描述。如果你想要了解更多关于中继标签的资讯,可以参考Google的说明。Meta 关键字和描...

日期:2020-10-03 21:18:25 浏览:1620

谷歌的SEO是什么

SEO (Search Engine Optimization)中文是搜寻引擎最佳化,意思近于「关键字自然排序」、「网站排名优化」。简言之,SEO是以搜索引擎(如Google、Bing)为曝光媒体的行销手法。例如搜寻「wordpress教学」,会看到本站的「WordPress教学:12个课程…」排行Google第一:关键字:wordpress教学、wordpress课程…若搜寻「网站架设」,则会看到另一个网页排名第1:关键字:网站架设、架站…以上两个网页,每月从搜寻引擎导入自然流量,达2万4千:每月「有机搜...

日期:2020-10-30 17:23:57 浏览:1263