transunet复现(swin transformer复现)

transunet复现(swin transformer复现)p 在使用过多个基于 transformer 的开源网络过后 包括 segformer objectformer swinunet transfuse 等 发现加载官方公布的预训练权重的 transunet 在自己的任务上表现最好 我的方向是图像复制篡改检测 就是一种二分类语义分割 amp p

大家好,我是讯享网,很高兴认识大家。



 <p>        在使用过多个基于transformer的开源网络过后&#xff0c;包括segformer、objectformer、swinunet、transfuse等&#xff0c;发现加载官方公布的预训练权重的transunet在自己的任务上表现最好&#xff08;我的方向是图像复制篡改检测&#xff0c;就是一种二分类语义分割&#xff09;&#xff0c;在不做任何修改的情况仅重新成功训练就有远超普通全卷积网络的能力&#xff0c;&#xff08;这里仅指transunet的R50-ViT-B_16这种卷积后接transformer的形式&#xff0c;而同样加载预训练权重的纯vit模型表现就不尽如人意了&#xff0c;很不理解&#xff1f;&#xff09;</p> 

讯享网

        先帮还没入手的的兄弟解决些复现的小问题。

        官方code:GitHub - Beckschen/TransUNet: This repository includes the official project of TransUNet, presented in our paper: TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation.

里面还有所有模型预训练权重和一些训练数据集的下载链接,就在readme哪儿。

        在自己的数据集上复现还是比较简单的,基本上重写一下dataloader就行了,我直接就贴我的代码了,改的另一位博主如雾如电的,稍微借鉴了原始的数据读取但我没有加任何图像预处理。

讯享网def own_data_loader(img_path, mask_path): 
img &#61; cv2.imread(img_path) img &#61; cv2.resize(img, (224,224), interpolation &#61; cv2.INTER_NEAREST) mask &#61; cv2.imread(mask_path, 0) mask &#61; cv2.resize(mask, (224,224), interpolation &#61; cv2.INTER_NEAREST) # img &#61; randomHueSaturationValue(img, # hue_shift_limit&#61;(-30, 30), # sat_shift_limit&#61;(-5, 5), # val_shift_limit&#61;(-15, 15)) # # img, mask &#61; randomShiftScaleRotate(img, mask, # shift_limit&#61;(-0.1, 0.1), # scale_limit&#61;(-0.1, 0.1), # aspect_limit&#61;(-0.1, 0.1), # rotate_limit&#61;(-0, 0)) # img, mask &#61; randomHorizontalFlip(img, mask) # img, mask &#61; randomVerticleFlip(img, mask) # img, mask &#61; randomRotate90(img, mask) # mask &#61; np.expand_dims(mask, axis&#61;2) img &#61; np.array(img, np.float32) / 255.0 * 3.2 - 1.6 # img &#61; np.array(img, np.float32) / 255.0 # mask &#61; np.array(mask, np.float32) mask &#61; np.array(mask, np.float32) / 255.0 mask[mask &gt;&#61; 0.5] &#61; 1 mask[mask &lt; 0.5] &#61; 0 # mask &#61; np.squeeze(mask,axis&#61;2) # pyplot.imshow(mask) # pyplot.show() img &#61; np.array(img, np.float32).transpose(2, 0, 1) # mask &#61; np.array(mask, np.float32).transpose(2, 0, 1) return img, mask</pre> 
讯享网def read_own_data(root_path, mode=‘train’): 
images &#61; [] masks &#61; [] image_root &#61; os.path.join(root_path&#43;&#39;/images&#39;&#43; &#39;/train&#39;) gt_root &#61; os.path.join(root_path&#43; &#34;/annotations&#34;&#43; &#39;/train_png&#39;) for image_name in os.listdir(gt_root): label_path &#61; os.path.join(gt_root, image_name) masks.append(label_path) for image_name in os.listdir(image_root): image_path &#61; os.path.join(image_root, image_name) images.append(image_path) return images, masks</pre> 
讯享网class ImageFolder(Dataset): 
def __init__(self,root_path, mode&#61;&#39;train&#39;): self.root &#61; root_path self.mode &#61; mode self.images, self.labels &#61; read_own_data(self.root, self.mode) def __getitem__(self, index): if self.mode &#61;&#61; &#39;test&#39;: img, mask &#61; own_data_test_loader(self.images[index], self.labels[index]) else: img, mask &#61; own_data_loader(self.images[index], self.labels[index]) # img &#61; torch.Tensor(img) # mask &#61; torch.Tensor(mask) return img, mask def __len__(self): # assert len(self.images) &#61;&#61; len(self.labels), &#39;The number of images must be equal to labels&#39; return len(self.images)</pre> 

讯享网
调用就一行,在trainer.py好像还要改一点点
讯享网db_train = ImageFolder( args.root_path,mode=‘train’)
iterator = tqdm(range(max_epoch), ncols=70) for epoch_num in iterator: 
讯享网progress_bar &#61; tqdm(trainloader) for image_batch, label_batch in trainloader: # image_batch, label_batch &#61; sampled_batch[&#39;image&#39;], sampled_batch[&#39;label&#39;] image_batch, label_batch &#61; image_batch.cuda(), label_batch.cuda() # print(image_batch.shape, &#34;333&#34;) outputs &#61; model(image_batch) # outputs &#61; nn.Sigmoid()(outputs) # print(outputs) #torch.Size([6, 2, 224, 224]) # print(label_batch.shape) #torch.Size([6, 1, 224, 224]) # 这里的ce_loss &#61; CrossEntropyLoss()常用于多分类&#xff0c;换成BCELoss # loss_ce &#61; ce_loss(outputs, label_batch[:].long()) # loss_dice &#61; dice_loss(outputs, label_batch, softmax&#61;True) # loss &#61; 0.4 * loss_ce &#43; 0.6 * loss_dice outputs &#61; torch.squeeze(outputs) label_batch &#61; torch.squeeze(label_batch) # loss_ce &#61; bce_loss(outputs, label_batch) # loss_dice &#61; dice_loss(outputs, label_batch) # loss &#61; 0.4 * loss_ce &#43; 0.6 * loss_dice loss &#61; bce_loss(outputs, label_batch) optimizer.zero_grad() loss.backward() optimizer.step() lr_ &#61; base_lr * (1.0 - iter_num / max_iterations) 0.9 for param_group in optimizer.param_groups: param_group[&#39;lr&#39;] &#61; lr_ iter_num &#61; iter_num &#43; 1 progress_bar.set_description( &#39; Epoch: {}/{}. Iteration: {}/{}. Mini loss: {:.5f} &#39;.format( epoch_num, 30, iter_num , max_iterations, loss.item(),))</pre> 
我这里也改了损失函数,用惯了Bce了,不过大家还是先用一下原始损失函数再根据自己任务调整。 其它基本就是改一些参数就行了,到了要加载预训练权重的时候可以发现,公布的权重并不是整个网络的,只包括resnet50三层和12个transformer层,不包括后续四个decoder层。这里就是我迫切想和诸位交流并咨询的一点,我做了大量实验表明,只有全部加载提供的预训练权重才能得到一个差不多的精度。不论你是只加载resnet50三层或者12个transformer层或者将resnet50三层换成同样有imagenet预训练权重的vgg等其它网络都会极大影响性能。         而且在resnet50与12个transformer之间新增一些模块或操作也会同样影响性能,不过对于在decoder部分做出的改动对于性能影响不大。         现在的问题就是resnet50与12个transformer深度绑定而且必须加载预训练权重才能保证性能,这就导致你不能对这两部分做出任何改动。所以想请教大家如何才能获得这篇论文的一个原始预训练数据,直接把imagenet拿过来做图像分类就行吗?         因为想靠transformer再混一篇出来,所以还想请教大家还有哪些带有imagenet预训练权重的transformer结构?因为以我的经验来看,哪些没带预训练权重的former完全没办法成功训练,同样的参数下它的损失很难收敛。
小讯
上一篇 2025-05-26 15:25
下一篇 2025-06-01 22:56

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/201619.html