Skip to content Skip to footer

ELA(Efficient Local Attention)注意力机制

2024年出了一个新的随插随用的注意力机制ELA(Efficient Local Attention)。

具体结构:

(a)SE模块

(b)CA模块

(c)ELA模块

文中总结了CA模块的不足之处并针对CA的缺点提出了ELA模块。

为了优化 ELA 的性能,同时考虑参数数量(参数为:conv1d中的Kernel_size,groups,GN中的num_groups),文中引入了四种方案:

ELA-Tiny(ELA-T), ELA-Base(ELAB), ELA-Small(ELA-S), ELA-Large(ELA-L)

代码实现:

import torch

import torch.nn as nn

class ELA(nn.Module):

def __init__(self, in_channels, phi):

super(ELA, self).__init__()

'''

ELA-T 和 ELA-B 设计为轻量级,非常适合网络层数较少或轻量级网络的 CNN 架构

ELA-B 和 ELA-S 在具有更深结构的网络上表现最佳

ELA-L 特别适合大型网络。

'''

Kernel_size = {'T': 5, 'B': 7, 'S': 5, 'L': 7}[phi]

groups = {'T': in_channels, 'B': in_channels, 'S': in_channels//8, 'L': in_channels//8}[phi]

num_groups = {'T': 32, 'B': 16, 'S': 16, 'L': 16}[phi]

pad = Kernel_size//2

self.con1 = nn.Conv1d(in_channels, in_channels, kernel_size=Kernel_size, padding=pad, groups=groups, bias=False)

self.GN = nn.GroupNorm(num_groups, in_channels)

self.sigmoid = nn.Sigmoid()

def forward(self, input):

b, c, h, w = input.size()

x_h = torch.mean(input, dim=3, keepdim=True).view(b,c,h)

x_w = torch.mean(input, dim=2, keepdim=True).view(b,c,w)

x_h = self.con1(x_h) # [b,c,h]

x_w = self.con1(x_w) # [b,c,w]

x_h = self.sigmoid(self.GN(x_h)).view(b, c, h, 1) # [b, c, h, 1]

x_w = self.sigmoid(self.GN(x_w)).view(b, c, 1, w) # [b, c, 1, w]

return x_h * x_w * input

if __name__ == "__main__":

# 创建一个形状为 [batch_size, channels, height, width] 的虚拟输入张量

input = torch.randn(2, 256, 40, 40)

ela = ELA(in_channels=256, phi='T')

output = ela(input)

print(output.size())

Copyright © 2088 乒乓球世界杯几年一次_世界杯冠军 - salooo.com All Rights Reserved.
友情链接