本发明涉及图像识别技术领域,尤其涉及基于双重空间对比学习的开放世界半监督图像分类方法,包括如下步骤:初始化网络模型;构建特征图及预测图;进行双重空间对比迭代学习训练,计算监督损失、特征空间自监督对比损失、预测空间自监督对比损失、特征空间与预测空间之间的空间引导对比损失及总体损失,根据总体损失对网络模型进行梯度反向传播及网络模型参数更新;重复迭代得到训练好的网络模型;对无标签数据集中的每一个无标签样本进行推理,得出每一个无标签样本预测的标记结果。本发明提供的方法提升了网络模型对于无标签数据的表征能力,从而提升对于无标签数据的标记准确率。
1.基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:包括如下步骤:
S1:整理待标记图像数据集,将待标记图像数据集中带标签样本划分至带标签数据集,将待标记图像数据集中无标签样本划分至无标签数据集,并初始化网络模型;
S2:根据待标记图像数据集中的所有样本在特征空间的关联性构建特征图,根据待标记图像数据集中的所有样本在预测空间的关联性构建预测图;
S3:利用步骤S1中的带标签数据集、无标签数据集以及步骤S2中的特征图、预测图对步骤S1中的初始化网络模型进行双重空间对比迭代学习训练,计算监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失,并对监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失加权求和得到总体损失,再根据总体损失对网络模型进行梯度反向传播及网络模型参数更新;
S4:重复迭代步骤S2至步骤S3,并判断是否达到设定迭代轮数,若达到设定迭代轮数,则得到训练好的网络模型进行进入下一步骤;
S5:利用步骤S4训练好的网络模型对步骤S1无标签数据集中的每一个无标签样本进行推理,得出每一个无标签样本预测的标记结果。
2.根据权利要求1所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S2中按照如下方法构建特征图:
S21:相连样本对的筛选:对无标签数据集中的每一个无标签样本采集其在特征空间中具有最近邻特征向量的样本,并将具有最近邻特征向量的样本与该无标签样本在特征图中相连,形成该无标签样本的相连样本对,对带标签数据集中的每一个带标签样本随机选择一个与其具有相同标签的带标签样本和该带标签样本在特征图中相连,形成该带标签样本的相连样本对;
S22:不相连样本对的筛选:将待标记图像数据集中所有样本两两构成样本对,然后计算每个样本对之间的特征向量余弦相似度,并将特征向量余弦相似度与设定阈值进行比较,若特征向量余弦相似度小于设定阈值,则该样本对在特征图中为不相连样本对;
S23:将无标签样本的相连样本对、带标签样本的相连样本对及不相连样本对之外的样本对定义为在特征图中关联性未知。
3.根据权利要求2所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S2中按照如下方法构建预测图:
S24:根据预测向量最大值对应的类别估计无标签数据集中无标签样本类别,并选择无标签样本类别中置信度高于置信度设定值且满足预设比例系数的样本组成高质量无标签数据集,其余无标签样本组成低质量无标签数据集;
S25:将高质量无标签数据集中的每一个样本其预测向量最大值对应的类别定义为该样本的伪标签,低质量无标签数据集中的样本伪标签定义为未知;
S26:对待标记图像数据集中的所有样本构建预测图,对于任意两个样本,若至少有一个样本属于低质量无标签数据集,则将这两个样本关系定义为在预测图中的关联性未知,若两样本均属于高质量无标签数据集或带标签数据集,则继续判断这两个样本的伪标签或者标签是否为同一类别,若属于同一类别则这两个样本在预测图中相连,若不属于同一类别则在预测图中不相连。
4.根据权利要求3所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S24中的预设比例系数随着迭代轮数的增加而渐进增加。
5.根据权利要求1所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S3中计算总体损失的方法包括如下步骤:
S31:对步骤S1中的带标签数据集、无标签数据集中的样本进行分批次采样,并对采样批次中的图像数据进行增广操作,获取弱增广图像及强增广图像;
S32:对于弱增广图像及强增广图像通过深度骨干网络模型与分类器得到弱增广特征向量、强增广特征向量、弱增广预测向量及强增广预测向量,并将弱增广特征向量、强增广特征向量根据采样批次整理为特征集,将弱增广预测向量及强增广预测向量根据采样批次整理为预测集;
S33:根据监督目标函数的公式计算监督损失;
S34:特征空间自监督对比学习:将特征集内的弱增广特征向量、强增广特征向量映射为弱增广投影向量和强增广投影向量,并将弱增广投影向量和强增广投影向量组成投影向量集,对投影向量集中的投影向量进行对比学习,将强增广投影向量及弱增广投影向量视为同一样本在不同增广下的两种不同表征代入对比损失函数中,得到特征空间自监督对比损失;
S35:预测空间自监督对比学习:将预测集内的弱增广预测向量以行向量堆叠的方式组成弱增广预测矩阵,将预测集内的强增广预测向量以行向量堆叠的方式组成强增广预测矩阵,再将弱增广预测矩阵沿列方向拆解为弱增广预测列向量,将强增广预测矩阵沿列方向拆解为强增广预测列向量,并将弱增广预测列向量和强增广预测列向量组成预测列向量集,然后对预测列向量集中的预测列向量进行对比学习,将强增广预测列向量、弱增广预测列向量视为不同增广下的两种不同表征代入对比损失函数中,得到预测列向量的对比损失,再对该批次内所有的弱增广预测向量和强增广预测向量分别取平均之后并求得弱增广预测向量熵及强增广预测向量熵,并将两个熵值相加作为熵损失,最终将预测列向量的对比损失减去熵损失得到预测空间自监督对比损失;
S36:空间引导对比学习:通过从预测空间到特征空间的图引导对比学习计算从预测空间到特征空间的总引导对比损失,再通过从特征空间到预测空间的图引导对比学习计算从特征空间到预测空间的引导对比损失;
S37:对步骤S33的监督损失、步骤S34的特征空间自监督对比损失、步骤S35的预测空间自监督对比损失、步骤S36的从预测空间到特征空间的总引导对比损失、从特征空间到预测空间的总引导对比损失求和计算总体损失。
6.根据权利要求5所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S36中从预测空间到特征空间的图引导对比学习方法如下:
S361:将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中产生关联的特征收集至相应的正关联特征集,将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中不产生关联的特征收集至相应的负关联特征集,并将相应的正关联特征集及负关联特征集的并集组成相应的正负关联特征集;
S362:根据式(1)计算对第个样本的弱增广特征向量从预测空间到特征空间的引导对比损失,根据式(2)计算对第个样本的强增广特征向量从预测空间到特征空间的引导对比损失:
(1);
(2);
其中:表示特征集第个样本弱增广之后所对应的特征向量,表示从预测空间到特征空间的引导对比损失,表示总的采样批次值,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量,表示特征集中第个样本的强增广特征向量,表示从预测空间到特征空间的引导对比损失,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量;
S363:根据式(3)计算从预测空间到特征空间的总引导对比损失:
(3)。
7.根据权利要求6所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:从特征空间到预测空间的引导对比学习方法如下:
S364:将预测集中与第个样本的弱增广预测向量或强增广预测向量在特征图中产生关联的特征收集至相应的正关联预测集,将预测集与第个样本的弱增广预测向量或强增广预测向量在特征图中不产生关联的特征收集至相应的负关联预测集,并将相应的正关联预测集及负关联预测集的并集组成相应的正负关联预测集;
S365:根据式(4)计算对第个样本的弱增广预测向量从特征空间到预测空间的引导对比损失,根据式(5)计算对第个样本的强增广预测向量从特征空间到预测空间的引导对比损失:
(4);
(5);
其中:表示预测集中第个样本的弱增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量,表示预测集中第个样本的强增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量;
S366:根据式(6)计算从特征空间到预测空间的总引导对比损失:
(6)。
8.根据权利要求7所述的基于双重空间对比学习的开放世界半监督图像分类方法,其特征在于:步骤S3中按照式(7)计算总体损失:
(7);
其中:表示监督损失,表示特征空间自监督对比损失,表示预测空间自监督对比损失。
技术领域
[0001]本发明涉及图像识别技术领域,尤其涉及基于双重空间对比学习的开放世界半监督图像分类方法。
背景技术
[0002]近年来,深度学习技术在计算机视觉领域取得了巨大的进展,尽管如此,大多数深度学习模型仍然需要大量高质量的标注数据,获取这些数据既耗时又昂贵。为了克服这个限制,研究人员广泛研究了半监督图像分类方法,这些方法充分利用了大量未标记的数据以达成减少手动标注需求的目的,其已经引起了研究和工业界的广泛关注。
[0003]尽管半监督图像分类方法的相关研究取得了一些进展,但是它们的应用仍然受到环境的限制,其中未标记的样本类别是预先定义且不变的。而在现实场景中,未标记样本的类别分布通常与标记样本不匹配,研究开放世界下的半监督图像分类问题具有重要的现实意义。
[0004]在开放集假设下,图像数据集由有限数量的已知类别的带标记图像和大量具有更广泛类别的未标记图像组成。标记数据集中出现的类别被称为旧类,在标记数据集中不存在但在未标记集合中出现的类别被称为新类。开放世界半监督分类的最终目的是不仅能够区分已知旧类,还能够对新类别中的样本进行聚类,进而为无标签样本提供合理的标注建议。与传统半监督标注方法相比,这种方法的核心难题在于在缺乏监督的情况下捕捉新类样本的特征。
[0005]现有研究已有将对比学习应用于开放世界半监督学习的成功案例,比如GCD,OpenCon等等,这些方法将对比学习的过程融合进半监督学习中,有益于对于无标签数据习得可分性强且具有语义特性的表征,从而促进对这些数据的准确标记,但仍然存在两个关键问题尚未解决,导致性能不佳:
1)预测空间的充分利用不足。先前的研究主要集中在特征空间中学习紧凑表示,未对预测空间进行进一步探索。这些方法通常使用原型学习生成分类预测,并根据k-means聚类准则更新原型。然而,这个过程是单向的,导致预测空间中的优化无法通过梯度传播直接影响到特征空间,从而影响最终结果。
[0006]2)忽视双重空间的互补属性。实际上,利用样本之间的相关性可以显著提升开放世界半监督学习的性能。然而,目前的方法通常限于在单一空间中提取样本相关性,例如在特征空间或预测空间中,我们发现在两个空间中构建样本关联的常见方法可能导致异质性。
发明内容
[0007]本发明所要解决的技术问题是提供基于双重空间对比学习的开放世界半监督图像分类方法,引入了特征空间和预测空间内自监督对比学习,并且利用特征空间和预测空间中的互补信息,提出了特征空间和预测空间之间的空间引导对比学习,挖掘一个空间的图结构信息并让其引导另一个空间表征的学习,从而实现异质信息的融合与高效利用,提升了网络模型对于无标签数据的表征能力,从而提升对于无标签数据的标记准确率。
[0008]本发明是通过以下技术方案予以实现:
[0009]基于双重空间对比学习的开放世界半监督图像分类方法,其包括如下步骤:
S1:整理待标记图像数据集,将待标记图像数据集中带标签样本划分至带标签数据集,将待标记图像数据集中不带标签样本划分至无标签数据集,并初始化网络模型;
S2:根据待标记图像数据集中的所有样本在特征空间的关联性构建特征图,根据待标记图像数据集中的所有样本在预测空间的关联性构建预测图;
S3:利用步骤S1中的带标签数据集、无标签数据集以及步骤S2中的特征图、预测图对步骤S1中的初始化网络模型进行双重空间对比迭代学习训练,计算监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失,并对监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失加权求和得到总体损失,再根据总体损失对网络模型进行梯度反向传播及网络模型参数更新;
S4:重复迭代步骤S2至步骤S3,并判断是否达到设定迭代轮数,若达到设定迭代轮数,则得到训练好的网络模型进行进入下一步骤;
S5:利用步骤S4训练好的网络模型对步骤S1无标签数据集中的每一个无标签样本进行推理,得出每一个无标签样本预测的标记结果。
[0010]进一步,步骤S2中按照如下方法构建特征图:
S21:相连样本对的筛选:对无标签数据集中的每一个无标签样本采集其在特征空间中具有最近邻特征向量的样本,并将具有最近邻特征向量的样本与该无标签样本在特征图中相连,形成该无标签样本的相连样本对,对带标签数据集中的每一个带标签样本随机选择一个与其具有相同标签的带标签样本和该带标签样本在特征图中相连,形成该带标签样本的相连样本对;
S22:不相连样本对的筛选:将待标记图像数据集中所有样本两两构成样本对,然后计算每个样本对之间的特征向量余弦相似度,并将特征向量余弦相似度与设定阈值进行比较,若特征向量余弦相似度小于设定阈值,则该样本对在特征图中为不相连样本对;
S23:将无标签样本的相连样本对、带标签样本的相连样本对及不相连样本对之外的样本对定义为在特征图中关联性未知。
[0011]进一步,步骤S2中按照如下方法构建预测图:
S24:根据预测向量最大值对应的类别估计无标签数据集中无标签样本类别,并选择无标签样本类别中置信度高于置信度设定值且满足预设比例系数的样本组成高质量无标签数据集,其余无标签样本组成低质量无标签数据集;
S25:将高质量无标签数据集中的每一个样本其预测向量最大值对应的类别定义为该样本的伪标签,低质量无标签数据集中的样本伪标签定义为未知;
S26:对待标记图像数据集中的所有样本构建预测图,对于任意两个样本,若至少有一个样本属于低质量无标签数据集,则将这两个样本关系定义为在预测图中的关联性未知,若两样本均属于高质量无标签数据集或带标签数据集,则继续判断这两个样本的伪标签或者标签是否为同一类别,若属于同一类别则这两个样本在预测图中相连,若不属于同一类别则在预测图中不相连。
[0012]优化的,步骤S24中的预设比例系数随着迭代轮数的增加而渐进增加。
[0013]进一步,步骤S3中计算总体损失的方法包括如下步骤:
S31:对步骤S1中的带标签数据集、无标签数据集中的样本进行分批次采样,并对采样批次中的图像数据进行增广操作,获取弱增广图像及强增广图像;
S32:对于弱增广图像及强增广图像通过深度骨干网络模型与分类器得到弱增广特征向量、强增广特征向量、弱增广预测向量及强增广预测向量,并将弱增广特征向量、强增广特征向量根据采样批次整理为特征集,将弱增广预测向量及强增广预测向量根据采样批次整理为预测集;
S33:根据监督目标函数的公式计算监督损失;
S34:特征空间自监督对比学习:将特征集内的弱增广特征向量、强增广特征向量映射为弱增广投影向量和强增广投影向量,并将弱增广投影向量和强增广投影向量组成投影向量集,对投影向量集中的投影向量进行对比学习,将强增广投影向量及弱增广投影向量视为同一样本在不同增广下的两种不同表征代入对比损失函数中,得到特征空间自监督对比损失;
S35:预测空间自监督对比学习:将预测集内的弱增广预测向量以行向量堆叠的方式组成弱增广预测矩阵,将预测集内的强增广预测向量以行向量堆叠的方式组成强增广预测矩阵,再将弱增广预测矩阵沿列方向拆解为弱增广预测列向量,将强增广预测矩阵沿列方向拆解为强增广预测列向量,并将弱增广预测列向量和强增广预测列向量组成预测列向量集,然后对预测列向量集中的预测列向量进行对比学习,将强增广预测列向量、弱增广预测列向量视为不同增广下的两种不同表征代入对比损失函数中,得到预测列向量的对比损失,再对该批次内所有的弱增广预测向量和强增广预测向量分别取平均之后并求得弱增广预测向量熵及强增广预测向量熵,并将两个熵值相加作为熵损失,最终将预测列向量的对比损失减去熵损失得到预测空间自监督对比损失;
S36:空间引导对比学习:通过从预测空间到特征空间的图引导对比学习计算从预测空间到特征空间的总引导对比损失,再通过从特征空间到预测空间的图引导对比学习计算从特征空间到预测空间的引导对比损失;
S37:对步骤S33的监督损失、步骤S34的特征空间自监督对比损失、步骤S35的预测空间自监督对比损失、步骤S36的从预测空间到特征空间的总引导对比损失、从特征空间到预测空间的总引导对比损失求和计算总体损失。
[0014]进一步,步骤S36中从预测空间到特征空间的图引导对比学习方法如下:
S361:将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中产生关联的特征收集至相应的正关联特征集,将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中不产生关联的特征收集至相应的负关联特征集,并将相应的正关联特征集及负关联特征集的并集组成相应的正负关联特征集;
[0015]S362:根据式(1)计算对第个样本的弱增广特征向量从预测空间到特征空间的引导对比损失,根据式(2)计算对第个样本的强增广特征向量从预测空间到特征空间的引导对比损失:
(1);
(2);
其中:表示特征集中第个样本弱增广之后所对应的特征向量,表示从预测空间到特征空间的引导对比损失,表示总的采样批次值,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量,表示特征集中第个样本的强增广特征向量,表示从预测空间到特征空间的引导对比损失,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量;
[0016]S363:根据式(3)计算从预测空间到特征空间的总引导对比损失:
(3)。
[0017]进一步,从特征空间到预测空间的引导对比学习方法如下:
S364:将预测集中与第个样本的弱增广预测向量或强增广预测向量在特征图中产生关联的特征收集至相应的正关联预测集,将预测集与第个样本的弱增广预测向量或强增广预测向量在特征图中不产生关联的特征收集至相应的负关联预测集,并将相应的正关联预测集及负关联预测集的并集组成相应的正负关联预测集;
[0018]S365:根据式(4)计算对第个样本的弱增广预测向量从特征空间到预测空间的引导对比损失,根据式(5)计算对第个样本的强增广预测向量从特征空间到预测空间的引导对比损失:
(4);
(5);
其中:表示预测集中第个样本的弱增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量,表示预测集中第个样本的强增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量;
[0019]S366:根据式(6)计算从特征空间到预测空间的总引导对比损失:
(6)。
[0020]进一步,步骤S3中按照式(7)计算总体损失:
(7);
其中:表示监督损失,表示特征空间自监督对比损失,表示预测空间自监督对比损失。
[0021]发明的有益效果:
本发明开创性地构建了一套基于双重空间对比学习的开放世界半监督图像分类方法,其核心在于对特征空间和预测空间施加对比学习范式,具体包括特征空间、预测空间的空间内自监督对比学习和特征空间、预测空间之间的图引导对比学习,空间内自监督对比学习在特征空间和预测空间中同时使用自监督学习方法,可以缓解预测空间监督信息不足的问题;特征空间、预测空间之间的图引导对比学习挖掘一个空间的图结构信息并让其引导另一个空间表征的学习,提升了网络模型的表达能力,从而实现两个空间中异质信息的融合与高效利用,提升了网络模型对于无标签数据的表征能力且提升了无标签数据的标记准确率。
附图说明
[0022]图1是本发明流程示意图。
具体实施方式
[0023]基于双重空间对比学习的开放世界半监督图像分类方法,流程图如图1所示,具体包括如下步骤:
S1:整理待标记图像数据集,将待标记图像数据集中带标签样本划分至带标签数据集,将待标记图像数据集中不带标签样本划分至无标签数据集,并初始化网络模型。具体地,可以引入深度骨干网络模型及分类器,骨干网络可以选用残差神经网络ResNet18,并使用SimCLR预训练的参数来初始化深度骨干网络模型的参数,分类器为归一化线性分类器,其中存储了所有类别的类别中心向量,而分类类别总数设置为待标记数据集中所包含的总类别数,由于无标签数据集中存在带标签数据集中未知的类别,因此总类别数大于带标签数据集中类别数目,一般由经验估计。
[0024]S2:根据待标记图像数据集中的所有样本在特征空间的关联性构建特征图,根据待标记图像数据集中的所有样本在预测空间的关联性构建预测图;构建特征图及预测图时,可以首先利用步骤S1中的深度骨干网络模型及分类器对待标记图像数据集中的所有样本计算特征向量和预测向量,然后根据这些特征和预测向量之间的关联性构建特征图和预测图,以构建特征空间和预测空间中不同样本之间的联系;对于预测图可以采用渐进式学习模式来构图,对于特征图可以参考在特征空间样本之间的余弦相似度进行构图。
[0025]具体可以按照如下方法构建特征图及预测图:
[0026]特征图的构建:
[0027]特征图挖掘的目标是发现特征向量之间存在的局部结构信息,在以前的研究中,一种经典的构图方法是在每个采样批次中,为所有未标记样本的特征向量寻找其批次内在特征空间中的最近邻特征向量,并定义其对应样本在图中相连。然而,由于搜索空间有限,连边的可靠性会受到很大限制。为了提高图的质量,本发明引入了整个数据集中的特征向量最近邻信息,而不是每个批次中的特征向量最近邻信息,具体方法如下:
S21:相连样本对的筛选:对无标签数据集中的每一个无标签样本采集其在特征空间中具有最近邻特征向量的样本,并将具有最近邻特征向量的样本与该无标签样本在特征图中相连,形成该无标签样本的相连样本对,对带标签数据集中的每一个带标签样本随机选择一个与其具有相同标签的带标签样本和该带标签样本在特征图中相连,形成该带标签样本的相连样本对;
S22:不相连样本对的筛选:将待标记图像数据集中所有样本两两构成样本对,然后计算每个样本对之间的特征向量余弦相似度,并将特征向量余弦相似度与设定阈值进行比较,若特征向量余弦相似度小于设定阈值,则该样本对在特征图中为不相连样本对,这里的设定阈值可以使用大津法阈值估计方法对其估计一个二分阈值;
S23:将无标签样本的相连样本对、带标签样本的相连样本对及不相连样本对之外的样本对定义为在特征图中关联性未知。
[0028]预测图的构建:
预测图是根据样本的预测判定两个样本之间是否存在连边的图,其可以有效捕获被分至同类样本的全局社区信息。可以首先筛选出一部分高置信的样本,置信度表示样本的预测向量中最大值,代表了预测向量的可信程度,然后记录这些样本预测向量最大值对应的类别作为伪标签,再根据其伪标签的类别是否相同判断其之间是否存在连边。
[0029]一般的筛选方法是使用一个固定阈值来筛选每个样本的置信度,也即预测向量的最大值。但相关工作证明了在训练前期,即使置信度较高样本也可能被分配错误伪标签。为提升伪标签分配的准确率,可以采用一种渐进筛选范式来选择伪标签,具体预测图的构建方法如下:
S24:根据预测向量最大值对应的类别估计无标签数据集中无标签样本类别,并选择无标签样本类别中置信度高于置信度设定值且满足预设比例系数的样本组成高质量无标签数据集,其余无标签样本组成低质量无标签数据集;具体就是选择每个样本类别中样本,按照置信度从高到低排序,然后筛选出在靠前一定预设比例系数范围内且置信度高于置信度设定值的样本组成高质量无标签数据集。
[0030]这里的预设比例系数随着迭代轮数的增加而渐进增加,具体的,预设比例系数可以从0开始,可以采取每20个轮次增加0.1的渐进增加方式,置信度设定值可以设为0.9。
S25:将高质量无标签数据集中的每一个样本其预测向量最大值对应的类别定义为该样本的伪标签,低质量无标签数据集中的样本伪标签定义为未知;
[0031]S26:对待标记图像数据集中的所有样本构建预测图,对于任意两个样本,若至少有一个样本属于低质量无标签数据集,则将这两个样本关系定义为在预测图中的关联性未知,若两样本均属于高质量无标签数据集或带标签数据集,则继续判断这两个样本的伪标签或者标签是否为同一类别,若属于同一类别则这两个样本在预测图中相连,若不属于同一类别则在预测图中不相连。
S3:利用步骤S1中的带标签数据集、无标签数据集以及步骤S2中的特征图、预测图对步骤S1中的初始化网络模型进行双重空间对比迭代学习训练,计算监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失,并对监督损失、特征空间自监督对比损失、预测空间自监督对比损失及特征空间与预测空间之间的空间引导对比损失加权求和得到总体损失,再根据总体损失对网络模型进行梯度反向传播及网络模型参数更新;
[0032]具体的计算总体损失的方法包括如下步骤:
S31:对步骤S1中的带标签数据集、无标签数据集中的样本进行分批次采样,并对采样批次中的图像数据进行增广操作,获取弱增广图像及强增广图像,带标签数据集的样本采样批次大小可以为128,无标签数据集的样本采样批次大小为384,总批次大小为512,弱增广代表对原始图像进行较小程度的变换,通常不会改变图像的基本特征,而强增广代表对原始图像进行较大程度的变换,通常会改变图像的一些基本特征。弱增广可以使用随机裁剪和随机翻转函数,而强增广可以使用随机裁剪,随机翻转,随机亮度对比度调整,随机曝光和随机均衡化函数。
S32:对于弱增广图像及强增广图像通过深度骨干网络模型与分类器得到弱增广特征向量、强增广特征向量、弱增广预测向量及强增广预测向量,并将弱增广特征向量、强增广特征向量根据采样批次整理为特征集,将弱增广预测向量及强增广预测向量根据采样批次整理为预测集;
[0033]具体的,弱增广图像及强增广图像可以先通过深度骨干网络模型得到弱增广特征向量及强增广特征向量,再将其通过分类器中得到弱增广预测向量及强增广预测向量。而分类器中包括了所有分类类别的中心向量,在使用分类器计算预测向量时,将样本的特征向量与分类器中每个类别中心向量计算余弦相似度,将所得相似度除以温度系数之后经过一个标准柔性最大传递函数(SoftMax)函数即可得到对应样本预测向量,预测向量中每一个元素对应一个分类类别,其含义代表此样本隶属于该类别的概率,本发明实例中,温度系数可以为设为0.1。
[0034]S33:根据监督目标函数的公式计算监督损失;在半监督图像分类中,利用标记数据是至关重要的。经典的监督损失函数通常是交叉熵函数,它仅应用于带标签数据,在开放世界半监督图像分类中,带标记数据集中图像的所属类别统称为旧类,而无标签数据集中的图像可能会隶属于不同于任意一个旧类的新的类别,统称为新类。由于带标签数据只包含已知的旧类,仅使用交叉熵函数进行训练可能导致新旧类别之间的不平衡。因此,我们可以使用具有不确定性自适应机制的监督目标函数来缓解训练不平衡问题。该监督目标函数引入了自适应不确定性余量机制,用于阻塞旧类的训练,以减少旧类和新类之间训练程度的偏差。
[0035]监督目标函数的公式如式(8)所示:
(8);
其中:表示监督损失,表示总分类类别个数,表示温度系数,表示样本序号,代表第个样本的标签,表示第个样本的标签类别序号,表示分类器中第类的归一化类中心向量,表示除了第类以外的其他标签类别序号,表示分类器中第类的类中心向量,代表特征集中第个样本弱增广之后所对应的特征向量,代表带标签样本的批次大小,表示不确定度,不确定度为1减去所有无标签样本置信度的平均值,训练前期,由于训练的不完全,置信度较低,因此不确定度较高,此时不确定度会阻碍旧类的训练,而到了训练后期不确定度会逐渐降低,监督目标函数逐渐退化为一般交叉熵函数。
S34:特征空间自监督对比学习:将特征集内的弱增广特征向量、强增广特征向量映射为弱增广投影向量和强增广投影向量,并将弱增广投影向量和强增广投影向量组成投影向量集,对投影向量集中的投影向量进行对比学习,将强增广投影向量及弱增广投影向量视为同一样本在不同增广下的两种不同表征代入对比损失函数中,得到特征空间自监督对比损失;
[0036]具体的,可以经过额外两层的多层感知机(MLP,Multilayer Perceptron)将特征集内的弱增广特征向量、强增广特征向量分别映射为弱增广投影向量和强增广投影向量,并对投影向量进行对比学习训练,可提高学习到的特征质量,提升特征向量的可判别性。
[0037]具体可以根据式(9)计算对第个样本的弱增广特征向量的特征空间自监督对比损失,根据式(10)计算对第个样本的强增广特征向量的特征空间自监督对比损失,根据式(11)计算特征空间自监督对比损失:
(9);
(10);
(11);
其中:表示内积,B表示总的采样批次大小,表示第个样本的弱增广投影向量,第个样本的强增广投影向量,表示投影向量集合的投影向量,表示向量在投影向量集合的序号,表示第个样本的弱增广特征向量的特征空间自监督对比损失,表示第个样本的强增广特征向量的特征空间自监督对比损失,表示特征空间自监督对比损失。
S35:预测空间自监督对比学习:将预测集内的弱增广预测向量以行向量堆叠的方式组成弱增广预测矩阵,将预测集内的强增广预测向量以行向量堆叠的方式组成强增广预测矩阵,再将弱增广预测矩阵沿列方向拆解为弱增广预测列向量,将强增广预测矩阵沿列方向拆解为强增广预测列向量,并将弱增广预测列向量和强增广预测列向量组成预测列向量集,然后对预测列向量集中的预测列向量进行对比学习,将强增广预测列向量、弱增广预测列向量视为不同增广下的两种不同表征代入对比损失函数中,得到预测列向量的对比损失,再对该批次内所有的弱增广预测向量和强增广预测向量分别取平均之后并求得弱增广预测向量熵及强增广预测向量熵,并将两个熵值相加作为熵损失,最终将预测列向量的对比损失减去熵损失得到预测空间自监督对比损失;
[0038]每个弱增广预测列向量都对应一个类别,并且记录了批次内所有弱增广预测向量在该类别的概率值,每个强增广预测列向量也都对应一个类别,并且每个强增广预测列向量也记录了批次内所有强增广预测向量在该类别的概率值。由于每个样本只属于一个唯一的类别,因此我们引入的目标是在不同的增广下,最大化同一类别列向量之间的相似性,同时最小化不同类别列向量之间的相似性,从而使得不同类相互远离,产生更加尖锐的预测,将预测向量平均之后求负熵损失,是为了防止促进不同类别之间的平衡性。
[0039]具体的可以根据式(12)计算第个弱增广预测列向量的预测空间自监督对比损失,根据式(13)计算对第个强增广预测列向量的预测空间自监督对比损失:
[0040]根据式(14)计算预测空间自监督对比损失:
(12);
(13);
(14);
其中:表示第个弱增广预测列向量,表示第个强增广预测列向量,表示预测列向量集合的预测列向量,表示向量在投影向量集合的序号,表示熵损失。
S36:空间引导对比学习:通过从预测空间到特征空间的图引导对比学习计算从预测空间到特征空间的总引导对比损失,再通过从特征空间到预测空间的图引导对比学习计算从特征空间到预测空间的引导对比损失;
[0041]S37:对步骤S33的监督损失、步骤S34的特征空间自监督对比损失、步骤S35的预测空间自监督对比损失、步骤S36的从预测空间到特征空间的总引导对比损失、从特征空间到预测空间的总引导对比损失求和计算总体损失。
[0042]进一步,步骤S36中从预测空间到特征空间的图引导对比学习方法如下:
S361:将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中产生关联的特征收集至相应的正关联特征集,将特征集中与第个样本的弱增广特征向量或强增广特征向量在预测图中不产生关联的特征收集至相应的负关联特征集,并将相应的正关联特征集及负关联特征集的并集组成相应的正负关联特征集;由于关联性未知的样本所能够提供的信息本就不够可靠,因此,不对预测图中与第个样本的弱增广特征向量或强增广特征向量关联性未知的特征做处理,有益于后续的学习过程。
[0043]S362:根据式(1)计算对第个样本的弱增广特征向量从预测空间到特征空间的引导对比损失,根据式(2)计算对第个样本的强增广特征向量从预测空间到特征空间的引导对比损失:
(1);
(2);
其中:表示第特征集中第个样本弱增广之后所对应的特征向量,表示从预测空间到特征空间的引导对比损失,表示总的采样批次值,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量,表示特征集中第个样本的强增广特征向量,表示从预测空间到特征空间的引导对比损失,表示的正关联特征集,表示的正负关联特征集,表示属于的正关联特征集的所有特征向量,表示属于的正负关联特征集的所有特征向量;
[0044]S363:根据式(3)计算从预测空间到特征空间的总引导对比损失:
(3)。
[0045]从预测空间到特征空间的图引导对比学习,旨在充分利用预测图中全局社区性图结构信息,可以将预测图相邻样本的特征拉到一起,将非相邻样本的特征推开,由于关联性未知样本对不在、之内,故而在从预测空间到特征空间的引导对比损失计算中不对预测图中关系定义为关联性未知的样本对做任何处理,这是因为这些样本之间的关系并不明确,将其分离出来就是为了防止贸然对于其关系进行监督会导致错误信息的汇入,从而提升网络模型对于无标签数据的表征能力,提升无标签数据的标记准确率。
[0046]进一步,从特征空间到预测空间的引导对比学习方法如下:
S364:将预测集中与第个样本的弱增广预测向量或强增广预测向量在特征图中产生关联的特征收集至相应的正关联预测集,将预测集与第个样本的弱增广预测向量或强增广预测向量在特征图中不产生关联的特征收集至相应的负关联预测集,并将相应的正关联预测集及负关联预测集的并集组成相应的正负关联预测集;由于关联性未知的样本对所能够提供的信息本就不够可靠,因此,不对特征图中与第个样本的弱增广预测向量或强增广预测向量关联性未知的预测做处理,这有益于后续的学习过程。
[0047]S365:根据式(4)计算对第个样本的弱增广预测向量从特征空间到预测空间的引导对比损失,根据式(5)计算对第个样本的强增广预测向量从特征空间到预测空间的引导对比损失:
(4);
(5);
其中:表示预测集中第个样本的弱增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量,表示预测集中第个样本的强增广预测向量,表示从特征空间到预测空间的引导对比损失,表示的正关联预测集,表示的正负关联预测集,表示属于的正关联预测集的所有预测向量,表示属于的正负关联预测集的所有预测向量;
[0048]S366:根据式(6)计算从特征空间到预测空间的总引导对比损失:
(6)。
[0049]从特征空间到预测空间的图引导对比学习,旨在利用特征图中局部社区性图结构信息,由于定义为在特征图中未知的样本对不在、之内,故而在从特征空间到预测空间的引导对比损失计算中不对特征图中关系定义为关联性未知的样本对做任何处理,这是因为这些样本之间的关系并不明确,将其分离出来就是为了防止贸然对于其关系进行监督会导致错误信息的汇入,从而提升网络模型对于无标签数据的表征能力,提升无标签数据的标记准确率。
[0050]进一步,步骤S3中按照式(7)计算总体损失:
(7);
其中:表示监督损失,表示特征空间自监督对比损失,表示预测空间自监督对比损失。
[0051]式(7)中相当于在计算总体损失时,每一项损失都乘以一个权重,、、、的权重均为1,而的权重为2,具体权重可以根据经验确定或者根据调参结果确定。
S4:重复迭代步骤S2至步骤S3,并判断是否达到设定迭代轮数,若达到设定迭代轮数,则得到训练好的网络模型进行进入下一步骤;
[0052]S5:利用步骤S4训练好的网络模型对步骤S1无标签数据集中的每一个无标签样本进行推理,得出每一个无标签样本预测的标记结果。具体地就是对于待标记图像数据集中的每一个图像,让其经过骨干网络模型得到特征,再经过分类器得到预测向量,最后根据预测向量最大值所对应的类别得出每一个无标签样本预测的标记结果。
[0053]本发明提供的基于双重空间对比学习的开放世界半监督图像分类方法,通过在一个公开的数据集CIFAR100上进行实例印证,即将数据集CIFAR100作为待标记图像数据集,CIFAR100共有100个类别,选取其中50个作为旧类,再从旧类中采样50%的数据作为带标签数据集,剩余数据作为无标签数据集。训练完毕之后对于无标签数据计算三个指标:
新类准确率:无标签数据集中网络模型对于实际标签为新类的样本的聚类准确率;
旧类准确率:无标签数据集中网络模型对于实际标签为旧类的样本的分类准确率;
全类准确率:无标签数据集中网络模型对于所有样本的聚类准确率。
[0054]这里将带标签数据集中的图像所属类别统称为“旧类”,而在开放世界的环境中,无标签数据集中的图像可能会隶属于不同于任意一个旧类的新的类别,这些类别统称为“新类”。
[0055]实验的结果如表1所示,表中对比方法ORCA表示基于不确定度的自适应间隔方法,OpenCon表示开放世界对比学习方法,GCD为通用类别挖掘方法。
表1
[0056]有表1可以看出,本发明提供的基于双重空间对比学习的开放世界半监督图像分类方法,相比于最先进的方法OpenCon,在新类、旧类、全类准确率上分别有5.6%,5.7%,6.4%的提升。
[0057]综上所述,本发明提供基于双重空间对比学习的开放世界半监督图像分类方法,通过特征空间与预测空间内的自监督对比学习和特征空间与预测空间之间的相互引导对比学习,空间内自监督对比学习在特征空间和预测空间中同时使用自监督学习方法,可以缓解之前工作中预测空间的监督信息不足的问题;空间之间的引导对比学习挖掘一个空间的图结构信息并让其引导另一个空间表征的学习,提升了模型的表达能力,从而实现两个空间中异质信息的融合与高效利用,提升了网络模型对于无标签数据的表征能力且提升了无标签数据的标记准确率。
[0058]以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。