主页
手机版
扫描查看手机站
所在位置:首页 → 教程资讯 → 噪声标签学习下的任务和实验设计

噪声标签学习下的任务和实验设计

发布: 更新时间:2024-06-29 15:52:57

噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。

训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。

论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:

  • 第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction
    提供的代码中
    ,使用了这种方法。但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。

  • 第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets
    提供的代码
    使用了这种方式。

places365可以使用

torchvision.datasets.Places365

下载,由于训练集较大,通常是用它的验证集作为辅助数据集。

imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见
这里
。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets
提供的链接。

实验

使用第二种方法构造含开集、闭集噪声数据集,开集噪声率

\(r_{ood}=0.2\)

,闭集噪声率

\(r_{id}=0.2\)

;辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。


设计imagenet32数据集

import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

_train_list = ['train_data_batch_1',
               'train_data_batch_2',
               'train_data_batch_3',
               'train_data_batch_4',
               'train_data_batch_5',
               'train_data_batch_6',
               'train_data_batch_7',
               'train_data_batch_8',
               'train_data_batch_9',
               'train_data_batch_10']
_val_list = ['val_data']

# ...(代码内容较长,省略)...

目录结构:

imagenet32
├─ train_data_batch_1
├─ train_data_batch_10
├─ train_data_batch_2
├─ train_data_batch_3
├─ train_data_batch_4
├─ train_data_batch_5
├─ train_data_batch_6
├─ train_data_batch_7
├─ train_data_batch_8
├─ train_data_batch_9
└─ val_data


设计cifar数据集

import torchvision
import numpy as np
from dataset.imagenet32 import Imagenet32

# ...(代码内容较长,省略)...


查看统计结果

import pandas as pd
import altair as alt
from dataset.cifar import CIFAR10, CIFAR100

# ...(代码内容较长,省略)...


运行环境

# Name                    Version                   Build  Channel
altair                    5.3.0                    pypi_0    pypi
pytorch                   2.3.1           py3.12_cuda12.1_cudnn8_0    pytorch
pandas                    2.2.2                    pypi_0    pypi
软件上新 查看更多