我在用pytorch读取mnist数据集时,采用了两种方法:官方下载和读取本地制作好的数据集,现在读出来的图片的torch.szie()大小不同,分别是torch.Size([1, 28, 28])和torch.Size([3, 28, 28]),请问有什么办法可以把3,28,28变成1,28,28,谢谢!
train_dataset1 = datasets.MNIST(\nroot=\'./data\', train=True, transform=transforms.ToTensor(), download=True)\n\ntrain\\_loader1 = DataLoader\\(train\\_dataset\\, batch\\_size=batch\\_size\\, shuffle=True\\)\n\nfrom torchvision.datasets import ImageFolder\nbatch_size = 128\npath=\'D:/work/\'\ntrain_dataset2 = ImageFolder(path,transform=transforms.ToTensor())\ntrain\\_loader2 = DataLoader\\(train\\_dataset\\, batch\\_size=batch\\_size\\, shuffle=False\\)\\print(train_dataset1[0][0].size())\nprint(train_dataset2[0][0].size())out:
torch.Size([1, 28, 28])\ntorch.Size([3, 28, 28]) |