假设我们的批次大小为 $N$,那么一个批次的数据包含:
- $N$ 张图像: [$\text{img}_1$, $\text{img}_2$, …, $\text{img}_N$]
- $N$ 段文本: [$\text{text}_1$, $\text{text}_2$, …, $\text{text}_N$]
这里的关键是,$\text{img}_i$ 和 $\text{text}_i$ 是一个匹配的对(例如,$\text{img}_1$是猫的图片,$\text{text}_1$是“一只猫躺在沙发上”)。
第 1 步:通过编码器提取特征向量
- 对于图像:
- 将 $N$ 张图像 [$\text{img}_1$ , …, $\text{img}_N$] 分别输入到 图像编码器 (Image Encoder) 中(例如 ViT 或 ResNet)。
- 编码器会对每张图片输出一个高维的特征向量。
- 我们得到 $N$ 个图像特征向量:[$I_1$, $I_2$, …, $I_N$]。
- 每个向量(比如 $I_1$)都是一长串数字,它“浓缩”了 $\text{img}_1$ 的视觉信息。
- 对于文本:
- 将 $N$ 段文本 [$\text{text}_1$, …, $\text{text}_N$] 分别输入到 文本编码器 (Text Encoder) 中(例如 Transformer)。
- 编码器同样对每段文本输出一个高维的特征向量(维度通常与图像特征向量相同)。
- 我们得到 $N$ 个文本特征向量:[$T_1$, $T_2$, …, $T_N$]。
- 每个向量(比如 $T_1$)“浓缩”了 $\text{text}_1$ 的语义信息。
第 2 步:L2 归一化
在计算相似度之前,我们对上一步得到的所有特征向量进行 L2 归一化。
- 对每个图像向量 $I_i$,计算其归一化版本:
$$ \hat{I}_i = \frac{I_i}{|I_i|_2} $$
- 对每个文本向量 $T_j$,计算其归一化版本:
$$ \hat{T}_j = \frac{T_j}{|T_j|_2} $$
这一步之后,所有的特征向量长度都为 1。
第 3 步:计算余弦相似度分数矩阵
我们需要知道批次里的每一张图和每一段文本有多相似。最有效的方法是计算一个 $N \times N$ 的相似度矩阵。
- 我们将 $N$ 个归一化的图像向量 $\hat{I}$ 组成一个矩阵(大小为 $N \times D$,$D$是特征维度)。
- 我们将 $N$ 个归一化的文本向量 $\hat{T}$ 组成一个矩阵(大小为 $N \times D$)。
- 通过一次矩阵乘法,我们就能得到完整的相似度矩阵 $S$:
$$ S = \text{Matrix}(\hat{I}) \cdot \text{Matrix}(\hat{T})^T $$
($T$ 在这里代表矩阵的转置 Transpose)
这个 $N \times N$ 的矩阵 $S$ 的每一个元素 $S_{ij}$ 的值,就是第 $i$ 张图像和第 $j$ 段文本之间的余弦相似度分数。因为向量都归一化了,所以这个值就是它们的点积: $$ S_{ij} = \hat{I}_i \cdot \hat{T}_j $$ 举例,$S$ 看起来像这样:
T_1 T_2 T_3 ...
I_1 [ S_11, S_12, S_13, ...] <-- 这是图像1作为锚点,和所有文本的相似度
I_2 [ S_21, S_22, S_23, ...]
I_3 [ S_31, S_32, S_33, ...]
...
第 4 步:定义锚点和正负样本
现在我们有了这个分数矩阵 $S$,就可以定义对比学习中的要素了。
以第 $i$ 行为例(即以 $\text{img}_i$ 作为锚点):
- 正样本: 与 $\text{img}_i$ 匹配的文本是 $\text{text}_i$。
- 所以,它们的分数 $S_{ii}$(矩阵对角线上的元素)就是正样本分数。
- 负样本: 所有其他文本 $\text{text}_j$ 都是负样本
- 所以,这一行里所有非对角线元素 $S_{ij}$ 都是负样本分数(当 $j \neq i$) 。
第 5 步:计算损失函数
现在万事俱备,我们可以计算损失了。
对每一行计算损失(以图像为锚点):
- 对第 $i$ 行的所有分数 [$S_{i1}$, $S_{i2}$, …, $S_{ii}$, …] 应用 InfoNCE 损失公式(也就是交叉熵损失)。
$$ \text{Loss}_i^{\text{image}} = -\log\left( \frac{\exp(S_{ii}/\tau)}{\sum_{j=1}^{N} \exp(S_{ij}/\tau)} \right) $$
- 我们对所有 $N$ 行都计算这个损失,得到 $Loss_{image}$。
对每一列计算损失(以文本为锚点):
- 完全对称地,我们再对每一列计算损失(此时是把文本作为锚点,图像作为正负样本)。
- 我们对所有 $N$ 列都计算这个损失,得到 $Loss_{text}$。
最终总损失:
$$ \text{Total\_Loss} = \frac{\text{Loss\_image} + \text{Loss\_text}}{2} $$
这个 $Total_Loss$ 就是最终用来通过反向传播更新图像编码器和文本编码器参数的数值。模型的目标就是让这个损失值不断变小,从而迫使匹配的对(对角线上的 $S_{ii}$)分数变高,不匹配的对(非对角线上的 $S_{ij}$)分数变低。
补充1:InfoNCE 损失
InfoNCE 损失看起来很像 softmax,下面解释他们的关系。
Softmax 的目标是: 将一堆任意的实数分数(logits)转换成一个“概率分布”。
假设你有一个包含 $K$ 个分数的列表 $z = [z_1, z_2, …, z_K]$。Softmax 计算第 $i$ 个分数对应的概率 $p_i$ 的公式是:
$$ p_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}} $$
- 分子 $e^{z_i}$:取第 $i$ 个分数的指数。$e$ 是自然常数(约2.718)。
- 分母 $\sum_{j=1}^{K} e^{z_j}$:把所有分数的指数加起来,作为归一化项。
- 结果:用分子除以分母后,得到的 $p_i$ 就有很好的特性:
- 每个 $p_i$ 都在 0 和 1 之间。
- 所有的 $p_i$ 加起来等于 1,完全符合概率的定义。
- 指数函数会让大的分数变得更大,拉开与小分数的差距。
我们对比一下 Softmax 和 InfoNCE:
- 原始 Softmax 的输入分数是:$[z_1, z_2, …, z_K]$
- 在 CLIP 损失函数里,我们用的分数是第 $i$ 行的相似度分数:$[S_{i1}, S_{i2}, …, S_{iN}]$
我们想计算“正确匹配”(也就是第 $i$ 个文本)的概率。根据 Softmax 公式,这个概率就是:
$$ p_{\text{correct}} = \frac{e^{S_{ii}}}{\sum_{j=1}^{N} e^{S_{ij}}} $$
这和损失函数里那一坨 exp(...) / Σexp(...)
几乎一模一样,只是多了一个 $τ$。
3. tau
($τ$) 是啥?
$τ$ (tau) 是一个被称为 温度 (Temperature) 的超参数。它是一个需要我们提前设定的数值(在 CLIP 的实现中它也可以是一个可学习的参数)。
它的作用是:调控 Softmax 的“敏感度”或“锐利度”。
带有温度 $τ$ 的 Softmax 公式变为:
$$ p_i = \frac{\exp(z_i / \tau)}{\sum_{j=1}^{K} \exp(z_j / \tau)} $$
它在 exp
计算之前,先用 $τ$ 去除以每个分数。这会带来什么影响?
低温 ($τ$ 较小,比如 0.1):
- 分数 $z_i$ 会被 $τ$ 放大(例如 $z_i / 0.1 = 10 \cdot z_i$)。
exp
函数会把这个放大的差距变得极其巨大。- 结果:Softmax 的输出会非常锐利 (sharp),概率会极度集中在分数最高的那个选项上,几乎是“赢家通吃”。这会迫使模型必须非常有信心地把正样本和所有负样本区分开。
高温 ($τ$ 较大,比如 1.0):
- 分数 $z_i$ 基本不变或被缩小。
exp
函数看到的分数差距不大。- 结果:Softmax 的输出会非常平滑 (smooth),概率会比较均匀地分布在多个选项上。模型不需要那么“自信”。
Softmax 给了我们一个概率 $p_{\text{correct}}$,代表模型认为“这是正确匹配”的可信度。但我们怎么根据这个可信度来惩罚模型呢?答案就是用 交叉熵损失 (Cross-Entropy Loss),而它的简化形式就是 $-\log()$。
我们来分析 $Loss = -\log(p_{\text{correct}})$ 这个函数:
- 理想情况:模型非常确信,算出 $p_{\text{correct}} \approx 1$。那么 $Loss = -\log(1) = 0$。模型没有受到惩罚。
- 糟糕情况:模型非常不确定,算出 $p_{\text{correct}} = 0.1$。那么 $Loss = -\log(0.1) \approx 2.3$。模型受到一个中等大小的惩罚。
- 灾难情况:模型完全搞错,算出 $p_{\text{correct}} = 0.001$。那么 $Loss = -\log(0.001) \approx 6.9$。模型受到巨大的惩罚。
所以,$-\log$ 的作用就是:一个惩罚函数。你预测的正确概率越低,我就给你一个越大的惩罚值(损失)。 模型在训练中会拼命调整自己,目的就是让这个惩罚值变得尽可能小。