V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
Waihinchan
V2EX  ›  机器学习

关于 pytorch 模型可视化的问题

  •  
  •   Waihinchan · 2020-07-18 07:31:19 +08:00 · 2505 次点击
    这是一个创建于 1370 天前的主题,其中的信息可能已经有所发展或是发生改变。

    最近在学习 pix2pixHD 的模型.基本上大体都对模型的细节有所了解了,但是想像在 tensorboard 里面 plot 一下模型.发现很多问题.想请教一下有没有什么比较好的方法可以把模型的结构图给打出来呢? 我首先试了一下 pytorch 的 torch.utils.tensorboard,用的 add_graph 这个要求写一个 dummyinput 我就直接从 train 里面的 dataloader 直接 load 了一个数据,然后出现的错误如下: Converting a tensor to a Python index might cause the trace to be incorrect…This means that the trace might not generalize to other inputs! 大概意思就是说把 tensor 转成 python 就会变成常量什么的,然后图就打不出来.

    我本来已经 hack 了一下模型的源代码,因为 add_graph 要求只有一个 input,但是它在 train.py 脚本中是有 4 个 input 的,我把它打包成一个参数之后再在执行中自己再解开了. 实际上出现问题的好像是它本身在 forward 中对 tensor 的一些转换出现了问题.

    它的源代码写的有点抽象,各种包装,加上初学实在想象不出来这个模型(已经大概有个数了,但是想打印出来对照一下自己理解有没有错)

    有没有什么比较快速的方法可以先让我预览一下这个模型呢?(实在不行我就只能一个一个模块在 tf 里面复现,或者把各个层重新自己写一次再用 summarywriter 来查看了..)

    10 条回复    2020-07-18 13:19:04 +08:00
    askfermi
        1
    askfermi  
       2020-07-18 07:53:34 +08:00   ❤️ 1
    我猜这个应该可以满足你的需要: https://github.com/lutzroeder/netron
    Waihinchan
        2
    Waihinchan  
    OP
       2020-07-18 08:41:21 +08:00
    @askfermi 感谢! 我试了一下 能 plot 出来但是似乎少了很多个层,好像对于一些 resnet 的结构支持的不是很好的 我看了一下作者说对于 pytorch 是试验性的支持.不过起码能有工具了~
    leimao
        3
    leimao  
       2020-07-18 08:43:54 +08:00
    PyTorch 不是有 TensorBoard 吗?
    leimao
        4
    leimao  
       2020-07-18 08:45:16 +08:00
    OK,我没仔细读,你是懒得用 TensorBoard 。那 Netron 是你最好的选择了。
    Waihinchan
        5
    Waihinchan  
    OP
       2020-07-18 08:58:44 +08:00
    @leimao 不是懒得用...是 plot 不出来 每次都说 trace 报错 但是找不到哪里有问题 拆了东墙补西墙都还是会报错 是 input 好像不太符合它的要求, 但是本身模型已经连数据都有了预处理的方式了 而且太多参数 实在没办法一下子改出一个能放在 tensorboard 里的
    我个人还是希望能用 tensorboard 的 但每个人写的模型都有点不同.. 类似参数的数量什么的还要调 动一动直接模型都加载不出来就很头大..
    leimao
        6
    leimao  
       2020-07-18 09:18:04 +08:00
    @Waihinchan 这个 tracing 的 那个东西是一个 warning,我没记错的话,使用来生成 jit 或者 onnx 用的,也就是用 dummy input 把模型跑一遍。你如果 implement 的不注意,就像 warning 说的,这个不能 generalize,这很可能是因为你中间用了 Python 的变量而不是 PyTorch 的变量导致的。
    leimao
        7
    leimao  
       2020-07-18 09:19:07 +08:00
    @Waihinchan 换句话说,你的模型中间有一部分东西是属于 Python 的而不是属于模型的 graph 的,所以你打印模型的时候会少那一部分,甚至是报错。
    leimao
        8
    leimao  
       2020-07-18 09:20:43 +08:00   ❤️ 1
    @Waihinchan 你要修炼到看代码就心中能看到 graph 的境界,graph visualization 这东西,工程上用的多,科研上没觉得有多大用途。
    Waihinchan
        9
    Waihinchan  
    OP
       2020-07-18 09:51:42 +08:00 via Android
    @leimao 我再试试…我本来输入的时候还特意打印了一下类型验证过 可能是中间有些环节出了问题又给转换回去了…
    duaneya
        10
    duaneya  
       2020-07-18 13:19:04 +08:00 via iPhone
    导出 onnx 看吧
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   我们的愿景   ·   实用小工具   ·   5285 人在线   最高记录 6543   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 28ms · UTC 07:45 · PVG 15:45 · LAX 00:45 · JFK 03:45
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.