发布: 更新时间:2024-09-26 10:09:56
数据集蒸馏旨在从大型数据集中合成每类(
IPC
)少量图像,以在最小性能损失的情况下近似完整数据集训练。尽管在非常小的
IPC
范围内有效,但随着
IPC
增加,许多蒸馏方法变得不太有效甚至性能不如随机样本选择。论文对各种
IPC
范围下的最先进的基于轨迹匹配的蒸馏方法进行了研究,发现这些方法在增加
IPC
的情况下很难将更难样本的复杂、罕见特征纳入合成数据集中,导致了容易和难的测试样本之间持续存在的覆盖差距。受到这些观察的启发,论文提出了
SelMatch
,一种能够有效随
IPC
扩展的新型蒸馏方法。
SelMatch
使用基于选择的初始化和通过轨迹匹配进行部分更新来管理合成数据集,以适应针对
IPC
范围定制的期望难度级别。在对
CIFAR-10
/
100
和
TinyImageNet
的测试中,
SelMatch
在
5%
到
30%
的子集比率上始终优于主流的仅选择和仅蒸馏方法。来源:晓飞的算法工程笔记 公众号,转载请注明出处
论文: SelMatch: Effectively Scaling Up Dataset Distillation via Selection-Based Initialization and Partial Updates by Trajectory Matching
数据集缩减对于数据高效学习至关重要,它涉及从大型数据集中合成或选择较少数量的样本,同时确保在这个缩减后的数据集上训练的模型性能与在完整数据集上训练的相比保持可比性或性能降低最小化。这种方法解决了在大型数据集上训练神经网络时所面临的挑战,如高计算成本和内存需求。
在这一领域中一种重要的技术是数据集蒸馏,也被称为数据集凝聚。这种方法将大型数据集提炼为一个更小的合成数据集。与核心集选择方法相比,数据蒸馏在图像分类任务中表现出显著的性能,特别是在极小规模上。例如,匹配训练轨迹(
MTT
)算法仅使用
CIFAR-10
数据集的
1%
,在简单的
ConvNet
上实现了
71.6%
的准确率,接近完整数据集的
84.8%
准确率。这种显著的效率来自于优化过程,在这个过程中,合成样本在连续空间中被最优地学习,而不是直接从原始数据集中选择。
然而,最近的研究表明,随着合成数据集的规模或每类图像(
IPC
)的增加,许多数据集蒸馏方法失去了效果,甚至表现不如随机样本选择。这一现象令人费解,考虑到蒸馏相对于离散样本选择提供的更大优化自由度。具体来说,
DATM
通过分析最先进的
MTT
方法的训练轨迹来调查这一现象,指出了在合成数据集过程中方法所关注的训练轨迹阶段如何显著影响蒸馏数据集的有效性。特别是,在早期轨迹中学习到的简单模式和在后期阶段学习到的困难模式明显影响了
MTT
在不同
IPC
情况下的性能。
论文进一步通过比较在不同
IPC
情况下,
MTT
方法涵盖合成数据集中简单和困难真实样本的情况,发现随着
IPC
增加,蒸馏方法未能充分将困难样本的稀有特征纳入合成数据集中,这导致了简单样本与困难样本之间的一致覆盖差距。在更高
IPC
范围内,数据集蒸馏方法效果降低的部分原因是它们倾向于专注于数据集中更简单、更具代表性的特征。相反,随着
IPC
的增加,涵盖更难、更稀有的特征对于在缩减数据集上训练的模型的泛化能力变得更加关键,这点在数据选择研究中得到了实证和理论上的验证。
受到这些观察的启发,论文提出了一种新颖的方法,名为
SelMatch
,作为有效扩展数据集蒸馏方法的解决方案。随着
IPC
的增加,合成数据集应该涵盖真实数据集更复杂和多样化的特征,具有适当的难度水平。通过基于选择的初始化和通过轨迹匹配的部分更新,管理合成数据集的期望难度级别。
IPC
IPC
IPC
IPC
IPC
在
CIFAR-10
/
100
和
TinyImageNet
上评估了
SelMatch
,并展示了在从
5%
到
30%
的子集比例设置中,与最先进的仅选择和仅蒸馏方法相比的优越性。值得注意的是,在
CIFAR-100
中,当每类有
50
张图像(
10%
比例)的情况下,与领先方法相比,
SelMatch
的测试准确率提高了
3.5
%。
数据集减少中的两种主要方法:样本选择和数据集蒸馏。
在样本选择中,主要有两种方法:基于优化和基于评分的选择。
基于优化的选择旨在识别一个小的核心集,有效地代表完整数据集的各种特征。例如,
Herding
和
K-center
选择一个近似于完整数据集分布的核心集。
Craig
和
GradMatch
寻求一个核心集,在神经网络训练中,它能够最小化与完整数据集的平均梯度差异。尽管在小到中等
IPC
范围内有效,但是与基于评分的选择相比,这些方法在可伸缩性和性能方面常常面临问题,特别是随着
IPC
的增加。
基于评分的选择能够根据神经网络训练中每个实例的难度或影响分配值。例如,
Forgetting
通过计算先前被正确分类但在之后的多个时期被误分类的次数来评估实例的学习难度。
C-score
将困难性评估为从训练集中删除样本时误分类的概率。这些方法优先考虑困难样本,捕捉罕见和复杂的特征,并在较大的
IPC
规模下优于基于优化的选择方法。这些研究表明,随着
IPC
的增加,引入更难或更稀有的特征对于模型的泛化能力的提高变得越来越重要。
数据集蒸馏旨在创建一个小的合成集
\(\mathcal{S}\)
,以便在
\(\mathcal{S}\)
上训练的模型
\(\theta^\mathcal{S}\)
能够实现良好的泛化性能,在完整数据集
\(\mathcal{T}\)
上表现良好:
这里,
\(\mathcal{L}^\mathcal{T}\)
和
\(\mathcal{L}^\mathcal{S}\)
分别是
\(\mathcal{T}\)
和
\(\mathcal{S}\)
上的损失。为了应对双层优化的计算复杂性和内存需求,现有的工作采用了两种方法:基于替代的匹配和基于核的方法。基于替代的匹配将复杂的原始目标替换为更简单的代理任务。例如,
DC
、
DSA
和
MTT
旨在通过匹配梯度或轨迹,使在
\(\mathcal{S}\)
上训练的模型
\(\theta^\mathcal{S}\)
的轨迹与完整数据集
\(\mathcal{T}\)
的轨迹保持一致。
DM
确保
\(\mathcal{S}\)
和
\(\mathcal{T}\)
在特征空间中具有相似的分布。另外,基于核的方法利用核方法近似神经网络对
\(\theta^\mathcal{S}\)
的训练,并为内部优化推导出闭式解。例如,
KIP
使用神经切线核(
NTK
)进行核岭回归,
FrePo
通过仅专注于最后一个可学习层的回归来减少训练成本。然而,随着
IPC
的增加,基于替代的匹配和基于核的方法在可扩展性或性能方面都难以有效扩展。
DC-BENCH
指出,与高
IPC
情况下的随机样本选择相比,这些方法性能不佳。
近期的研究致力于解决最先进的
MTT
方法的可扩展性问题,主要关注计算方面,通过降低内存需求,或性能方面,通过在后续时期利用完整数据集的训练轨迹。具体而言,
DATM
发现与早期训练轨迹保持一致可增强在低
IPC
制度下的性能,而与后期轨迹保持一致对于高
IPC
制度下更有益。基于这一观察,
DATM
根据
IPC
优化了轨迹匹配范围,从而自适应地将专家轨迹中更容易或更困难的模式纳入,从而提高了
MTT
的可扩展性。虽然
DATM
可有效地确定轨迹匹配范围的下限和上限,但在这些范围之外的匹配损失变化趋势上,明确量化或搜寻所需的训练轨迹困难水平仍然是一个具有挑战性的任务。相比之下,论文的
SelMatch
利用基于选择的初始化和通过轨迹匹配进行部分更新,以纳入适合每个
IPC
的难样本的复杂特征。尤其是,论文的方法引入了一种新颖的策略,即针对每个
IPC
范围为合成样本初始化定制的困难水平,这是在以往的数据集蒸馏文献中尚未探讨的。此外,与专门设计用于增强
MTT
的
DATM
不同,
SelMatch
的主要组成部分,即基于选择的初始化和部分更新,在各种蒸馏方法中具有更广泛的适用性。
最先进的数据集蒸馏方法
MTT
将作为基准,用于分析传统数据集蒸馏方法在大
IPC
范围内的局限性。
MTT
的目标是通过匹配真实数据集
\(\mathcal{D}_\textrm{real}\)
和合成数据集
\(\mathcal{D}_\textrm{syn}\)
之间的训练轨迹来生成合成数据集。在每个蒸馏迭代中,合成数据集会被更新,以最小化匹配损失,该损失以真实数据集
\(\mathcal{D}_\textrm{real}\)
的训练轨迹
\(\{\theta_t^*\}\)
和合成数据集
\(\mathcal{D}_\textrm{syn}\)
的训练轨迹
\(\{\hat{\theta}_t\}\)
为定义。
其中,
\(\theta_t^*\)
是在第
\(t\)
步上在
\(\mathcal{D}_\textrm{real}\)
上训练的模型参数。从
\(\hat{\theta}_{t}=\theta_t^*\)
开始,
\(\hat{\theta}_{t+N}\)
是通过在
\(\mathcal{D}_\textrm{syn}\)
上训练
\(N\)
步后获得的模型参数,而
\({\theta}^*_{t+M}\)
是在
\(\mathcal{D}_\textrm{real}\)
上训练
\(M\)
步后获得的参数。
首先分析
MTT
生成的合成数据的模式如何随着每类图像(
IPC
)的增加而演变。要使数据集蒸馏方法在更大的合成数据集中保持有效,蒸馏过程在每类图像增加时应继续向合成样本提供真实数据集的新颖和复杂模式。轨迹匹配方法在低
IPC
水平上虽然处于最先进地位,但在实现这一目标方面还存在不足。
论文通过检查真实(测试)数据集的“覆盖率”来展示这一点。“覆盖率”被定义为在特征空间内与合成样本距离小于一定半径(
\(r\)
)的真实样本的比例,半径
\(r\)
被设置为特征空间内真实训练样本的平均最近邻距离。较高的覆盖率表明合成数据集捕获了真实样本的多样特征,使得在合成数据集上训练的模型能够学习到真实数据集中不仅是简单,还有复杂模式。
图
1a
(左)展示了随着
CIFAR-10
数据集的每类图像数量(
IPC
)增加,覆盖率如何变化。此外,在图
1a
(右)中,针对两组样本进行了分析。“简单”
50%
和“困难”
50%
(根据遗忘分数对真实样本进行的难度衡量)。
观察结果显示,使用
MTT
的覆盖率并没有有效地随
IPC
扩展,始终低于随机选择的覆盖率。此外,困难样本组的覆盖率远远低于简单样本组的覆盖率。这表明,即使
IPC
增加,
MTT
也无法有效地将困难和复杂的数据模式嵌入到合成样本中,这可能是
MTT
性能不佳的缩放原因。而论文的方法
SelMatch
展示了更优越的总体覆盖率,特别是在
IPC
增加时,困难组覆盖率明显提升。
另一个重要发现是,随着蒸馏迭代次数的增多,
MTT
的覆盖率在减少,如图
1b
所示。这一观察进一步表明,传统的蒸馏方法主要在多次迭代过程中捕获“简单”模式,使得合成数据集随着蒸馏迭代次数的增加变得缺乏多样性。相比之下,即使迭代次数增加,使用
SelMatch
的覆盖率仍然保持稳定。如图
1c
所示,覆盖率也影响测试准确性。简单测试样本组和困难测试样本组之间覆盖率显著差异导致两组之间测试准确性存在显著差距。
SelMatch
提高了两组的覆盖率,从而提高了总体测试准确性,特别是在
IPC
增加时,对困难组的测试准确性有所提升。
图
2
展示了
SelMatch
的核心思想,该方法将基于选择的初始化与通过轨迹匹配进行部分更新相结合。传统的轨迹匹配方法通常使用随机选择的真实数据集
\(\mathcal{D}_\textrm{real}\)
的子集对合成数据集
\(\mathcal{D}_\textrm{syn}\)
进行初始化,没有任何特定的选择标准。在每次蒸馏迭代过程中,整个
\(\mathcal{D}_\textrm{syn}\)
都会被更新,以最小化定义在公式
1
中的匹配损失
\(\mathcal{L}(\mathcal{D}_\textrm{syn}, \mathcal{D}_\textrm{real})\)
。
相比之下,
SelMatch
首先使用精心选择的子集
\(\mathcal{D}_\textrm{initial}\)
对
\(\mathcal{D}_\textrm{syn}\)
进行初始化,该子集包含量身定制的适合于合成数据集规模的样本,具有适当的困难级别。然后,在每次蒸馏迭代中,
SelMatch
仅更新
\(\mathcal{D}_\textrm{syn}\)
的一个特定部分,表示为
\(\alpha\in[0,1]\)
,(称为
\(\mathcal{D}_\textrm{distill}\)
),而数据集的剩余部分(称为
\(\mathcal{D}_\textrm{select}\)
)保持不变。这个过程旨在最小化公式
1
中的相同匹配损失
\(\mathcal{L}(\mathcal{D}_\textrm{syn}, \mathcal{D}_\textrm{real})\)
,但现在
\(\mathcal{D}_\textrm{syn}\)
是
\(\mathcal{D}_\textrm{distill}\)
和
\(\mathcal{D}_\textrm{select}\)
的组合。
图
1
中的一个重要观察是,传统的轨迹匹配方法倾向于关注完整数据集中简单和具代表性的模式,而不是复杂的数据模式,导致在更大的
IPC
设置中扩展性较差。为了克服这一问题,论文提出了使用一个经过精心选择的难度级别对合成数据集
\(\mathcal{D}_\textrm{syn}\)
进行初始化,该难度级别在
IPC
增加时包括来自真实数据集更复杂的模式。因此,挑战在于如何选择真实数据集
\(\mathcal{D}_\textrm{real}\)
的一个子集,其复杂度水平适当,同时考虑
\(\mathcal{D}_\textrm{syn}\)
的规模。
为了解决这个问题,论文设计了一个滑动窗口算法。根据预先计算的困难度分数(在
CIFAR-10
/
100
上利用预先计算的
C-score
,而在
Tiny Imagenet
上使用
Forgetting score
作为困难分数。),按照困难程度的降序(从最困难到最容易)排列训练样本。然后,通过在不同起始点上的每个窗口子集训练模型来比较测试准确度,评估这些样本的窗口子集。对于给定阈值
\(\beta\in[0,100]\%\)
,在排除最困难的
\(\beta\)
%样本后,窗口子集包括来自
\([\beta, \beta+r]\)
%范围内的样本,其中
\(r=(|\mathcal{D}_\textrm{syn}|/|\mathcal{D}_\textrm{real}|)\times 100\%\)
,
\(|\mathcal{D}_\textrm{syn}|\)
等于
IPC
乘以类别数。在这里,确保每个窗口子集包含相同数量的来自每个类别的样本。
如图
3
所示,窗口的起始点对应于困难程度的级别,显著影响模型的泛化能力(通过测试准确度来衡量)。特别是对于较小的窗口(
5-10%
范围),测试准确度根据窗口起始位置的不同可以出现高达
40%
的偏差。此外,表现最好的窗口子集,即实现最高测试准确度的子集,倾向于在子集大小增加时包含更困难的样本(较小的
\(\beta\)
)。这符合这样一种直觉,即随着
IPC
的增加,将来自真实数据集的复杂模式纳入模型可以增强其泛化能力。
基于这一观察,将
\(\mathcal{D}_\textrm{syn}\)
的初始化设置为
\(\mathcal{D}_\textrm{initial}\)
,其中
\(\mathcal{D}_\textrm{initial}\)
是由滑动窗口算法为给定
\(\mathcal{D}_\textrm{syn}\)
大小确定的表现最佳的窗口子集。这种方法确保了随后的提取过程从特定
IPC
制度下经过优化的难度级别的图像开始。
在用滑动窗口算法选择的最佳窗口子集
\(\mathcal{D}_\textrm{initial}\)
对合成数据集
\(\mathcal{D}_\textrm{syn}\)
进行初始化后,下一个目标是通过数据集蒸馏来更新
\(\mathcal{D}_\textrm{syn}\)
,以便有效地将来自整个真实数据集
\(\mathcal{D}_\textrm{real}\)
的信息嵌入其中。传统上,匹配训练轨迹(
MTT
)算法通过对
\(N\)
个模型更新进行反向传播,以最小化匹配损失公式
1
,从而更新
\(\mathcal{D}_\textrm{syn}\)
中的所有样本。然而,如图
1b
所示,这种方法倾向于数据集中更简单的模式,导致在连续提取迭代中覆盖范围的减少。因此,为了解决这个问题并保持一些真实样本的独特和复杂特征(对于模型在更大
IPC
范围内的泛化能力至关重要),论文引入了对
\(\mathcal{D}_\textrm{syn}\)
的部分更新。
根据每个样本的难度分数,将初始合成数据集
\(\mathcal{D}_\textrm{syn}=\mathcal{D}_\textrm{initial}\)
划分为两个子集
\(\mathcal{D}_\textrm{select}\)
和
\(\mathcal{D}_\textrm{distill}\)
。子集
\(\mathcal{D}_\textrm{select}\)
包含
\((1-\alpha) \times |\mathcal{D}_\textrm{syn}|\)
个难度较高的样本,剩下的
\(\alpha\)
部分样本分配到
\(\mathcal{D}_\textrm{distill}\)
,其中
\(\alpha\in[0,1]\)
是根据
IPC
调整的超参数。
在提取迭代过程中,保持
\(\mathcal{D}_\textrm{select}\)
不变,只更新
\(\mathcal{D}_\textrm{distill}\)
子集。更新的目标是最小化整个
\(\mathcal{D}_\textrm{syn}=\mathcal{D}_\textrm{select}\cup \mathcal{D}_\textrm{distill}\)
和
\(\mathcal{D}_\textrm{real}\)
之间的匹配损失,即:
与最小化
\(\mathcal{L}(\mathcal{D}_\textrm{distill}, \mathcal{D}_\textrm{real})\)
不同,仅更新部分
\(\mathcal{D}_\textrm{syn}\)
的损失策略鼓励
\(\mathcal{D}_\textrm{distill}\)
浓缩在
\(\mathcal{D}_\textrm{select}\)
中不存在的知识,从而丰富
\(\mathcal{D}_\textrm{syn}\)
中的整体信息。
在创建合成数据集
\(\mathcal{D}_\textrm{syn}\)
后,通过使用这个数据集训练一个随机初始化的神经网络来评估其有效性。通常情况下,先前的蒸馏方法采用了
Dif- ferentiable Siamese Augmentation
(
DSA
)来评估合成数据集。这种方法涉及比用于真实数据集常见的简单方法(如随机裁剪和水平翻转)更复杂的增强技术,在合成数据方面取得了更好的结果。这种提升的性能可能是因为合成数据集主要捕获了更简单的模式,使它们更适合于通过
DSA
进行更强的增强。
然而,在整个合成数据集
\(\mathcal{D}_\textrm{syn}\)
上应用
DSA
可能并非理想,特别是考虑到包含难以处理样本的子集
\(\mathcal{D}_\textrm{select}\)
的存在。为了解决这个问题,论文提出了一种专门针对论文的合成数据集定制的综合增强策略。具体而言,将
DSA
应用于精炼部分
\(\mathcal{D}_\textrm{distill}\)
,并对选择的、更复杂的子集
\(\mathcal{D}_\textrm{select}\)
使用更简单、更传统的增强技术。这种综合方法旨在利用两种增强方法的优势,以提高合成数据集的整体性能。
将所有内容整合起来,
SelMatch
在算法
1
中进行了总结。
如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】