fp16.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # -*- coding: utf-8 -*-
  2. import torch
  3. from torch import nn
  4. from torch.autograd import Variable
  5. from torch.nn.parameter import Parameter
  6. FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
  7. HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
  8. def conversion_helper(val, conversion):
  9. """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
  10. if not isinstance(val, (tuple, list)):
  11. return conversion(val)
  12. rtn = [conversion_helper(v, conversion) for v in val]
  13. if isinstance(val, tuple):
  14. rtn = tuple(rtn)
  15. return rtn
  16. def fp32_to_fp16(val):
  17. """Convert fp32 `val` to fp16"""
  18. def half_conversion(val):
  19. val_typecheck = val
  20. if isinstance(val_typecheck, (Parameter, Variable)):
  21. val_typecheck = val.data
  22. if isinstance(val_typecheck, FLOAT_TYPES):
  23. val = val.half()
  24. return val
  25. return conversion_helper(val, half_conversion)
  26. def fp16_to_fp32(val):
  27. """Convert fp16 `val` to fp32"""
  28. def float_conversion(val):
  29. val_typecheck = val
  30. if isinstance(val_typecheck, (Parameter, Variable)):
  31. val_typecheck = val.data
  32. if isinstance(val_typecheck, HALF_TYPES):
  33. val = val.float()
  34. return val
  35. return conversion_helper(val, float_conversion)
  36. class FP16Module(nn.Module):
  37. def __init__(self, module):
  38. super(FP16Module, self).__init__()
  39. self.add_module('module', module.half())
  40. def forward(self, *inputs, **kwargs):
  41. return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
  42. def state_dict(self, destination=None, prefix='', keep_vars=False):
  43. return self.module.state_dict(destination, prefix, keep_vars)
  44. def load_state_dict(self, state_dict, strict=True):
  45. self.module.load_state_dict(state_dict, strict=strict)
  46. def get_param(self, item):
  47. return self.module.get_param(item)
  48. def to(self, device, *args, **kwargs):
  49. self.module.to(device)
  50. return super().to(device, *args, **kwargs)