Stable Diffusion 里的 UNet 主要有三个输入,分别是:
Input | Dimensions |
---|---|
带噪的潜空间图像 | (1, 4, 64, 64) |
时间步 | (1,) |
文本提示编码 | (1, 77, 768) |
这三个输入经过 UNet 后,会得到一个输出,维度是 (1, 4, 64, 64)
。
本文的目标就是:解释这个输出具体是怎么得到的。
整体结构
UNet 的结构主要包含五个部分:
- conv_in
- Encoder
- Middle Block
- Decoder
- conv_out
这里的 conv_in 和 conv_out 其实就是一个普通的 3x3 卷积层,用来转换通道数。
写成代码就是一行的事:
conv_in = Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv_out = Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
所以 UNet 的大部份操作,其实都是在 (1, 320, 64, 64)
这个维度基础上完成的。
Component | Input Dimensions | Output Dimensions |
---|---|---|
Encoder | (1, 320, 64, 64) | (1, 1280, 8, 8) |
Middle Block | (1, 1280, 8, 8) | (1, 1280, 8, 8) |
Decoder | (1, 1280, 8, 8) | (1, 320, 64, 64) |
先分析 Encoder
输入 (1, 320, 64, 64)
进入Encoder后,会依次经过四个Block:
- CrossAttnDownBlock2D
- CrossAttnDownBlock2D
- CrossAttnDownBlock2D
- DownBlock2D
得到输出 (1, 1280, 8, 8)
。
可以看到,这里有两类 Block:
- CrossAttnDownBlock2D
- DownBlock2D
他俩的关系是:CrossAttnDownBlock2D 是加强版的 DownBlock2D,里面有 Attention Block ,而 DownBlock2D 里没有。
下面先介绍 DownBlock2D。
DownBlock2D 包含 N个ResBlock(N=2)和一个可选的下采样层。
特征会依次经过这N个ResBlock,最后经过下采样层(如果有)。
而 CrossAttnDownBlock2D 在每个ResBlock后,加了一个 Attention Block,用来计算图像特征的自注意力和图文特征的交叉注意力。
特征会依次经过这N对ResBlock-Attention Block,最后经过下采样层(如果有)。
对于Encoder里下采样层的设计是,前三个Block有,最后一个没有。
对于特征通道数的加倍是,中间两个会加倍,其他的不变。
下面是整体的设计:
对比指标 | Down Block 1 | Down Block 2 | Down Block 3 | Down Block 4 |
---|---|---|---|---|
模块类型 | CrossAttnDownBlock2D | CrossAttnDownBlock2D | CrossAttnDownBlock2D | DownBlock2D |
尺寸减半 | 是 | 是 | 是 | 否 |
尺寸变化 | 64x64 -> 32x32 | 32x32 -> 16x16 | 16x16 -> 8x8 | 8x8 |
通道数加倍 | 否 | 是 | 是 | 否 |
通道数变化 | 320 | 320 -> 640 | 640 -> 1280 | 1280 |
ResBlock 数量 | 2 | 2 | 2 | 2 |
Attention Block 数量 | 2 | 2 | 2 | 0 |
主要输入 |
conv_in、Encoder里的每个ResBlock以及下采样层的输出,都会存入down_block_res_samples,用于后续Decoder都跨层特征融合。
所以down_block_res_samples里一共有:1+2*4+1*3=1+8+3=12
个特征。
这12个特征的维度如下图:
序号 (Index) | 生产者 (Producer Layer) | 输出特征图 (通道数, 尺寸) |
---|---|---|
0 | conv_in | (320, 64x64) |
1 | down_blocks[0] 的第1个 ResBlock | (320, 64x64) |
2 | down_blocks[0] 的第2个 ResBlock | (320, 64x64) |
3 | down_blocks[0] 的 Downsampler | (320, 32x32) |
4 | down_blocks[1] 的第1个 ResBlock | (640, 32x32) |
5 | down_blocks[1] 的第2个 ResBlock | (640, 32x32) |
6 | down_blocks[1] 的 Downsampler | (640, 16x16) |
7 | down_blocks[2] 的第1个 ResBlock | (1280, 16x16) |
8 | down_blocks[2] 的第2个 ResBlock | (1280, 16x16) |
9 | down_blocks[2] 的 Downsampler | (1280, 8x8) |
10 | down_blocks[3] 的第1个 ResBlock | (1280, 8x8) |
11 | down_blocks[3] 的第2个 ResBlock | (1280, 8x8) |
再分析 Decoder
Encdoer 的输出会经过 Middle Block,并且特征维度和通道数都保持不变。
- 输入
x
:(1, 1280, 8, 8)
- 内部层: 1 x ResBlock, 1 x Attention Block, 1 x ResBlock。
- 输出
x
:(1, 1280, 8, 8)
Middle Block 的最终输出,也就是 Decoder 的输入,就需要和 Encdoer 的输出列表不断的进行拼接融合了。
输入在 Decoder 里也会经过四个 Block:
- UpBlock2D
- CrossAttnUpBlock2D
- CrossAttnUpBlock2D
- CrossAttnUpBlock2D
Decoder的设计和Encoder差不多,但是有以下区别:
- 前3个block要做上采样
- 后两个block做通道数减半
- 每个block包含3对ResBlock- Attention Block,而不是2对。
对比指标 | Up Block 1 | Up Block 2 | Up Block 3 | Up Block 4 |
---|---|---|---|---|
模块类型 | UpBlock2D | CrossAttnUpBlock2D | CrossAttnUpBlock2D | CrossAttnUpBlock2D |
尺寸加倍 | 是 | 是 | 是 | 否 |
尺寸变化 | 8x8 -> 16x16 | 16x16 -> 32x32 | 32x32 -> 64x64 | 64x64 -> 64x64 |
通道数减半 | 否 | 否 | 是 | 是 |
通道数变化 | 1280 + x -> 1280 | 1280 + x -> 1280 | 1280 + x -> 640 | 640 + x -> 320 |
ResBlock 数量 | 3 | 3 | 3 | 3 |
Attention Block 数量 | 0 | 3 | 3 | 3 |
主要输入 |
跨层分析
使用残差的时候,就是先拼接,再通过ResBlock融合特征通道。
因为前两个block对特征通道数不变,所以这三个ResBlock的输出维度是一样的。
up_blocks[0]
- 输入
sample_in
:[B, 1280, 8, 8]
(来自mid_block
)
步骤 | 操作 | 输入Tensor(s) 及其维度 | 关键层定义 (in_ch -> out_ch) | 输出维度 |
---|---|---|---|---|
1.1 | torch.cat | sample_in : [B, 1280, 8, 8] res[11] : [B, 1280, 8, 8] | [B, 2560, 8, 8] | |
1.2 | up_blocks[0].ResBlocks[0] | (上一步输出) : [B, 2560, 8, 8] | ResBlockBlock2D(conv1: 2560 -> 1280) | [B, 1280, 8, 8] |
1.3 | torch.cat | (上一步输出) : [B, 1280, 8, 8] res[10] : [B, 1280, 8, 8] | [B, 2560, 8, 8] | |
1.4 | up_blocks[0].ResBlocks[1] | (上一步输出) : [B, 2560, 8, 8] | ResBlockBlock2D(conv1: 2560 -> 1280) | [B, 1280, 8, 8] |
1.5 | torch.cat | (上一步输出) : [B, 1280, 8, 8] res[9] : [B, 1280, 8, 8] | [B, 2560, 8, 8] | |
1.6 | up_blocks[0].ResBlocks[2] | (上一步输出) : [B, 2560, 8, 8] | ResBlockBlock2D(conv1: 2560 -> 1280) | [B, 1280, 8, 8] |
1.7 | up_blocks[0].upsamplers[0] | (上一步输出) : [B, 1280, 8, 8] | Upsample2D(conv: 1280 -> 1280) | [B, 1280, 16, 16] |
这里也是,因为总体上前2个block的输出通道数是不变的,只是要根据残差的通道数来修改输入通道数。
up_blocks[1]
- 输入
sample_in
:[B, 1280, 16, 16]
(来自up_blocks[0]
)
步骤 | 操作 | 输入Tensor(s) 及其维度 | 关键层定义 (in_ch -> out_ch) | 输出维度 |
---|---|---|---|---|
2.1 | torch.cat | sample_in : [B, 1280, 16, 16] res[8] : [B, 1280, 16, 16] | [B, 2560, 16, 16] | |
2.2 | up_blocks[1].ResBlocks[0] | (上一步输出) : [B, 2560, 16, 16] | ResBlockBlock2D(conv1: 2560 -> 1280) | [B, 1280, 16, 16] |
2.3 | torch.cat | (上一步输出) : [B, 1280, 16, 16] res[7] : [B, 1280, 16, 16] | [B, 2560, 16, 16] | |
2.4 | up_blocks[1].ResBlocks[1] | (上一步输出) : [B, 2560, 16, 16] | ResBlockBlock2D(conv1: 2560 -> 1280) | [B, 1280, 16, 16] |
2.5 | torch.cat | (上一步输出) : [B, 1280, 16, 16] res[6] : [B, 640, 16, 16] | [B, 1920, 16, 16] | |
2.6 | up_blocks[1].ResBlocks[2] | (上一步输出) : [B, 1920, 16, 16] | ResBlockBlock2D(conv1: 1920 -> 1280) | [B, 1280, 16, 16] |
2.7 | up_blocks[1].upsamplers[0] | (上一步输出) : [B, 1280, 16, 16] | Upsample2D(conv: 1280 -> 1280) | [B, 1280, 32, 32] |
后两个Block对第一个ResBlock需要做到通道数减半。
up_blocks[2]
- 输入
sample_in
:[B, 1280, 32, 32]
(来自up_blocks[1]
)
步骤 | 操作 | 输入Tensor(s) 及其维度 | 关键层定义 (in_ch -> out_ch) | 输出维度 |
---|---|---|---|---|
3.1 | torch.cat | sample_in : [B, 1280, 32, 32] res[5] : [B, 640, 32, 32] | [B, 1920, 32, 32] | |
3.2 | up_blocks[2].ResBlocks[0] | (上一步输出) : [B, 1920, 32, 32] | ResBlockBlock2D(conv1: 1920 -> 640) | [B, 640, 32, 32] |
3.3 | torch.cat | (上一步输出) : [B, 640, 32, 32] res[4] : [B, 640, 32, 32] | N/A | [B, 1280, 32, 32] |
3.4 | up_blocks[2].ResBlocks[1] | (上一步输出) : [B, 1280, 32, 32] | ResBlockBlock2D(conv1: 1280 -> 640) | [B, 640, 32, 32] |
3.5 | torch.cat | (上一步输出) : [B, 640, 32, 32] res[3] : [B, 320, 32, 32] | N/A | [B, 960, 32, 32] |
3.6 | up_blocks[2].ResBlocks[2] | (上一步输出) : [B, 960, 32, 32] | ResBlockBlock2D(conv1: 960 -> 640) | [B, 640, 32, 32] |
3.7 | up_blocks[2].upsamplers[0] | (上一步输出) : [B, 640, 32, 32] | Upsample2D(conv: 640 -> 640) | [B, 640, 64, 64] |
第一个ResBlock要做到通道数减半。
up_blocks[3]
- 输入
sample_in
:[B, 640, 64, 64]
(来自up_blocks[2]
)
步骤 | 操作 | 输入Tensor(s) 及其维度 | 关键层定义 (in_ch -> out_ch) | 输出维度 |
---|---|---|---|---|
4.1 | torch.cat | sample_in : [B, 640, 64, 64] res[2] : [B, 320, 64, 64] | [B, 960, 64, 64] | |
4.2 | up_blocks[3].ResBlocks[0] | (上一步输出) : [B, 960, 64, 64] | ResBlockBlock2D(conv1: 960 -> 320) | [B, 320, 64, 64] |
4.3 | torch.cat | (上一步输出) : [B, 320, 64, 64] res[1] : [B, 320, 64, 64] | [B, 640, 64, 64] | |
4.4 | up_blocks[3].ResBlocks[1] | (上一步输出) : [B, 640, 64, 64] | ResBlockBlock2D(conv1: 640 -> 320) | [B, 320, 64, 64] |
4.5 | torch.cat | (上一步输出) : [B, 320, 64, 64] res[0] : [B, 320, 64, 64] | N/A | [B, 640, 64, 64] |
4.6 | up_blocks[3].ResBlocks[2] | (上一步输出) : [B, 640, 64, 64] | ResBlockBlock2D(conv1: 640 -> 320) | [B, 320, 64, 64] |