【python】only one element tensors can be converted to Python scalars

mac2024-11-04  14

1-问题:only one element tensors can be converted to Python scalars

2-分析问题

for e in range(start, self.num_epochs): for i, (input_data, labels) in enumerate(zip(tqdm(self.data_loader))): iter_ctr += 1 start = time.time() input_data = self.to_var(input_data) # print('input_data.size:',input_data.size) # input_data=input_data.reshape[256,-1] input_data=input_data.view(input_data.size(0),-1) total_loss,sample_energy, recon_error, cov_diag = self.dagmm_step(input_data)

在训练网络时数据输入用了torch.utils.data.DataLoader函数,用自己的数据库进行数据封装。由于先用了zip(),所以输出是tuple元组格式,用np.array将其转换为数组格式就可以将其导入,但是问题来了,每次只能转一个batch_size的数据之后的不能连续转换

3-解决方案

for e in range(self.num_epochs): print('Epoch ({}/{})----------------------------------------------------------------------------'.format(e, self.num_epochs)) batch_idxs = len(self.data_loader)// self.batch_size #sample 367200,batch_size=256 for i in tqdm(range(0, batch_idxs)): iter_ctr += 1 start = time.time() batch = self.data_loader[i * self.batch_size:(i + 1) * self.batch_size] input_data = np.array(batch).astype(np.float32) # print('batch_images', input_data.shape) input_data=torch.tensor(input_data) input_data = self.to_var(input_data) # print('input_data.size:',input_data.size) # input_data=input_data.reshape[256,-1] input_data=input_data.view(input_data.size(0),-1) total_loss,sample_energy, recon_error, cov_diag = self.dagmm_step(input_data)

将上述代码改写了一下,解决了该问题,但是感觉不是最优解,我看了其他博客他们的解决方案主要是:数组和张量相互转化,经过实验证明在我的问题上并未能解决问题

要是大家有更好的解决方案欢迎共享~

最新回复(0)