Remixmatch:具有分布对齐和增强锚定的半监督学习
原文:《ReMixMatch: Semi-Supervised Learning with Distribution Matching and Augmentation Anchoring》
摘要
我们通过引入两种新技术改进了最近提出的“MixMatch”半监督学习算法:分布对齐和增强锚定。 分布对齐鼓励未标记数据的预测边缘分布接近真实标签的边缘分布。 增强锚定将输入的多个强增强版本提供给模型,并鼓励每个输出接近同一输入的弱增强版本的预测。为了产生强大的增强,我们提出了 AutoAugment 的变体,它在训练模型时学习增强策略。 我们称为 ReMixMatch 的新算法比之前的工作具有更高的数据效率,需要 5 倍到 16 倍的数据才能达到相同的精度。 例如,在带有 250 个标记示例的 CIFAR10 上,我们达到了 93.73% 的准确率(相比之下,MixMatch 的准确率为 93.58%,带有 4,000 个示例),而每个类别只有四个标签的准确率中值为 84.92%。 我们在 https://github.com/google-research/remixmatch 上开源我们的代码和数据。
本文思路
半监督学习 (SSL) 提供了一种在只有有限的标记数据可用时利用未标记数据来提高模型性能的方法。 当标记数据昂贵或不方便时,这可以启用大型、强大的模型。 SSL 研究产生了多种方法,包括一致性正则化(Sajjadi 等人,2016 年;Laine 和 Aila,2017 年),它鼓励模型在输入受到扰动时产生相同的预测,以及熵最小化(Grandvalet 和 Bengio, 2005)鼓励模型输出高置信度的预测。 最近提出的“MixMatch”算法(Berthelot 等人,2019 年)将这些技术结合在一个统一的损失函数中,并在各种图像分类基准上实现了强大的性能。 在本文中,我们提出了两项可以轻松集成到 MixMatch 框架中的改进。
- 首先,我们引入了“分布对齐”,它鼓励模型的聚合类预测的分布与基本真相类标签的边际分布相匹配。Bridle等人(1992)引入了这个概念作为“公平”目标,其中一个相关的损失项显示出来自模型输入和输出之间相互信息的最大化。在回顾了这个理论框架之后,我们将展示如何通过使用模型预测的运行平均值修改“猜测的标签”来直接将分布对齐添加到MixMatch中。
- 其次,我们引入了“增强锚定”,它取代了MixMatch的一致性正则化组件。对于每个给定的未标记输入,增强锚定首先生成一个弱增强版本(例如,只使用翻转和裁剪),然后生成多个强增强版本。该模型对弱增强输入的预测被视为所有强增强版本的猜测标签的基础。 为了产生强大的增强,我们引入了一种基于控制理论的 AutoAugment 变体(Cubuk 等人,2018 年),我们称之为“CTAugment”。 与 AutoAugment 不同,CTAugment 在模型训练的同时学习增强策略,使其在 SSL 设置中特别方便。
我们将改进后的算法称为“ReMixMatch”,并在一套标准SSL图像基准测试上对其进行实验验证。
相关背景
半监督学习算法的目标是以提高标记数据性能的方式从未标记数据中学习。 实现这一目标的典型方法包括针对未标记数据的“猜测”标签进行训练,或优化不依赖于标签的启发式目标。 一致性正规化:许多 SSL 方法依赖于一致性正则化来强制模型输出在输入受到扰动时保持不变。最常见的扰动是应用特定领域的数据增强,用于衡量一致性的损失函数通常是模型针对扰动和非扰动输入的输出之间的均方误差或交叉熵。 熵最小化:Grandvalet & Bengio (2005) 认为应该使用未标记的数据来确保类被很好地分离。 这可以通过鼓励模型的输出分布对未标记数据具有低熵(即做出“高置信度”预测)来实现。例如,可以显式地添加一个损失项,以最小化模型在未标记数据上预测的类分布的熵(Grandvalet & Bengio, 2005;Miyato et al., 2018)。 标准正则化:在SSL设置之外,在过度参数化的情况下对模型进行正则化通常是有用的。在对有标签和无标签数据进行训练时,通常都可以应用这种正则化。例如,标准的“权重衰减”(Hinton & van Camp,1993),其中参数的 L2 范数被最小化,通常与 SSL 技术一起使用。 同样,强大的 MixUp 正则化 (Zhang et al., 2017) 最近已应用于 SSL(Berthelot et al., 2019;Verma et al., 2019),该模型训练输入和标签的线性插值模型。
Remixmatch
ReMixMatch的完整算法如算法1所示。
分布对齐
我们的第一个贡献是分布对齐(distribution
alignment),它强制对未标记数据的预测聚类与已标记数据的分布相匹配。这个总体思想在25年前首次提出(Bridle等人,1992年),但据我们所知,现代SSL技术还没有使用它。分布对齐示意图见图1。在回顾和扩展了该理论之后,我们将描述如何将其直接包含在ReMixMatch中。

输入输出互信息
SSL算法的主要目标是使用未标注的数据来提升模型的精度。Bridle等人(1992)首先提出了一种将这种直觉形式化的方法,即最大化未标记数据输入和输出之间的互信息。直观地说,一个好的分类器的预测应该尽可能地依赖于输入。根据Bridle等人(1992)的分析,我们可以将这个目标形式化为:
其中
Mixmatch中的分布对齐
MixMatch已经包含了一种通过“锐化”操作实现的熵最小化形式,这使得未标记数据的猜测标签具有更低的熵。因此,我们也有兴趣在
ReMixMatch 中加入一种“公平”形式。然而,注意到目标
改进的一致性正则化
一致性正则化是大多数SSL方法的基础。对于图像分类任务,通常要求对同一未标记图像的两个增强版本的输出要保持一致性。为了保证一致性,MixMatch对每个未标记的示例
增强锚定
我们推测带有AutoAugment的MixMatch不稳定的原因是:MixMatch对
在实验时,我们发现可以用标准的交叉熵损失来替换MixMatch的均方误差损失函数(对于unlabeled
data而言)。这样既保持了稳定性,又简化了实现。当MixMatch仅在
控制理论中的增强方法
AutoAugment是一种学习数据增强策略的方法,可获得较高的验证集精度。增强策略由一系列应用于每个图像的变换参数幅度元组组成。至关重要的是,AutoAugment
策略是在监督学习下的:转换的大小和顺序是通过在代理任务上训练许多模型来确定的,例如,在
CIFAR-10 上使用 4,000 个标签,在 SVHN 上使用 1,000
个标签。这使得AutoAugment难以应用于少标签的半监督学习任务。
因此,在这项工作中,我们开发了CTAugment,一个设计高性能增强策略的替代方法。像RandAugment一样,CTAugment也随机地选取图像变换的方式,但在训练过程中动态地改变图像变换的相关参数。由于CTAugment不需要在有监督的代理任务上进行优化,也没有超参数,我们可以直接将其融入我们的半监督模型中。直观地说,对于每个增强参数,CTAugment计算增强图像被归类为正确标签的可能性。利用这些可能性,CTAugment然后只采样在网络容忍范围内的增强。这个过程与Fast
AutoAugment中的密度匹配有关,以便增强图像的分布与训练集图像的相匹配。
首先,CTAugment对于每个图像变换涉及到的超参数的定义域进行区间划分,初始化每个区间的权重为1。在每个训练步骤中,对于每个图像,随机均匀地采样两个变换。为了增强用于训练的图像,对于这些转换的每个参数,我们生成一组修改后的
bin 权重
汇总
ReMixMatch 用于处理一批标记和未标记示例的算法如算法 1
所示。该算法的主要目的是生成集合
Pre-mixup unlabeled loss
我们将猜测的标签和预测输入到一个单独的交叉熵损失项中。
Rotation loss
最近的结果表明,将自监督学习的思想应用于SSL可以产生强大的性能。我们通过将每个图像
算法流程
- 输入:一批包含已标注数据(含标签)的数据,一批包含未标注数据的数据、超参数;
- 对于每一笔数据(因为两批数据的数量一致): 增强操作:对已标注数据进行一次强增强,对未标注数据进行K次强增强和1次弱增强; 为未标注数据生成猜测标签(分布对齐、锐化)
- 生成三个数据集:强增强后的已标注数据、强增强后的未标注数据和弱增强标注后的未标注数据;
- 对已标注数据和未标记数据进行MixUp;
- 返回三个数据集,计算损失函数:
第一项是常规的有监督损失;第二项和第三项是无标注数据的“监督损失”;第四项是旋转角度预测。
与Mixmatch对比:
- 增加了一个损失函数:角度预测;
- 伪标签生成的方式不同,引入了强增强和弱增强,并以弱增强作为label anchor;
- 引入CTAugment控制图像变换的参数取样和方式取样。