12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- # -*- coding: utf-8 -*-
- import torch
- from torch import nn
- from torch.autograd import Variable
- from torch.nn.parameter import Parameter
- FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
- HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
- def conversion_helper(val, conversion):
- """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
- if not isinstance(val, (tuple, list)):
- return conversion(val)
- rtn = [conversion_helper(v, conversion) for v in val]
- if isinstance(val, tuple):
- rtn = tuple(rtn)
- return rtn
- def fp32_to_fp16(val):
- """Convert fp32 `val` to fp16"""
- def half_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, FLOAT_TYPES):
- val = val.half()
- return val
- return conversion_helper(val, half_conversion)
- def fp16_to_fp32(val):
- """Convert fp16 `val` to fp32"""
- def float_conversion(val):
- val_typecheck = val
- if isinstance(val_typecheck, (Parameter, Variable)):
- val_typecheck = val.data
- if isinstance(val_typecheck, HALF_TYPES):
- val = val.float()
- return val
- return conversion_helper(val, float_conversion)
- class FP16Module(nn.Module):
- def __init__(self, module):
- super(FP16Module, self).__init__()
- self.add_module('module', module.half())
- def forward(self, *inputs, **kwargs):
- return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- return self.module.state_dict(destination, prefix, keep_vars)
- def load_state_dict(self, state_dict, strict=True):
- self.module.load_state_dict(state_dict, strict=strict)
- def get_param(self, item):
- return self.module.get_param(item)
- def to(self, device, *args, **kwargs):
- self.module.to(device)
- return super().to(device, *args, **kwargs)
|