AI

Multi-GPU, Parallel | 용어와 개념 간단한 코드[PyTorch] 와 함께 한방에 정리

cstory-bo 2024. 1. 12. 14:40

이번에는 multi gpu를 사용하면서 나올 용어들과 개념들을 정리해보았다.
multi gpu하면 꼭 같이 나오는 parallel을 data와 model로 나누어 정리하였다.
코드

Multi-GPU

원래 옛날에는 GPU를 어떻게 하면 적게 쓸까를 고민했지만

최근에는 성능에 초점을 두면서 엄청난 양의 GPU를 쓰고 있다.

용어 정리

Single : 1개 사용 vs Multi : 2개 이상 사용

GPU vs Node : Node는 1대의 컴퓨터를 말한다.

Single Node Single GPU : 1대의 컴퓨터에서 1개의 GPU 사용

Single Node Multi GPU : 1대의 컴퓨터에서 여러 개의 GPU 사용

Multi Node Multi GPU : 서버실에서 여러 개의 GPU 사용

Model parallel

  • 다중 GPU에 학습을 분산하는 두 가지 방법
    • 모델 나누기 / 데이터 나누기
  • alexnet부터 모델 병렬화를 연구해왔다.
  • 모델의 병목 파이프라인의 어려움 등으로 고난이도 과제로 여겨지고 있다.
class ModelParallelResNet(ResNet):
	def __init__():
		...
		self.seq1 = nn.Sequential(
			self.conv2,...
		).to('cuda:0') # cuda0에 할당

		self.seq2 = nn.Sequential(
			self.conv2,...
		).to('cuda:1') # cuda1에 할당

		...
	def forward(self,x):
		x = self.seq2(self.seq1(x).to('cuda:1')) # cuda0에서 cuda1로 복사한다

		

Data parallel

Data Parallel이란

minibatch와 비슷한 느낌이다.

minibatch처럼 데이터를 나누어서 병렬적으로 돌리고

이를 나중에 합쳐서 평균을 구하는 방법이다.

Forward 과정

  1. 데이터를 나누어 각 GPU에 할당한다.
  2. 병렬적으로 연산을 한다.
  3. 모두 마친 후 한 곳의 GPU에 결과들을 모은다.

Backward 과정

  1. 한 곳에 모인 각 loss들을 가지고 각자 grad를 구한다.
  2. 다시 이 grad들을 각 GPU에 다시 보낸다.
  3. 각 GPU에서 병렬적으로 backward를 실행한다.
  4. 다시 모아서 하나의 GPU에 보내고 이들의 평균을 내서 업데이트한다.

특징

  • 단순히 데이터를 분배하고 평균을 취하기 때문에 GPU사용 불균형 문제가 발생한다.
  • GIL(Global Interpreter Lock)

PyTorch 코드

parallel_model = torch.nn.DataParallel(model)
...
# backward 할 때 Average GPU-losses + backward pass
loss.mean().backward()

Distributed Data Parallel

각 GPU에 CPU도 할당해줘서 각 CPU마다 process 생성하여 개별 GPU에 할당한다.

⇒ 기본적으로 DataParallel로 하나 개별적으로 연산의 평균을 낸다.

train_sampler = torch.utils.data.distributed.DistributedSampler(data)

loader = torch.utils.data.DataLoader(data,batch_size=20,shuffle = False,
	pin_memory = True,sampler=train_sampler)

def main():
	n_gpus = torch.cuda.device_count()
	torch.multiprocessing.spawn(main_worker,nprocs=n_gpus,args=(n_gpus,)

def main_worker(gpu,n_gpus):
	...
	# 멀티프로세싱 통신 규약 정의
	torch.distributed.init_process_group(backend='nccl',
	init_method='tcpL//~',world_size=n_gpus,rank=gpu)
	
	# Distributed DataParallel 정의
	torch.nn.parallel.DistributedDataParallel(model,device_idx=[gpu])

추가로, num_workers는 GPU x 4를 많이한다.