Dataset 快速入门

tf.data 模块里包含了一组让你可以快速加载数据,操作数据,并把它传输到你的模型的类。这篇文章通过两个简单的例子来看一下这些API
1. 从内存中的numpy arrays读取数据
2. 从csv文件里读取数据

基本输入


从一个array里获取slices是入门tf.data的最简单的办法。
比如下边这个代码把features和labels组成dataset,这样就可以供Estimator使用。


def train_input_fun(features,labels,batch_size):
    """ An input function for training """
   
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features),labels))
   
    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)
   
    # Return the dataset
    return dataset

我们来详细看一下这个程序。

参数

这个方法有三个参数,参数可以接受任何可以被转化成numpy.array的array。tuple是一种特列,我们将会看到,它有特殊的含义。
1. features: 一个{‘feature_name’:array}的字典(或者 DataFrame)包含着原始的input features
2. labels: 一个array包含着每个样例的label
3. batch_size:一个表示batch size的整数

premade_estimator.py文件里我们通过iris_data.load_data()来获取iris data。你可以运行它,并像下边那样获取并解析数据:


import iris_data

# 获取数据
train,test = iris_data.load_data()
features,labels = train

然后我们把数据传给input function,像这样:


batch_size=100
iris_data.train_input_fn(features,labels,batch_size)

我们现在过一遍train_input_fn()

Slices

在方法的开始,通过tf.data.Dataset.from_tensor_slices方法创建了一个tf.data.Dataset用来表示数组的一个切片。数组是在第一个维度进行切分的。比如,一个数组的shape是(60000,28,28)。经过from_tensor_slices,会返回一个Dataset对象。它包含60000个切片。每一个都是28×28的数组。
返回这个Dataset的代码如下:


train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)

这会打印出像下边的内容,包含Dataset里元素的shape和type。但是Dataset并不知道它包含多少元素。


<TensorSliceDataset shapes:(28,28), types: tf.unit8>

上边的Dataset只是简单的表示了一组array。但是Dataset其实有更强的功能。一个Dataset可以自动的处理任何嵌套的dictionary或者tuple(或者namedtuple)
比如在把iris features转化为一个标准的python的dictionary后,你可以把dictionary的数组转化为一个dictionary的Dataset。就像下边这样:


dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)

输出为:


<TensorSliceDataset
    shapes:{
        SepalLength:(),PetalWidth:(),
        PetalLength:(),SepalWidth:()}
    types:{
        SepalLength: tf.float64, PetalWidth: tf.float64,
        PetalLength: tf.float64, SepalWidth: tf.float64}
    >
    }

这里我们看到当Dataset包含结构化的元素,Dataset的shapes和types有着同样的结构。这个dataset包含着一组标量的dictionary。所有的类型都是tf.float64.

上边方法train_input_fn的第一行用了相同的功能,但是添加了另外一层的结构。它创建了一个dataset包含着(features_dict,label)的结构。

下边的代码显示label是一个int64的标量。


# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)

<TensorSliceDataset
    shapes: (
        {
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},
        ()),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

操作

现在,Dataset可以对数据进行一次遍历。以固定的顺序,每次产生一个元素。它在进行模型训练前需要更多的处理。tf.data.Dataset提供了准备数据的方法。比如下边这行代码就用到了几个:


# Shuffle,repeat,and batch the example.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

shuffle方法用固定大小的buffer来对数据进行随机排序。在上边的例子里buffer size是1000,它比Dataset里含有的样本还多,这就保证了完全的shuffle。
repeat 方法在数据一遍迭代完后,可以重新开始,你也可以传入参数来规定可以repeat几次。
batch方法可以把dataset里的样本组成固定大小的batch。这样会给切片数据的shape前边增加一个维度。比如:


print(mnist_ds.batch(100))

<BatchDataset
  shapes: (?, 28, 28),
  types: tf.uint8>

?是因为最后一个batch含有不足一个batch size的元素。
在train_input_fn里,在batching Dataset之后,会是这样:


print(dataset)

<TensorSliceDataset
    shapes: (
        {
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},
        (?,)),

    types: (
        {
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},
        tf.int64)>

返回值

现在Dataset包含(features_dict,lables)对,这是train和evaluate方法期待的格式。所以 input_fn返回这个dataset。

labels在进行predict的时候可以也应该被去掉。

从CSV文件读取数据


现实世界最常见的办法是从磁盘的文件流式读取数据到Dataset。tf.data模块包含了多样的文件读取功能。让我们看如何从csv 文件来读取iris data 到Dataset。

iris_datra.maybe_download如果需要会下载数据。并返回下载文件的路径。


import iris_data
train_path, test_path = iris_data.maybe_download()

iris_data.csv_input_fn方法包含了另一种用Dataset解析csv文件的实现。让我们看一下如何构建一个可以供Estimator使用的input function。它从本地文件读取数据。

构建Dataset

我们从构建一个TextLineDataset对象开始,这个对象每次读取文件的一行。我们可以调用skip(1)来跳过文件的第一行列名。


ds = tf.data.TextLineDataset(train_path).skip(1)

创建一个csv的行分析器

我们开始构建一个分析一行数据的方法。
下边的iris_data.parse_line方法通过tf.decode_csv方法和一些简单的python code完成了这个任务。
我们必须通过对dataset里每一行进行分析来生成必要的(features,label)对。下边的_parse_lie方法调用了tf.decode_csv来把一行分解成features和lable。因为Esitimators需要features表示成dictionary的形式。我们依赖Python内建的dict和zip方法来构建dictionary。feature的名字是dictionary的key。我们通过调用dictionary的pop方法来从features里去掉label列。


# Metadata describing the text columns
COLUMNS = ['SepalLength','SepalWidth','PetalLength','PetalWidth','label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

解析行

在数据被送到模型之前,Dataset有很多方法来操作数据。最常被用到的方法是map,这个操作对dataset里每个元素进行转化。

map method有一个map_func参数,这个参数描述了Dataset的每个元素如何被转化。

所以,为了从把csv文件里读取的行进行分析,我们传递我们的_parse_line方法给map方法:


ds = ds.map(_parse_line)
print(ds)

<MapDataset
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
    ()),
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},
    tf.int32)>

现在,ds不再是一个简单的string标量,而是一个包含着(features,label)对的dataset。

试一下

这个方法可以被用来替代iris_data.train_input_fn. 它可以被用来给estimator来提供数据。


train_path,test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    tf.feature_column.numeric_column(name)
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,n_classes=3)

# Train the estimator
batch_size = 100
est.train(steps=1000,input_fn= lambda:iris_data.csv_input_fn(train_path,batch_size))

Estimators 需要一个没有参数的input_fn,这里我们用lambda来获取参数,并且提供期望的接口。

总结


tf.data 模块提供了一组类和方法让我们容易的从各种datasource来读取数据。并且tf.data还有简单但是强大的方法来进行用户定义的数据转化。

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: