tf.data的API可以让你通过简单的,可重用的模块来构造一个复杂的输入管道。比如,对于一个图像模型的管道可能需要从一个分布式的文件系统收集数据,给每个图片增加随机的扰动,并且随机选择一些图片来作为一个batch用来训练。对于一个文本的管道可能需要引入从原始文本提取符号,从lookup表里,把这些符号转化为embedding的值。并且把不同长度的序列组成一个batch。tf.data API可以处理大规模的数据,读取不同的数据格式,并进行复杂的转化。

tf.data的API引入了tf.data.Dataset来抽象表示一个元素序列。每个元素包含了一个或多个组件。比如对于一个图片的管道,一个元素是一个训练样本,这个样本包含了图片以及它的label。

有两种不同的方法来创建一个dataset:

  • 从数据源构造一个Dataset,比如从内存或者文件(一个或多个)构造。
  • 从一个或多个tf.data.Dataset转化生成一个dataset。

基本机制

你必须从一个数据源开始定义一个输入管道。比如,从内存中的数据创建一个Dataset,你可以用tf.data.Dataset.from_tensors()或者tf.data.Dataset.from_tensor_slices(). 或者,如果你的输入数据是以建议的TFRecord格式存储在一个文件里。你可以用tf.data.TFRecordDataset().

一旦你有一个Dataset对象,你可以通过链式调用它上边的方法把它转化为一个新的Dataset。比如,你可以通过调用对每个元素级别的转化,比如Dataset.map(),以及对多个元素的转化,比如Dataset.batch().可以通过查看tf.data.Dataset来查看一个完整的转化列表。

Dataset对象实现了Python的iterable接口。这样你可以通过一个for循环来消费数据。


dataset = tf.data.Dataset.from_tensor_slices([8,3,0,8,2,1])
dataset

<TensorSliceDataset shape:(), types:tf.int32>

for elem in dataset:
    print(elem.numpy())

8
3
0
8
2
1

或者通过iter来显式创建一个Python的迭代器,并且使用next来消费它的元素


it = iter(dataset)
print(next(it).numpy())

8

还有一种就是你使用reduce操作去消费dataset的元素。这样会使用所有的元素生成一个单一的结果。下边的例子演示了如何用reduce去计算一个整数类型dataset的加和。


print(dataset.reduce(0,lambda state,value:state+value).numpy())

22

Dataset 结构

一个dataset包含的元素每个都有同样的(嵌套)结构。这个结构每个独立的部分都是可以用tf.TypeSpec表示的类型。包括tf.Tensor,tf.sparse.SparseTensor,tf.RaggedTensor,tf.TensorArray,或者tf.data.Dataset.
Dataset.element_spec属性允许你去查看元素每个模块的类型。这个属性返回一个tf.TypeSpec的嵌套结构。和元素的结构相对应。可能是一个单一模块,或者是一个模块元组,或者是一个嵌套的模块元组。比如:


dataset1 = tf.data.Dataset.from_tensor_slices(rf.random.uniform([4,10]))
dataset1.element_spec

TensorSpec(shape(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4]),tf.random.uniform([4,100],maxval=100,dtype=tf.int32))
dataset2.element_spec

(TensorSpec(shape=(),dtype=tf.float32,name=None),
TensorSpec(shape=(100,),dtype=tf.int32,name=None))

dataset3 = tf.data.Dataset.zip((dataset1,dataset2))
dataset3.element_spec

(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

dataset4=tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0,0],[1,2]],values=[1,2],dense_shape=[3,4]))
dataset4.element_spec

SparseTensorSpec(TensorShape([3,4]),tf.int32)

#使用value_type去查看element spec表示的类型
dataset4.element_spec.value_type

tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset 的转化支持任意类型的dataset,当你使用Dataset.map(),Dataset.filter()方法。会对每个元素应用这个方法。元素的结构决定了方法的参数。


dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1

<TensorSliceDataset shapes: (10,), types: tf.int32>

for z in dataset1:
  print(z.numpy())

[6 4 2 6 4 7 8 5 2 3]
[3 6 2 3 8 6 1 2 3 4]
[8 4 4 2 7 3 1 7 7 8]
[9 9 3 7 9 1 9 1 1 7]

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2

<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
dataset3

<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>

for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))

shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

读取输入数据

消费Numpy数组

通过查看加载Numpy数组了解更多的例子。
如果你所有的数据可以一次性加载在内存中,则从这些数据中生成Dataset是通过把他们转化成tf.Tensor对象,然后使用Dataset.from_tensor_slices()


train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset

<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

上边的代码会通过tf.constant()把features和labels嵌入在TensorFlow的graph里。这对于小的数据集是可以的,但是会浪费内存,因为数组的元素会被多次拷贝。并且受到tf.GraphDef的protocal buffer 2GB内存的限制。

消费Python Generators

另一个可以方便的被作为tf.data.Dataset数据源的是python generator

虽然这是一种方便的方法,但它的可移植性和安全性有限。它必须在创建生成器的python进程中运行,并且仍然服从python GIL。


def count(stop):
    i=0
    while i<stop:
        yield i
        i += 1
for n in count(5):
    print(n)

0
1
2
3
4

Dataset.from_generator 构造器把 python generator转化成一个tf.data.Dataset。
这个构造器把迭代器的生成方法作为参数,而不是迭代器本身,这可以让他在迭代器到达终点后从新启动迭代器。他需要一个可选的迭代器生成方法的参数。

output_types参数是必填的,因为tf.data内部构建了一个tf.Graph。Graph的边需要一个tf.dtype.


ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())

[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shape参数不是必须的,但是强烈推荐加上,因为很多tenorflow的操作不支持未知rank的tensor。如果特定维度的长度是未知的或者是可变的,则可以在output_shape里把它设置为None。这里同样需要注意output_shapes和output_types和其他dataset方法遵循一样的嵌套规则。
这里有一个样例的generator,它演示了这两个方面,它返回数组的元组。第二个数组是一个有着未知长度的向量。


def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i &gt; 5:
    break

0 : [-0.119  -0.1705  0.6298  1.6472]
1 : [0.4102 0.858 ]
2 : [ 1.1845 -0.4004 -0.5682  0.7648 -2.0806  1.2242 -2.0966  0.8446]
3 : [-2.2548  1.4521  0.8092]
4 : [ 0.8081  0.7604 -0.6913  0.5043  1.0684]
5 : [-0.1358  0.3046  0.0346]
6 : []

的一个输出是int32,第二个是float32。
第一个元素时个标量,shape为(),第二个元素是一个未知长度的向量。shape为(None,)


ds_series = tf.data.Dataset.from_generator(
    gen_series,
    output_types=(tf.int32, tf.float32),
    output_shapes=((), (None,)))

ds_series

&lt;FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)&gt;

现在它就可以像一个常规的tf.data.Dataset一样去使用。需要注意的是,如果你想把有着变化长度的dataset作为一个batch使用,你需要用Dataset.padded_batch.


ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())

[16 19 15  2  0 23 12  6 27 20]

[[ 1.7982  0.6075  0.8935  1.5496 -0.9427  0.9026 -1.0912  0.      0.    ]
 [ 0.      0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.3835 -0.7199  0.6697 -0.5535 -1.2751  1.0727 -1.6975 -1.8727  0.6132]
 [-0.7542  2.4467  1.6744  0.6795  0.3545  0.4416  0.      0.      0.    ]
 [-1.1485  0.2186 -0.5212 -0.6324 -1.1361 -2.0172  0.      0.      0.    ]
 [ 0.3903  0.0763  0.1203 -1.5849  0.      0.      0.      0.      0.    ]
 [-0.361   0.      0.      0.      0.      0.      0.      0.      0.    ]
 [-0.3172 -0.3045 -0.4404  1.0413 -2.2203 -0.9137  0.      0.      0.    ]
 [ 0.1888  1.432   1.4547  0.5166 -0.2852  2.3501 -0.3476 -0.1323 -0.0585]
 [ 0.0677  0.5017  1.5934  0.5202  0.      0.      0.      0.      0.    ]]

一个更加真实的例子是把preprocessing.image.ImageDataGenerator包装成一个tf.data.Dataset.
首先是下载数据:


flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

创建image.ImageDataGenerator


img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))

Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)

float32 (32, 256, 256, 3)
float32 (32, 5)

ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers],
    output_types=(tf.float32, tf.float32),
    output_shapes=([32,256,256,3], [32,5])
)

ds

&lt;FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, tf.float32)&gt;

消费TFRecord data

一个端到端加载TFRecords的例子

tf.data 的API支持非常多的文件格式来支持不适合一次读入内存的大数据集。比如,TFRecord 文件格式是一个简单的面向记录的二进制文件格式。很多TensorFlow的应用程序都用它来作为训练数据。tf.data.TFRecordDataset类让你流式的读取一个或者多个TFRecord文件来作为输入流的一部分。

这里有一个用French Street Name Sign(FSNS)作为测试文件的例子。


# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")

TFRecordDataset的初始化函数的filenames参数可以是一个string,一个string的list,或者是strings的tf.Tensor。因此如果你有两组数据,分别用来做训练和验证,你可以创建一个工厂方法来生成dataset。用文件名作为输入参数。


dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

&lt;TFRecordDatasetV2 shapes: (), types: tf.string&gt;

很多TensorFlow的工程在他们的TFRcord文件里使用序列化的tf.train.Example 条目。这些都要在使用前进行解码。


raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']

bytes_list {
  value: "Rue Perreyon"
}

消费一个文本数据

一个加载文本数据的例子。
很多数据集是一个或者多个分布式的文本文件。tf.data.TextLineDataset提供了一种简单的方法来从一个或者多个文本文件里抽取行。提供一个或者多个文件名,一个TextLineDataset可以从这些文件中每行生成一个字符串类型的元素。


directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
dataset = tf.data.TextLineDataset(file_paths)
for line in dataset.take(5):
    print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

为了在多个文件中交替选择行,可以使用Dataset.interleave. 这样让在多个文件做shuffle变得容易。这里是每个翻译的第一,二,三行


files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

默认情况下,一个TextLineDataset对每个文件的每一行都会生成数据。这个可能并不是你想要的。比如,如果文件有文件头列,或者包含着注释。这些行可以通过Dataset.skip()或者Dataset.filter()来进行处理。这里你可以跳过第一行,然后过滤数据,只留下你想要的。


titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
for line in titanic_lines.take(10):
  print(line.numpy())

b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'

def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())

b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

消费CSV数据

更多示例关于加载CSV文件加载Pandas Dataframes
比如:


titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
df.head()
survived sex age n_siblings_spouses parch fare class deck embark_town alone
0 0 male 22.0 1 0 7.2500 Third unknown Southampton n
1 1 female 38.0 1 0 71.2833 First C Cherbourg n
2 1 female 26.0 0 0 7.9250 Third unknown Southampton y
3 1 female 35.0 1 0 53.1000 First C Southampton n
4 0 male 28.0 0 0 8.4583 Third unknown Queenstown y

如果你的数据可以在内存放得下,Dataset.from_tensor_slices对于字典数据也可以用。这样可以简单的导入数据。


titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))

  'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

一个更加可伸缩的方案是按需从磁盘读取数据。
tf.data模块提供了从一个或者多个遵循RFC 4180CSV文件提取记录的方法。
experimental.make_csv_dataset方法是一个高层的接口用来读取一组csv文件。它支持列类型推断和其他很多特性。比如batching和shuffling,这样使用起来更方便。


titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))

'survived': [0 0 0 0]
features:
  'sex'               : [b'male' b'male' b'male' b'male']
  'age'               : [18. 19. 31. 34.]
  'n_siblings_spouses': [0 0 0 1]
  'parch'             : [0 0 0 0]
  'fare'              : [ 8.3    8.05   7.775 26.   ]
  'class'             : [b'Third' b'Third' b'Third' b'Second']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'y' b'y' b'y' b'n']

你可以使用select_columns参数,如果你只需要选取一部分列。


titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))

'survived': [1 0 0 1]
  'fare'              : [ 20.25     9.     110.8833  59.4   ]
  'class'             : [b'Third' b'Third' b'First' b'First']

也有一个低级别的接口experimental.CsvDataset类来提供更细粒度的控制。他不支持列类型的推断,需要你指定每个列的类型。


titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])

[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

如果有的列有空值,这个低级别接口允许你提供默认值,而不需要指定列的类型。


%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

&lt;MapDataset shapes: (4,), types: tf.int32&gt;

for line in dataset:
  print(line.numpy())

[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

默认情况下,一个CsvDataset对每个文件的每行每列都读取,这个可能并不是你想要的,比如如果文件头,或者一些列你不想作为输入,这些行和列可以通过header和select_col参数来进行去除。


# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

&lt;MapDataset shapes: (2,), types: tf.int32&gt;

for line in dataset:
  print(line.numpy())

[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

消费文件集

有很多数据集以文件的方式分散存储。每个文件是一个样本。


flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

根目录下每个不同分类都存在一个文件夹下:


for item in flowers_root.glob("*"):
  print(item.name)

sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

每个分类文件夹下的文件都是样本:


list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())

b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4933229095_f7e4218b28.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/18282528206_7fb3166041.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/3711723108_65247a3170.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4019748730_ee09b39a43.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/475936554_a2b38aaa8e.jpg'

使用tf.io.read_file来读取数据,从文件路径抽取label。生成(iamge,label)对


def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)

for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())

b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00'

b'roses'

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: