搜索
bottom↓
回复: 3

【人工智能 图像分割 原创首发】第六章 将pb模型转换到移动端

[复制链接]

出25入84汤圆

发表于 2020-10-13 15:12:51 | 显示全部楼层 |阅读模式
本帖最后由 chun2495 于 2020-10-13 15:14 编辑

这一章就是将如何将已经训练好的网络进行固化,并采用apple公司提供的工具转换为iphone可以识别的网络。
因为之前一直用的deeplab是1.13.x,这个版本不被apple支持,所以这里需要重新建立一个环境。

1. 需要新建一个虚拟环境,因为官网(https://coremltools.readme.io/docs/what-are-coreml-tools)要求tensorflow必须大于1.14.0或者大于2.1.0。

首先在cmd中创建虚拟环境 
  1. conda create -n tf1.14.0 python=3.6.4
复制代码

然后激活
  1. conda activate tf1.14.0
复制代码

依次安装以下deeplabv3需要的库,因为后面需要在tensorflow1.14.0中训练deeplab v3网络。前面四步是安装deeplab需要的库,最后一行是安装tensorflow1.14.0。
  1. pip install numpy -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]
  2. pip install pillow -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]
  3. pip install Jupyter notebook -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]
  4. pip install Matplotlib -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]

  5. pip install tensorflow==1.14.0 -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]
复制代码


最后一步安装官网指导(https://coremltools.readme.io/docs/installation),安装转换工具coremltools。
  1. pip install coremltools==4.0b3 -i [url]https://pypi.tuna.tsinghua.edu.cn/simple[/url]
复制代码

2. 安装deeplab,并配置环境训练。

拷贝一份之前下载的deeplab。


打开anconda,选择刚才新建的环境tf1.14.0。然后打开PyCharm,选择到上一步的deeplab目录。


打开pycharm的终端,发现默认的环境还是在之前的deeplab里,当然也可以直接用conda activate tf1.14.0来激活新环境,但是切换一个shell,又得重新配置,而且新建的shell没办法切换环境,非常麻烦。所以我们需要配置一个新环境。按照下图箭头来更改就行了,因为我已经建好了,所以下划线那个需要改为envs/tf1.14.0。


3. 然后就是deeplab进行训练,这一步在我另一篇文章中介绍了,最后生成的为xxx.pb文件。这就是转换需要的文件了。

4. 新建一个shell,输入以下代码,运行后就生成了mlmodel。


5 转换代码如下:
  1. import tensorflow as tf
  2. import coremltools as ct

  3. print("tensorflow: ", tf.__version__)#输出tensorflow版本,只能在tf1.14.0以上来转换模型。

  4. image_wide = 400
  5. image_high = 320
  6. pbpath = "./exp/mydata_train/export/frozen_inference_graph_16358.pb"#输入的.pb模型
  7. mlpath = "./exp/mydata_train/export/MobileNetV2.mlmodel"#输出的.mlmodel模型

  8. image_input = ct.ImageType(shape=(1, image_high, image_wide, 3,))#输入图像的类型

  9. model = ct.convert(
  10.     pbpath, source="tensorflow", inputs=[image_input], outputs=["SemanticPredictions:0"],
  11. )

  12. # model.input_description["image_input"] = "Input image to be classified"#输入不需要描述,根据输入图像自动描述。
  13. model.output_description["SemanticPredictions:0"] = "预测膀胱图像,图像大小和输入图像一致,输出宽度*高度的点阵,数值为1或0"
  14. model.author = "chun"
  15. model.license = "none"
  16. model.short_description = "Detect Bladder"
  17. model.version = "1.0"

  18. model.save(mlpath)
复制代码


6. 对于输入输出的参数的确定需要再编写一个.py,来获取自己模型的信息。代码如下:

  1. import tensorflow as tf
  2. from tensorflow.core.framework import graph_pb2
  3. import operator
  4. import sys

  5. pbpath = "./exp/mydata_train/export/frozen_inference_graph_16358.pb"
  6. infopath = "./exp/mydata_train/export/pb-info.txt"


  7. def inspect(model_pb, output_txt_file):
  8.     graph_def = graph_pb2.GraphDef()
  9.     with open(model_pb, "rb") as f:
  10.         graph_def.ParseFromString(f.read())

  11.     tf.import_graph_def(graph_def)

  12.     sess = tf.Session()
  13.     OPS = sess.graph.get_operations()

  14.     ops_dict = {}

  15.     sys.stdout = open(output_txt_file, 'w')
  16.     for i, op in enumerate(OPS):
  17.         print(
  18.             '---------------------------------------------------------------------------------------------------------------------------------------------')
  19.         print("{}: op name = {}, op type = ( {} ), inputs = {}, outputs = {}".format(i, op.name, op.type, ", ".join(
  20.             [x.name for x in op.inputs]), ", ".join([x.name for x in op.outputs])))
  21.         print('@input shapes:')
  22.         for x in op.inputs:
  23.             print("name = {} : {}".format(x.name, x.get_shape()))
  24.         print('@output shapes:')
  25.         for x in op.outputs:
  26.             print("name = {} : {}".format(x.name, x.get_shape()))
  27.         if op.type in ops_dict:
  28.             ops_dict[op.type] += 1
  29.         else:
  30.             ops_dict[op.type] = 1

  31.     print(
  32.         '---------------------------------------------------------------------------------------------------------------------------------------------')
  33.     sorted_ops_count = sorted(ops_dict.items(), key=operator.itemgetter(1))
  34.     print('OPS counts:')
  35.     for i in sorted_ops_count:
  36.         print("{} : {}".format(i[0], i[1]))


  37. if __name__ == "__main__":
  38.     # 生成网络结构文件,可以通过这个文件查看你需要的输入和输出
  39.     inspect(pbpath, infopath)
  40.     # 转换成CoreML model
  41.     # pbtomlmodel.convert()
  42.     # 规范化CoreML model的输入输出名
  43.     # tool.rename_var()
复制代码


生成的文件为pb-info.txt,这里面存储了所有节点名称以及张量tensor,我们需要用的只是输出的节点名称,在最后可以看到输出名称为“SemanticPredictions:0”import那些不用管。
  1. 1035: op name = import/SemanticPredictions, op type = ( Identity ), inputs = import/Cast_2:0, outputs = import/SemanticPredictions:0
  2. @input shapes:
  3. name = import/Cast_2:0 : (1, ?, ?)
  4. @output shapes:
  5. name = import/SemanticPredictions:0 : (1, ?, ?)
复制代码


至此,模型转换完毕。
其实这一章是把前面第二章建立的环境重新建了一遍。然后就是生成了.mlmodel。下一节开始使用这个模型进行实战测试。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有帐号?注册

x

出0入0汤圆

发表于 2020-10-13 15:18:11 | 显示全部楼层
一脸懵逼,二脸懵逼。

出0入91汤圆

发表于 2020-10-13 15:26:15 | 显示全部楼层
配点效果图吧  否则懵逼路上你和我

出25入84汤圆

 楼主| 发表于 2020-10-13 15:28:00 | 显示全部楼层
ackyee 发表于 2020-10-13 15:26
配点效果图吧  否则懵逼路上你和我

下一节就有效果图了
回帖提示: 反政府言论将被立即封锁ID 在按“提交”前,请自问一下:我这样表达会给举报吗,会给自己惹麻烦吗? 另外:尽量不要使用Mark、顶等没有意义的回复。不得大量使用大字体和彩色字。【本论坛不允许直接上传手机拍摄图片,浪费大家下载带宽和论坛服务器空间,请压缩后(图片小于1兆)才上传。压缩方法可以在微信里面发给自己(不要勾选“原图),然后下载,就能得到压缩后的图片】。另外,手机版只能上传图片,要上传附件需要切换到电脑版(不需要使用电脑,手机上切换到电脑版就行,页面底部)。
您需要登录后才可以回帖 登录 | 注册

本版积分规则

手机版|Archiver|amobbs.com 阿莫电子技术论坛 ( 粤ICP备2022115958号, 版权所有:东莞阿莫电子贸易商行 创办于2004年 (公安交互式论坛备案:44190002001997 ) )

GMT+8, 2024-4-19 06:01

© Since 2004 www.amobbs.com, 原www.ourdev.cn, 原www.ouravr.com

快速回复 返回顶部 返回列表