본문 바로가기

Deep Learning/Debugging

[PyTorch] ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 32])

Traceback (most recent call last):
  File "train.py", line 193, in <module>
    g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs))
  File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cvmi-koo/code/Face-Generator/libs/models.py", line 77, in forward
    self.bn1(out)
  File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 135, in forward
    return F.batch_norm(
  File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 2147, in batch_norm
    _verify_batch_size(input.size())
  File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 2114, in _verify_batch_size
    raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 32])

환경: python 3.8.5, torch 1.8.1

batch size 32로 학습을 진행하고 있었는데, 마지막 batch에 single data가 되었다. 

마지막 batch가 single data인 것이 에러의 원인이라고 판단하여 구글링을 했다.

 

https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274

 

Error: Expected more than 1 value per channel when training

I have a model that works perfectly when there are multiple input. However, if there is only one datapoint as input, I will get the above error. Does anyone have an idea on what’s going on here?

discuss.pytorch.org

 

해결 방법 1. dataLoader의 파라메터 중 drop_last=True을 설정하여 마지막 batch를 없앤다.

 

해결 방법 2. model.eval()을 사용한다. (이 방법을 사용하였다.)

# Before
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ...(생략)
        self.bn = nn.BatchNorm1d(32, 0.8)
        ...(생략)
        
    def forward(self, img):
        ...(생략)
        out = self.bn(out)
        ...(생략)
        return out
# After
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        ...(생략)
        self.bn = nn.BatchNorm1d(32, 0.8)
        ...(생략)
        
    def forward(self, img):
        ...(생략)
        try:
            out = self.bn(out)
        except:
            self.bn.eval()
            out = self.bn(out)
            self.bn.train()
        ...(생략)
        return out