首页 / 行业
利用Python和PyTorch处理面向对象的数据集
2021-08-25 15:30:00
本篇是利用 Python 和 PyTorch 处理面向对象的数据集系列博客的第 2 篇。
如需阅读第 1 篇:原始数据和数据集,请参阅此处。
我们在第 1 部分中已定义 MyDataset 类,现在,让我们来例化 MyDataset 对象
此可迭代对象是与原始数据交互的接口,在整个训练过程中都有巨大作用。
第 2 部分:创建数据集对象
■输入 [9]:
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = False, resize = True, newsize = (64, 64))
以下是该对象的一些使用示例:
■输入 [10]:
# 对象操作示例。
# 此操作用于调用 method __getitem__ 并从第 6 个样本获取标签
mydataset[6][1]
■输出 [10]:
0
■输入 [11]:
# 此操作用于在类声明后打印注释
MyDataset.__doc__
■输出 [11]:
‘Interface class to raw data, providing the total number of samples in the dataset and a preprocessed item’
■输入 [12]:
# 此操作用于调用 method __len__
len(mydataset)
■输出 [12]:
49100
■输入 [13]:
# 此操作用于触发 method __str__
print(mydataset)
原始数据路径为 。/raw_data/data_images/《raw samples》
可迭代对象的重要性
训练期间,将向模型提供多批次样本。可迭代的 mydataset 是获得高级轻量代码的关键。
以下提供了可迭代对象的 2 个使用示例。
示例 1:
我们可以直接获取第 3 个样本张量:
■输入 [14]:
mydataset.__getitem__(3)[0].shape
■输出 [14]:
torch.Size([3, 64, 64])
与以下操作作用相同
■输入 [15]:
mydataset[3][0].shape
■输出 [15]:
torch.Size([3, 64, 64])
示例 2:
我们可以对文件夹中的图像进行解析,并移除黑白图像:
■输入 [ ]:
# 数据集访问示例:创建 1 个包含标签的新文件,移除黑白图像
if os.path.exists(raw_data_path + ‘/’+ “labels_new.txt”):
os.remove(raw_data_path + ‘/’+ “labels_new.txt”)
with open(raw_data_path + ‘/’+ “labels_new.txt”, “a”) as myfile:
for item, info in mydataset:
if item != None:
if item.shape[0]==1:
# os.remove(raw_data_path + ‘/’ + info.SampleName)
print(‘C = {}; H = {}; W = {}; info = {}’.format(item.shape[0], item.shape[1], item.shape[2], info))
else:
#print(info.SampleName + ‘ ’ + str(info.SampleLabel))
myfile.write(info.SampleName + ‘ ’ + str(info.SampleLabel) + ‘ ’)
■输入 [ ]:
# 查找具有非期望格式的样本
with open(raw_data_path + ‘/’+ “labels.txt”, “a”) as myfile:
for item, info in mydataset:
if item != None:
if item.shape[0]!=3:
# os.remove(raw_data_path + ‘/’ + info.SampleName)
print(‘C = {}; H = {}; W = {}; info = {}’.format(item.shape[0], item.shape[1], item.shape[2], info))
修改标签文件后,请务必更新缓存:
■输入 [ ]:
if os.path.exists(raw_data_path + ‘/’+ “labels_new.txt”):
os.rename(raw_data_path + ‘/’+ “labels.txt”, raw_data_path + ‘/’+ “labels_orig.txt”)
os.rename(raw_data_path + ‘/’+ “labels_new.txt”, raw_data_path + ‘/’+ “labels.txt”)
@functools.lru_cache(1)
def getSampleInfoList(raw_data_path):
sample_list = []
with open(str(raw_data_path) + ‘/labels.txt’, “r”) as f:
reader = csv.reader(f, delimiter = ‘ ’)
for i, row in enumerate(reader):
imgname = row[0]
label = int(row[1])
sample_list.append(DataInfoTuple(imgname, label))
sample_list.sort(reverse=False, key=myFunc)
return sample_list
del mydataset
mydataset = MyDataset(isValSet_bool = None, raw_data_path = ‘。./。./raw_data/data_images’, norm = False)
len(mydataset)
您可通过以下链接阅读了解有关 PyTorch 中的可迭代数据库的更多信息:
https://pytorch.org/docs/stable/data.html
归一化
应对所有样本张量计算平均值和标准差。
如果数据集较小,可以尝试在内存中对其进行直接操作:使用 torch.stack 即可创建 1 个包含所有样本张量的栈。
可迭代对象 mydataset 支持简洁精美的代码。
使用“view”即可保留 R、G 和 B 这 3 个通道,并将其余所有维度合并为 1 个维度。
使用“mean”即可计算维度 1 的每个通道的平均值。
请参阅附件中有关 dim 使用的说明。
■输入 [16]:
imgs = torch.stack([img_t for img_t, _ in mydataset], dim = 3)
■输入 [17]:
#im_mean = imgs.view(3, -1).mean(dim=1).tolist()
im_mean = imgs.view(3, -1).mean(dim=1)
im_mean
■输出 [17]:
tensor([0.4735, 0.4502, 0.4002])
■输入 [18]:
im_std = imgs.view(3, -1).std(dim=1).tolist()
im_std
■输出 [18]:
[0.28131285309791565, 0.27447444200515747, 0.2874436378479004]
■输入 [19]:
normalize = transforms.Normalize(mean=[0.4735, 0.4502, 0.4002], std=[0.28131, 0.27447, 0.28744])
# free memory
del imgs
下面,我们将再次构建数据集对象,但这次将对此对象进行归一化:
■输入 [21]:
mydataset = MyDataset(isValSet_bool = None, raw_data_path = raw_data_path, norm = True, resize = True, newsize = (64, 64))
由于采用了归一化,因此张量值被转换至范围 0..1 之内,并进行剪切操作。
■输入 [22]:
original = Image.open(‘。./。./raw_data/data_images/img_00009111.JPEG’)
fig, axs = plt.subplots(1, 2, figsize=(10, 3))
axs[0].set_title(‘clipped tensor’)
axs[0].imshow(mydataset[5][0].permute(1,2,0))
axs[1].set_title(‘original PIL image’)
axs[1].imshow(original)
plt.show()
将输入数据剪切到含 RGB 数据的 imshow 的有效范围内,以 [0..1] 表示浮点值,或者以 [0..255] 表示整数值。
使用 torchvision.transforms
进行预处理
现在,我们已经创建了自己的变换函数或对象(原本用作为加速学习曲线的练习),我建议使用 Torch 模块 torchvision.transforms:
“此模块定义了一组可组合式类函数对象,这些对象可作为实参传递到数据集(如 torchvision.CIFAR10),并在加载数据后 __getitem__ 返回数据之前,对数据执行变换”。
以下列出了可能的变换:
■输入 [23]:
from torchvision import transforms
dir(transforms)
■输出 [23]:
[‘CenterCrop’,
‘ColorJitter’,
‘Compose’,
‘FiveCrop’,
‘Grayscale’,
‘Lambda’,
‘LinearTransformation’,
‘Normalize’,
‘Pad’,
‘RandomAffine’,
‘RandomApply’,
‘RandomChoice’,
‘RandomCrop’,
‘RandomErasing’,
‘RandomGrayscale’,
‘RandomHorizontalFlip’,
‘RandomOrder’,
‘RandomPerspective’,
‘RandomResizedCrop’,
‘RandomRotation’,
‘RandomSizedCrop’,
‘RandomVerticalFlip’,
‘Resize’,
‘Scale’,
‘TenCrop’,
‘ToPILImage’,
‘ToTensor’,
‘__builtins__’,
‘__cached__’,
‘__doc__’,
‘__file__’,
‘__loader__’,
‘__name__’,
‘__package__’,
‘__path__’,
‘__spec__’,
‘functional’,
‘transforms’]
在此示例中,我们使用变换来执行了以下操作:
1) ToTensor - 从 PIL 图像转换为张量,并将输出格式定义为 CxHxW
2) Normalize - 将张量归一化
最新内容
手机 |
相关内容
STC15W芯片A/D、D/A转换的简单使用
STC15W芯片A/D、D/A转换的简单使用,简单使用,转换,芯片,模拟,输入,输出,STC15W系列芯片是一种高性能的单片机芯片,具有丰富的外设资硅谷:设计师利用生成式 AI 辅助芯片
硅谷:设计师利用生成式 AI 辅助芯片设计,芯片,生成式,硅谷,优化,修改,方法,在硅谷,设计师们正在利用生成式人工智能(AI)来辅助芯片设计清华研制出首个全模拟光电智能计算
清华研制出首个全模拟光电智能计算芯片ACCEL,芯片,智能计算,模拟,清华,混合,研发,清华大学最近成功研制出了一款全模拟光电智能计算如何利用示波器快速测量幅频特性?有
如何利用示波器快速测量幅频特性?有何注意事项?,测量,示波器,连接,输入,信号,频率,利用示波器快速测量幅频特性是一种常用的方法,可以什么是差动放大器,差动放大器的组成
什么是差动放大器,差动放大器的组成、特点、原理、分类、操作规程及发展趋势,分类,发展趋势,负载,信号,调节,输入,NTA4153NT1G差动放PLC控制器的主要抗干扰措施
PLC控制器的主要抗干扰措施,抗干扰,控制器,隔离,能力,输入,滤波器,PLC控制器作为一种专门用于工业自动化控制的设备,其稳定性和抗干功率放大器如何驱动超声波换能器
功率放大器如何驱动超声波换能器,较好,输入,装置,失真,方法,输出,功率放大器用于驱动超声波换能器,以实现超声波的发射和接收。TPS62对于初次使用的buck电源芯片,如何做
对于初次使用的buck电源芯片,如何做模块性能测试?,性能测试,模块,芯片,初次,确保,输入,对于初次使用的buck电源芯片,模块性能测试是非