본문 바로가기

Python

ImportError: cannot import name '_NewEmptyTensorOp' from 'torchvision.ops.misc'

반응형

해당 에러는 torchvision의 version 확인을 제대로 못해서 발생하는 문제이다.

 

import torch
import torch.nn as nn
import torch.distributed as dist
from torch import Tensor

# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if float(torchvision.__version__[:3]) < 0.5:
    import math
    from torchvision.ops.misc import _NewEmptyTensorOp
    def _check_size_scale_factor(dim, size, scale_factor):
        # type: (int, Optional[List[int]], Optional[float]) -> None
        if size is None and scale_factor is None:
            raise ValueError("either size or scale_factor should be defined")
        if size is not None and scale_factor is not None:
            raise ValueError("only one of size or scale_factor should be defined")
        if not (scale_factor is not None and len(scale_factor) != dim):
            raise ValueError(
                "scale_factor shape must match input shape. "
                "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
            )
    def _output_size(dim, input, size, scale_factor):
        # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
        assert dim == 2
        _check_size_scale_factor(dim, size, scale_factor)
        if size is not None:
            return size
        # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
        assert scale_factor is not None and isinstance(scale_factor, (int, float))
        scale_factors = [scale_factor, scale_factor]
        # math.floor might return float in py2.7
        return [
            int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
        ]
elif float(torchvision.__version__[:3]) < 0.7:
    from torchvision.ops import _new_empty_tensor
    from torchvision.ops.misc import _output_size

위에서 아래와 관련된 부분이 문제라는 뜻.

float(torchvision.__version__[:3]) < 0.5:

 

실제로 내가 설치한 torchvision의 version은 0.12인데 torchvision.__version__[:3]은 0.1을 반환하여 위 조건문이 true가 되고 _NewEmptyTensorOp를 import하려 한다. 그런데 아마도 0.12에서는 _NewEmptyTensorOp를 지원하지 않는듯?

 

다행이 조건문이 elseif로 끝난다!! 이말은 version이 0.7보다 크거나 같은경우는 조건문이 무시된다는 뜻!!

결국 misc.py에서 해당 조건문에 해당되는 부분을 모두 주석처리하면 문제없이 실행된다!!!