TIN代码修改
#TIN 编译 cuda_shift 失效
提示说:
shift_cuda.cpp:18:26: error: ‘THCState_getCurrentStream’ was not declared in this scope
解决方法按照就是把:
1 | ShiftDataCudaForward(THCState_getCurrentStream(state), |
替换成:
1 | ShiftDataCudaForward(at::cuda::getCurrentCUDAStream(), |
#参考链接
#autograd 报错
提示说:
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)
参照官方文档,在rtc_wrap.py
文件中的forward
和backward
方法前添加@staticmethod
:
1 | # Code for "Temporal Interlacing Network" |
然后在temporal_interlace.py
文件中的linear_sampler
方法,替换如下:
1 | def linear_sampler(data, bias): |
#参考文档
错误解决,效果待测试。