7.9PyTorch里的Dataset和DataLoader

在深度学习项目中,数据处理是非常重要的一环,PyTorch提供这两个类来帮助用户高效地加载和处理数据。Dataset负责数据的读取和预处理,而DataLoader则负责将数据分成小批量,支持多线程加速,以及数据的打乱等。本节,我们就以Titanic数据集为例,讲解如何使用PyTorch里的Dataset和DataLoader类来处理数据。

7.9.1Dataset

Dataset类提供对数据集的抽象,任何自定义数据集都需要继承torch.utils.data.Dataset,并实现两个方法:__len____getitem__(idx)。其中__len__需要返回整个数据集样本的个数。__getitem__(idx)需要能根据样本的index返回具体的样本。

另外一般情况下,我们也把对数据的预处理工作放在自定义的Dataset类定义里。

接下来我们就来看一下对于Titanic数据集,我们的自定义Dataset类怎么实现,具体代码如下:

from torch.utils.data import Dataset
import pandas as pd
import torch


class TitanicDataset(Dataset):
    def __init__(self, file_path):
        self.file_path = file_path
        self.mean = {
            "Pclass": 2.236695,
            "Age": 29.699118,
            "SibSp": 0.512605,
            "Parch": 0.431373,
            "Fare": 34.694514,
            "Sex_female": 0.365546,
            "Sex_male": 0.634454,
            "Embarked_C": 0.182073,
            "Embarked_Q": 0.039216,
            "Embarked_S": 0.775910
        }

        self.std = {
            "Pclass": 0.838250,
            "Age": 14.526497,
            "SibSp": 0.929783,
            "Parch": 0.853289,
            "Fare": 52.918930,
            "Sex_female": 0.481921,
            "Sex_male": 0.481921,
            "Embarked_C": 0.386175,
            "Embarked_Q": 0.194244,
            "Embarked_S": 0.417274
        }

        self.data = self._load_data()
        self.feature_size = len(self.data.columns) - 1

    def _load_data(self):
        df = pd.read_csv(self.file_path)
        df = df.drop(columns=["PassengerId", "Name", "Ticket", "Cabin"]) ##删除不用的列
        df = df.dropna(subset=["Age"])##删除Age有缺失的行
        df = pd.get_dummies(df, columns=["Sex", "Embarked"], dtype=int)##进行one-hot编码

        ##进行数据的标准化
        base_features = ["Pclass", "Age", "SibSp", "Parch", "Fare"]
        for i in range(len(base_features)):
            df[base_features[i]] = (df[base_features[i]] - self.mean[base_features[i]]) / self.std[base_features[i]]
        return df

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        features = self.data.drop(columns=["Survived"]).iloc[idx].values
        label = self.data["Survived"].iloc[idx]
        return torch.tensor(features, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)

可以看到我们自定义了一个Dataset类:TitanicDataset。其中定义了一个方法_load_data,它负责加载数据,并且对数据进行预处理,以及标准化。

另外继承自Dataset类则必须实现的__len____getitem__我们也进行了对应的实现。这样我们的TitanicDataset就是一个合格的PyTorch里的Dataset了。

7.9.2DataLoader

一般情况下,我们不需要自定义DataLoader。PyTorch里默认实现的DataLoader就可以满足我们的使用,它定义了如何批量读取数据的功能。比如你可以通过batchsize设置每次读取数据的大小,通过shuffle参数设置是否对数据集进行打乱。另外如果你的Dataset的```_getitem__```比较费时,你可以通过num_workers参数指定多进程加载。

from torch.utils.data import DataLoader

dataset = TitanicDataset(r"E:\电子书\RethinkFun深度学习\data\titanic\train.csv")
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for inputs, labels in dataloader:
    print(inputs.shape, labels.shape)
    break

输出为:

torch.Size([32, 10]) torch.Size([32])

results matching ""

    No results matching ""