pytorch JIT浅解析
概要 Torch Script中的核心数据结构是ScriptModule。 它是Torch的nn.Module的类似物,代表整个模型作为子模块树。 与普通模块一样,ScriptModule中的每个单独模块都可以包含子模块,参数和方法。 在nn.Modules中,方法是作为Python函数实现的,但在ScriptModules方法中通常实现为Torch Script函数,这是一个静态类型的Python子集,包含PyTorch的所有内置Tensor操作。 这种差异允许您运行ScriptModules代码而无需Python解释器。 ScriptModules和Torch Script函数可以通过两种方式创建: Tracing: 使用torch.jit.trace,您可以获取现有模块或python函数,提供示例输入,然后运行该函数,记录在所有张量上执行的操作。 我们将生成的记录转换为Torch Script方法,该方法作为ScriptModule的正向方法安装。 该模块还包含原始模块所具有的任何参数。 Example: import torch def foo(x, y): return 2*x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) 1 2 3 4 注意: