허깅페이스 PyTorch 프로파일링 2편, nn.Linear의 바이어스 융합 원리 해부
허깅페이스가 'PyTorch 프로파일링' 시리즈의 2편을 블로그에 공개했다. 1편에서 손으로 작성한 행렬곱-덧셈 쌍(torch.add(torch.matmul(x, w), b))으로 프로파일러 트레이스를 읽는 법을 다뤘다면, 이번 2편은 그 자리를 딥러닝 모델의 기본 구성요소인 nn.Linear(bias=True)로 대체한다. 나아가 활성화 함수를 사이에 끼운 세 개의 Linear 층을 쌓아 다층 퍼셉트론(MLP) 블록을 만든다.
예제는 02_linear.py, 03_simple_mlp.py, 03_kernels_mlp.py 스크립트로 구성되며, NVIDIA A100-SXM4-80GB GPU에서 실행했다. 허깅페이스는 Spaces의 Dev Mode나 Jobs 파이프라인으로 손쉽게 GPU를 띄워 실험할 수 있다고 안내했다.
글은 두 가지 전제를 반복적으로 활용한다. GPU 커널은 GPU의 수많은 스레드에서 병렬로 실행되는 프로그램이고, CPU는 이 커널들을 스케줄링하고 실행한다는 것이다. 프로파일러 트레이스에서 보이는 PyTorch 오버헤드의 대부분이 바로 이 스케줄링 작업이다.
nn.Linear는 1편에서 프로파일링한 것과 같은 행렬곱·덧셈을 감싼 모듈로, 가중치와 바이어스를 파라미터로 소유하고 forward 메서드를 노출한다. 연산은 y = x @ w.T + b로 표현된다. 트레이스를 확대하면 행렬곱·덧셈에 해당하는 aten::addmm 앞에 전치(transpose)인 aten::t 연산이 나타난다.
핵심은 aten::t가 실제로 데이터를 복사하거나 재배치하지 않는다는 점이다. 텐서는 데이터를 메모리에 하나의 연속된 흐름으로 저장하는 반면, 형태(shape)와 스트라이드(stride)는 그 위에 얹혀 데이터를 어떻게 훑을지 알려주는 메타데이터다. 전치는 이 스트라이드를 맞바꿔 같은 원본 데이터의 다른 뷰를 만들 뿐, 단 하나의 숫자도 옮기지 않는다. 따라서 aten::t는 CPU에서 메타데이터만 다시 쓰고 GPU 커널은 실행하지 않는다.
또 하나 주목할 점은 바이어스 덧셈에 해당하는 별도의 aten::add 커널이 디스패치 체인에 보이지 않는다는 것이다. 바이어스 덧셈이 '에필로그(epilogue)'라는 기법으로 행렬곱 커널 안에 접혀 들어갔기 때문이다. 에필로그는 GEMM(일반 행렬곱) 커널이 결과를 HBM(고대역폭 메모리)에 되쓰기 직전에 수행하는 작은 연산으로, 바이어스 더하기·활성화 적용·상수 스케일링 등이 대표적이다. 메모리를 두 번 오가는 비용을 피하는 것이 목적이다.
nn.Linear는 torch.nn.functional.linear를 거쳐 aten::linear를 호출하는데, aten::linear는 바이어스가 전달된 것을 보고 행렬곱과 덧셈을 따로 하는 대신 aten::addmm(bias, x, weight)을 디스패치한다. GPU에서 도는 cuBLAS GEMM 커널에는 바이어스 덧셈 변형이 내장돼 있어, 덧셈이 별도 커널로 나타나지 않고 행렬곱 커널의 되쓰기 단계에 포함된다.
이 대목에서 미묘한 사실이 드러난다. 1편에서 --compile로 보았던 addmm 커널이, 사실은 eager 모드의 nn.Linear가 이미 사용하는 바로 그 커널이라는 것이다. 실제로 단일 nn.Linear의 forward를 eager와 compile로 각각 트레이스하면, 같은 cuBLAS GEMM 커널과 같은 aten::addmm 연산이 나타나고 compile 쪽에만 CPU 레인에 몇 줄이 더 붙는다.
그렇다고 compile이 무의미한 것은 아니다. compile은 GPU 커널을 없앤 것이 아니라, 그 뷰를 디스패치하던 CPU 오버헤드를 제거한다. 컴파일러의 Inductor가 컴파일 시점에 뷰 체인을 추적해 결과 스트라이드를 한 번 계산하고, 그 값을 박아 넣은 aten::addmm 호출을 직접 내보낸다. GPU는 동일한 연산을 하되 CPU에서 수 마이크로초의 작업이 사라진다. 다만 입력 데이터가 컴파일러가 미리 계산한 스트라이드를 위반하면 오류가 발생한다.
결론적으로 모델이 느리게 느껴질 때마다 반사적으로 torch.compile을 떠올리기 쉽지만, 바이어스가 붙은 단일 GEMM에는 compile이 할 일이 거의 없다. 융합이 의미를 가지려면 둘 이상의 연산이 필요하며, 그 효과는 여러 Linear 층을 쌓은 MLP에서 비로소 드러난다.