Global Filter Networks for Image Classification
Yongming Rao* Wenliang Zhao* Zheng Zhu
Tsinghua University
[Paper (arXiv)] [Code (GitHub)]
Figure 1: The overall architecture of the Global Filter Network. Our architecture is based on Vision Transformer (ViT) models with some minimal modifications. We replace the self-attention sub-layer with the proposed global filter layer, which consists of three key operations: a 2D discrete Fourier transform to convert the input spatial features to the frequency domain, an element-wise multiplication between frequency-domain features and the global filters, and a 2D inverse Fourier transform to map the features back to the spatial domain. The efficient fast Fourier transform (FFT) enables us to learn arbitrary interactions among spatial locations with log-linear complexity.
Abstract
Recent advances in self-attention and pure multi-layer perceptrons (MLP) models for vision have shown great potential in achieving promising performance with fewer inductive biases. These models are generally based on learning interaction among spatial locations from raw data. The complexity of self-attention and MLP grows quadratically as the image size increases, which makes these models hard to scale up when high-resolution features are required. In this paper, we present the Global Filter Network (GFNet), a conceptually simple yet computationally efficient architecture, that learns long-term spatial dependencies in the frequency domain with log-linear complexity. Our architecture replaces the self-attention layer in vision transformers with three key operations: a 2D discrete Fourier transform, an element-wise multiplication between frequency-domain features and learnable global filters, and a 2D inverse Fourier transform. We exhibit favorable accuracy/complexity trade-offs of our models on both ImageNet and downstream tasks. Our results demonstrate that GFNet can be a very competitive alternative to transformer-style models and CNNs in efficiency, generalization ability and robustness.
Global Filter Layer
GFNet is a conceptually simple yet computationally efficient architecture, which consists of several stacking Global Filter Layers and Feedforward Networks (FFN). The Global Filter Layer mixes tokens with log-linear complexity benefiting from the highly efficient Fast Fourier Transform (FFT) algorithm. The layer is easy to implement:
import torch
import torch.nn as nn
import torch.fft
class GlobalFilter(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
def forward(self, x):
B, H, W, C = x.shape
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
return x
Table 1: Comparisons of the proposed Global Filter with prevalent operations in deep vision models.
Figure 2: Comparisons among GFNet, ViT [8] and ResMLP [36] in (a) FLOPs (b) latency and (c) GPU memory with respect to the number of tokens (feature resolution). The dotted lines indicate the estimated values when the GPU memory has run out. The latency and GPU memory is measured using a single NVIDIA RTX 3090 GPU with batch size 32 and feature dimension 384.
Results
We exhibit favorable accuracy/complexity trade-offs of our models on both ImageNet and downstream tasks.
We demonstrate that GFNet can be a very competitive alternative to transformer-style models and CNNs in terms of efficiency, generalization ability, and robustness.
Figure 3: ImageNet accuracy vs computational complexity of transformer-style models.
Figure 4: ImageNet accuracy vs computation complexity of hierarchical models.
Table 2: Semantic segmentation results on ADE20K. We report the mIoU on the validation set. All models are equipped with Semantic FPN [17] and trained for 80K iterations following [40]. The FLOPs is tested with 1024×1024 input. We compare the models that have similar computational cost and divide the models into three groups: 1) tiny models using ResNet-18, PVT-Ti and GFNet-H-Ti; 2) small models using ResNet-50, PVT-S, Swin-Ti and GFNet-H-S and 3) base models using ResNet-101, PVT-M, Swin-S and GFNet-H-B.
Figure 5: Visualization of the learned global filters in GFNet-XS. We visualize the original frequency domain global filters in (a) and show the corresponding spatial domain filters for the first 6 columns in (b). There are more clear patterns in the frequency domain than the spatial domain.
Table 3: Evaluation of robustness and generalization ability. We measure the robustness from different aspects, including the adversarial robustness by adopting adversarial attack algorithms including FGSM and PGD and the performance on corrupted/out-of-distribution datasets including ImageNet-A (top-1 accuracy) and ImageNet-C (mCE, lower is better). The generalization
ability is evaluated on ImageNet-V2 and ImageNet-Real.
BibTeX
@article{rao2021global,
title={Global Filter Networks for Image Classification},
author={Rao, Yongming and Zhao, Wenliang and Zhu, Zheng and Lu, Jiwen and Zhou, Jie},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}