设置种子
1 2 3 4 5 6 def set_seed (args ): random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.n_gpu > 0 : torch.cuda.manual_seed_all(args.seed)
设置随机种子
将种子赋予np
将种子赋予torch
将种子赋予cuda
GPU设置
1 2 3 4 5 6 7 8 9 10 if args.local_rank == -1 : device = torch.device('cuda' , args.gpu_id) args.world_size = 1 args.n_gpu = torch.cuda.device_count() else : torch.cuda.set_device(args.local_rank) device = torch.device('cuda' , args.local_rank) torch.distributed.init_process_group(backend='nccl' ) args.world_size = torch.distributed.get_world_size() args.n_gpu = 1
根据local_rank决定是否采取分布式。如果local_rank=-1,说明分布式失效;如果local_rank不等于-1,则根据不同的卡配置不同的进程数;获取设备device方便后续将数据和模型加载在上面(代码为.to(device));初始化设置分布式的后端等。
torch.distributed.barrier()的使用:
①数据集:
1 2 3 4 5 6 7 8 9 if args.local_rank not in [-1 , 0 ]: torch.distributed.barrier() labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset]( args, './data' ) if args.local_rank == 0 : torch.distributed.barrier()
有些操作是不需要多卡同时运行的,如数据集和模型的加载。因此,PyTorch对非主进程的卡上面的运行进行了barrier设置。如果是在并行训练非主卡上,其它进行需要先等待主进程读取并缓存数据集,再从缓存中读取数据,以同步不同进程的数据,避免出现数据处理不同步的现象。
②模型
1 2 3 4 5 6 7 if args.local_rank not in [-1 , 0 ]: torch.distributed.barrier() model = create_model(args) if args.local_rank == 0 : torch.distributed.barrier()
先对其余进程设置一个障碍,等到主进程加载完模型和数据后,再对主进程设置障碍,使所有进程都处于同一“出发线”,最后再同时释放。
数据集划分
本代码使用的数据集分为三类:带标签的训练集,不带标签的训练集,测试集。虽然表面上需要一个训练集是“不带标签”的,但是PyTorch并没有直接舍去标签的数据集设置。一开始我在想,如果是我自己来写代码,应该要怎么处理呢?后来发现代码根本没有拘泥于“不带标签”这个事情,因为在返回数据集和标签时,使用“_”直接代替掉标签即可,损失函数也不需要使用标签。
核心API: 1 labeled_dataset, unlabeled_dataset, test_dataset = DATASET_GETTERS[args.dataset](args, './data' )
dataset->cifar.py,发现调用了如下函数(get_cifar10和get_cifar100极其类似,只是数据集分类的类别数不一样而已。下面仅以get_cifar100为例):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 def get_cifar100(args, root): # 图像变换 transform_labeled = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(size=32, padding=int(32*0.125), padding_mode='reflect'), transforms.ToTensor(), transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=cifar100_mean, std=cifar100_std)]) # 数据集设置 base_dataset = datasets.CIFAR100( root, train=True, download=True) train_labeled_idxs, train_unlabeled_idxs = x_u_split( args, base_dataset.targets) train_labeled_dataset = CIFAR100SSL( root, train_labeled_idxs, train=True, transform=transform_labeled) train_unlabeled_dataset = CIFAR100SSL( root, train_unlabeled_idxs, train=True, transform=TransformFixMatch(mean=cifar100_mean, std=cifar100_std)) test_dataset = datasets.CIFAR100( root, train=False, transform=transform_val, download=False) return train_labeled_dataset, train_unlabeled_dataset, test_dataset
get_cifar100函数包括两部分:transform的设置和数据集设置。
(1)transform
对于测试集和带标签的训练集,可以根据论文[1]的介绍进行设置。但是对于不带标签的训练集,代码调用了TransformFixMatch类,因为这部分的训练集需要使用弱增强和强增强的方法,两种方法是不同的,所以需要特别设置一个callable的类,能够将两种transform手段凑在一块。当构建dataset调用transform时,可以直接调用call函数,直接返回两个增强手段处理后的图像。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 class TransformFixMatch (object ): def __init__ (self, mean, std ): self.weak = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(size=32 , padding=int (32 *0.125 ), padding_mode='reflect' )]) self.strong = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(size=32 , padding=int (32 *0.125 ), padding_mode='reflect' ), RandAugmentMC(n=2 , m=10 )]) self.normalize = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) def __call__ (self, x ): weak = self.weak(x) strong = self.strong(x) return self.normalize(weak), self.normalize(strong)
(2)数据索引设置
怎么从原始的CIFAR数据集提取出带标签的训练集和无标签的训练集?注意到PyTorch数据集类有一个函数成员def
getitem (self,
index),核心参数是index,所以我们构建以上两个训练集,本质上是构建训练集对应的索引值。下面是索引生成函数x_u_split的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def x_u_split(args, labels): label_per_class = args.num_labeled // args.num_classes labels = np.array(labels) #每个label是一个数字 labeled_idx = [] unlabeled_idx = np.array(range(len(labels))) for i in range(args.num_classes): idx = np.where(labels == i)[0] #有[0]是因为np.where得到的是一个tuple,需要把tuple的元素提取出来 idx = np.random.choice(idx, label_per_class, False) labeled_idx.extend(idx) labeled_idx = np.array(labeled_idx) assert len(labeled_idx) == args.num_labeled if args.expand_labels or args.num_labeled < args.batch_size: num_expand_x = math.ceil( #向上取整 args.batch_size * args.eval_step / args.num_labeled) #等于17 #将参数元组的元素数组按水平方向进行叠加 labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)]) np.random.shuffle(labeled_idx) return labeled_idx, unlabeled_idx
每个类带标签数据的个数是均衡的,每个类带标签的数据个数 =
带标签数据总个数//类数,所以,使用一个循环(10个类)。
对于每一个类,找出他们在总数据(labels)中的数据索引,然后将labels(原本是列表)转换为numpy数组。并用random.choice随机选择label_per_class个数据,将他们加入到带标签的数据索引labeled_idx中。
对于不带标签的数据,原文作者使用了所有的数据(包含带标签的数据),所以他的索引为全部数据的索引,unlabeled_idx可以直接对应全体数据。
需要注意的一个点是,args.expand_labels参数作者默认为true的,所以我们要进行数据重复。
或者num_labeled比batch_size还小,则对数组进行扩充。
这里重复的次数num_expand_x为 64(batch_size )* 1024(eval_step)/ 4000
(num_labeled)=17次 所以带标签的数据为
68000个(每个索引都重复了17次)。
(3)数据集设置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 class CIFAR100SSL (datasets.CIFAR100): def __init__ (self, root, indexs, train=True , transform=None , target_transform=None , download=False ): super ().__init__(root, train=train, transform=transform, target_transform=target_transform, download=download) if indexs is not None : self.data = self.data[indexs] self.targets = np.array(self.targets)[indexs] def __getitem__ (self, index ): img, target = self.data[index], self.targets[index] img = Image.fromarray(img) if self.transform is not None : img = self.transform(img) if self.target_transform is not None : target = self.target_transform(target) return img, target
scheduler
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def get_cosine_schedule_with_warmup (optimizer, num_warmup_steps, num_training_steps, num_cycles=7. /16. , last_epoch=-1 ): def _lr_lambda (current_step ): if current_step < num_warmup_steps: return float (current_step) / float (max (1 , num_warmup_steps)) no_progress = float (current_step - num_warmup_steps) / \ float (max (1 , num_training_steps - num_warmup_steps)) return max (0. , math.cos(math.pi * num_cycles * no_progress)) return LambdaLR(optimizer, _lr_lambda, last_epoch)
scheduler是为了动态调整训练期间的学习率,使模型更好地收敛。论文使用的是带有warmup性质的余弦退火学习率调整器。核心是返回了一个自定义函数的学习率调整器,调整的函数是_lr_lambda,如果当前的step少于warmup的步数,则使用线性递增的策略一直增加到初始学习率;而后使用余弦变化的策略改变学习率:
是初始学习率, 是当前的步数, 是总步数。
混合精度
本代码使用的是英伟达开发的apex库,可以通过使用混合精度,在保证精度丢失很少的情况下,减少内存,增快训练速度。混合精度涉及对模型和优化器的重初始化、损失函数的反向传播等。代码如下:
1 2 3 4 5 from apex import ampmodel, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward()
指数移动平均(EMA)
EMA在本代码是用于更新模型权重的,核心公式就这一条: 这里的参数 代表测试用模型的参数权重。训练时,原模型就按照正常的节奏来训练、更新权重,而另外开辟一个EMA模型,在原模型更新权重的同时也跟着更新权重,并作为最后使用的模型,检测在测试集上的表现。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 class ModelEMA (object ): def __init__ (self, args, model, decay ): self.ema = deepcopy(model) self.ema.to(args.device) self.ema.eval () self.decay = decay self.ema_has_module = hasattr (self.ema, 'module' ) self.param_keys = [k for k, _ in self.ema.named_parameters()] self.buffer_keys = [k for k, _ in self.ema.named_buffers()] for p in self.ema.parameters(): p.requires_grad_(False ) def update (self, model ): needs_module = hasattr (model, 'module' ) and not self.ema_has_module with torch.no_grad(): msd = model.state_dict() esd = self.ema.state_dict() for k in self.param_keys: if needs_module: j = 'module.' + k else : j = k model_v = msd[j].detach() ema_v = esd[k] esd[k].copy_(ema_v * self.decay + (1. - self.decay) * model_v) for k in self.buffer_keys: if needs_module: j = 'module.' + k else : j = k esd[k].copy_(msd[j])
权重衰减(Weight Decay)
1 2 3 4 5 6 7 8 9 grouped_parameters = [ {'params' : [p for n, p in model.named_parameters() if not any ( nd in n for nd in no_decay)], 'weight_decay' : args.wdecay}, {'params' : [p for n, p in model.named_parameters() if any ( nd in n for nd in no_decay)], 'weight_decay' : 0.0 } ]
核心算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 labeled_iter = iter (labeled_trainloader) unlabeled_iter = iter (unlabeled_trainloader) model.train() for epoch in range (args.start_epoch, args.epochs): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() losses_x = AverageMeter() losses_u = AverageMeter() mask_probs = AverageMeter() if not args.no_progress: p_bar = tqdm(range (args.eval_step), disable=args.local_rank not in [-1 , 0 ]) for batch_idx in range (args.eval_step): try : inputs_x, targets_x = labeled_iter.next () except : if args.world_size > 1 : labeled_epoch += 1 labeled_trainloader.sampler.set_epoch(labeled_epoch) labeled_iter = iter (labeled_trainloader) inputs_x, targets_x = labeled_iter.next () try : (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next () except : if args.world_size > 1 : unlabeled_epoch += 1 unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch) unlabeled_iter = iter (unlabeled_trainloader) (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next () data_time.update(time.time() - end) batch_size = inputs_x.shape[0 ] inputs = interleave( torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2 *args.mu+1 ).to(args.device) targets_x = targets_x.to(args.device) logits = model(inputs) logits = de_interleave(logits, 2 *args.mu+1 ) logits_x = logits[:batch_size] logits_u_w, logits_u_s = logits[batch_size:].chunk(2 ) del logits Lx = F.cross_entropy(logits_x, targets_x, reduction='mean' ) pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1 ) max_probs, targets_u = torch.max (pseudo_label, dim=-1 ) mask = max_probs.ge(args.threshold).float () Lu = (F.cross_entropy(logits_u_s, targets_u, reduction='none' ) * mask).mean() loss = Lx + args.lambda_u * Lu
其中torch.max的用法参考如下
1 2 a = torch.randn(24 ).reshape(2 ,3 ,4 ) print (a)
运行结果:
1 2 3 4 5 6 7 tensor([[[-0.9135 , 1.3096 , 0.2803 , -0.9314 ], [-0.2687 , -0.0968 , -0.7156 , -0.8814 ], [-1.0099 , 1.6910 , 0.3458 , -0.6547 ]], [[-0.4334 , -0.0464 , -1.9236 , 0.3148 ], [ 0.3628 , -0.7063 , -0.1750 , 1.5068 ], [ 1.1270 , -0.9374 , -0.8419 , -0.0050 ]]])
1 2 3 A = torch.softmax(a.detach()/1 , dim=-1 ) print (A)
1 2 3 4 5 6 7 tensor([[[0.0689 , 0.6362 , 0.2273 , 0.0677 ], [0.2968 , 0.3525 , 0.1899 , 0.1608 ], [0.0472 , 0.7025 , 0.1830 , 0.0673 ]], [[0.2079 , 0.3061 , 0.0468 , 0.4392 ], [0.1974 , 0.0678 , 0.1153 , 0.6196 ], [0.6294 , 0.0799 , 0.0879 , 0.2029 ]]])
1 2 3 4 5 max_probs, targets_u = torch.max (A, dim=-1 ) print (max_probs)print (max_probs.shape)print (targets_u)print (targets_u.shape)
1 2 3 4 5 6 7 8 9 tensor([[0.6362 , 0.3525 , 0.7025 ], [0.4392 , 0.6196 , 0.6294 ]]) torch.Size([2 , 3 ]) tensor([[1 , 1 , 1 ], [3 , 3 , 0 ]]) torch.Size([2 , 3 ])
模型保存与加载
模型保存过程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 if args.local_rank in [-1 , 0 ]: test_loss, test_acc = test(args, test_loader, test_model, epoch) args.writer.add_scalar('train/1.train_loss' , losses.avg, epoch) args.writer.add_scalar('train/2.train_loss_x' , losses_x.avg, epoch) args.writer.add_scalar('train/3.train_loss_u' , losses_u.avg, epoch) args.writer.add_scalar('train/4.mask' , mask_probs.avg, epoch) args.writer.add_scalar('test/1.test_acc' , test_acc, epoch) args.writer.add_scalar('test/2.test_loss' , test_loss, epoch) is_best = test_acc > best_acc best_acc = max (test_acc, best_acc) model_to_save = model.module if hasattr (model, "module" ) else model if args.use_ema: ema_to_save = ema_model.ema.module if hasattr ( ema_model.ema, "module" ) else ema_model.ema save_checkpoint({ 'epoch' : epoch + 1 , 'state_dict' : model_to_save.state_dict(), 'ema_state_dict' : ema_to_save.state_dict() if args.use_ema else None , 'acc' : test_acc, 'best_acc' : best_acc, 'optimizer' : optimizer.state_dict(), 'scheduler' : scheduler.state_dict(), }, is_best, args.out) test_accs.append(test_acc) logger.info('Best top-1 acc: {:.3f}' .format (best_acc)) logger.info('Mean top-1 acc: {:.3f}\n' .format ( np.mean(test_accs[-20 :])))
状态字典:state_dict:
在PyTorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的参数中,(使用model.parameters()可以进行访问)。
state_dict是Python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层,线性层等)的模型才具有state_dict这一项。目标优化torch.optim也有state_dict属性,它包含有关优化器的状态信息,以及使用的超参数。
因为state_dict的对象是Python字典,所以它们可以很容易的保存、更新、修改和恢复,为PyTorch模型和优化器添加了大量模块。
下面通过从简单模型训练一个分类器中来了解一下state_dict的使用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 import torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimclass TheModelClass (nn.Module): def __init__ (self ): super (TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3 , 6 , 5 ) self.pool = nn.MaxPool2d(2 , 2 ) self.conv2 = nn.Conv2d(6 , 16 , 5 ) self.fc1 = nn.Linear(16 * 5 * 5 , 120 ) self.fc2 = nn.Linear(120 , 84 ) self.fc3 = nn.Linear(84 , 10 ) def forward (self, x ): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1 , 16 * 5 * 5 ) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = TheModelClass() optimizer = optim.SGD(model.parameters(), lr=0.001 , momentum=0.9 ) print ("Model's state_dict:" )for param_tensor in model.state_dict(): print (param_tensor, "\t" , model.state_dict()[param_tensor].size()) print ("Optimizer's state_dict:" )for var_name in optimizer.state_dict(): print (var_name, "\t" , optimizer.state_dict()[var_name])
输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer' s state_dict:state {} param_groups [{'lr' : 0.001 , 'momentum' : 0.9 , 'dampening' : 0 , 'weight_decay' : 0 , 'nesterov' : False , 'maximize' : False , 'params' : [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ]}]
定义save_checkpoint保存完整模型
1 2 3 4 5 def save_checkpoint (state, is_best, checkpoint, filename='checkpoint.pth.tar' ): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar' ))
当保存好模型用来推断的时候,只需要保存模型学习到的参数,使用torch.save()函数来保存模型state_dict。
在 PyTorch 中最常见的模型保存使‘.pt’或者是‘.pth’作为模型文件扩展名。
在运行推理之前,务必调用model.eval()去设置 dropout 和 batch
normalization 层为评估模式。如果不这么做,可能导致模型推断结果不一致。
注意:
load_state_dict()函数只接受字典对象,而不是保存对象的路径。这就意味着在你传给load_state_dict()函数之前,你必须反序列化你保存的state_dict。例如,你无法通过
model.load_state_dict(PATH)来加载模型。
保存和加载 Checkpoint
用于推理/继续训练
保存Checkpoint:
1 2 3 4 5 6 7 8 9 save_checkpoint({ 'epoch' : epoch + 1 , 'state_dict' : model_to_save.state_dict(), 'ema_state_dict' : ema_to_save.state_dict() if args.use_ema else None , 'acc' : test_acc, 'best_acc' : best_acc, 'optimizer' : optimizer.state_dict(), 'scheduler' : scheduler.state_dict(), }, is_best, args.out)
当保存成 Checkpoint
的时候,可用于推理或者是继续训练,保存的不仅仅是模型的state_dict。保存优化器的state_dict也很重要,
因为它包含作为模型训练更新的缓冲区和参数。你也许想保存其他项目,比如最新记录的训练损失,外部的torch.nn.Embedding层等等。
要保存多个组件,请在字典中组织它们并使用torch.save()来序列化字典。PyTorch
中常见的保存checkpoint是使用 .tar 文件扩展名。
要加载项目,首先需要初始化模型和优化器,然后使用torch.load()来加载本地字典。这里,你可以非常容易的通过简单查询字典来访问你所保存的项目。
请记住在运行推理之前,务必调用model.eval()去设置 dropout 和
batch normalization 为评估。如果不这样做,有可能得到不一致的推断结果。
如果你想要恢复训练,请调用model.train()以确保这些层处于训练模式。
加载Checkpoint:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 if args.resume: logger.info("==> Resuming from checkpoint.." ) assert os.path.isfile( args.resume), "Error: no checkpoint directory found!" args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc' ] args.start_epoch = checkpoint['epoch' ] model.load_state_dict(checkpoint['state_dict' ]) if args.use_ema: ema_model.ema.load_state_dict(checkpoint['ema_state_dict' ]) optimizer.load_state_dict(checkpoint['optimizer' ]) scheduler.load_state_dict(checkpoint['scheduler' ])
加载最优模型:
1 2 3 4 5 6 filepath = os.path.join(args.out, 'model_best.pth.tar' ) assert os.path.isfile(filepath), "Error: no model_best directory found!" model.load_state_dict(torch.load(filepath)['state_dict' ]) if args.use_ema: ema_model.ema.load_state_dict(torch.load(filepath)['ema_state_dict' ])
accuracy
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 def accuracy (output, target, topk=(1 , ) ): """Computes the precision@k for the specified values of k""" maxk = max (topk) batch_size = target.size(0 ) _, pred = output.topk(maxk, dim=1 , largest=True , sorted =True ) pred = pred.t() correct = pred.eq(target.reshape(1 , -1 ).expand_as(pred)) """ correct=([[0, 0, 1, ..., 0, 0, 0], [1, 0, 0, ..., 0, 0, 0], [0, 0, 0, ..., 1, 0, 0], [0, 0, 0, ..., 0, 0, 0], [0, 1, 0, ..., 0, 0, 0]], device='cuda:0', dtype=torch.uint8) """ res = [] for k in topk: correct_k = correct[:k].reshape(-1 ).float ().sum (0 ) res.append(correct_k.mul_(100.0 / batch_size)) return res
topk函数:
1 torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
input (Tensor):输入张量,一个tensor数据 k
(int):指明是得到前k个数据以及其index dim (int, optional):
指定在哪个维度上排序, 默认是最后一个维度 largest (bool,
optional):如果为True,按照大到小排序; 如果为False,按照小到大排序
sorted (bool, optional) :控制返回值是否排序 out (tuple,
optional):可选输出张量 (Tensor, LongTensor)
例如:
1 2 3 4 5 a = torch.tensor([[ 0 , 1 , 1 , 0 ], [ 0 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 ], [ 1 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 1 ]])
1 2 3 4 5 6 tensor([[0 , 1 , 1 , 0 ], [0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 ], [1 , 0 , 0 , 0 ], [0 , 0 , 0 , 1 ]]) torch.Size([5 , 4 ])
1 2 3 4 5 6 7 8 9 correct_1 = a[:1 ].reshape(-1 ) print (correct_1.shape)print (correct_1)correct_1 = correct_1.float ().sum (0 ) print (correct_1.shape)print (correct_1)correct_1.mul_(100.0 / 4 ) print (correct_1.shape)print (correct_1)
1 2 3 4 5 6 torch.Size([4 ]) tensor([0 , 1 , 1 , 0 ]) torch.Size([]) tensor(2. ) torch.Size([]) tensor(50. )
1 2 3 4 5 6 7 8 9 correct_5 = a[:5 ].reshape(-1 ) print (correct_5.shape)print (correct_5)correct_5 = correct_5.float ().sum (0 ) print (correct_5.shape)print (correct_5)correct_5.mul_(100.0 / 4 ) print (correct_5.shape)print (correct_5)
1 2 3 4 5 6 torch.Size([20 ]) tensor([0 , 1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 1 ]) torch.Size([]) tensor(4. ) torch.Size([]) tensor(100. )