torch_function
This document explains some important point about __torch_function__
and its role in this project.
__torch_function__
API is first introduced in v1.5.0, hereSubclass preservations of some important operators introduced here. This property is very important since
torchview
package keeps track tensor of onlyRecorderTensor
subclass.
- Some important fixes introdued here. For instance support for
F.embedding
is included. Otherwise,F.embedding
under__torch_function__
of subclass would returnNotImplemented
, leading totorch.Tensor
, which is not desired. To prevent this issue (and support pytorch version < 1.9), we added the below code inrecorder_tensor.py
# This is necessary for torch version < 1.10
if func in [F.linear, F.embedding]:
out = nn.parameter.Parameter.__torch_function__(
func, types, args, kwargs).as_subclass(RecorderTensor)
else:
# use original torch_function; otherwise,
# it leads to infinite recursive call of torch_function
out = super().__torch_function__(func, types, args, kwargs)
To be precise about the versions,
F.linear returns `NotImplemented` for versions 1.7.1, 1.8, 1.9
F.embedding returns `NotImplemented` for versions 1.7.1, 1.8, 1.9
This package does not support torch version <= 1.6 since torch.Tensor does not have __torch_function__
as class methods in these version.