- 2024-08-04
-
回复了主题帖:
有谁和我一样,动过给产品加AI的心思?
wangerxian 发表于 2024-7-24 17:52
等于这个功能你们白做了呗,是效果不好吗?
不是,这个机器学习算法就是我们自己做的,相当于自己砸自己的饭碗。
不过没办法,因为友商们为了竞争,已经纷纷开始砸自己的饭碗了,我们砸的慢了,这套算法系统的钱都赚不到了。
大概就是所谓的,资本家为了利润会出售绞死自己的绳索。
-
回复了主题帖:
有谁和我一样,动过给产品加AI的心思?
soso 发表于 2024-7-24 17:48
这个尴尬了,自己给自己优化了。
没办法,你不优化自己,别人会优化你,等到了那天这套优化系统的钱都赚不到了
- 2024-07-24
-
回复了主题帖:
有谁和我一样,动过给产品加AI的心思?
本帖最后由 tinnu 于 2024-7-24 14:25 编辑
有了,我们业务做产线的,其中有个环节是测定产品某个参数。
然后现在通过机器学习算法可以从前面环节的数据推出这个环节的数据,然后所有产线的这个环节被客户砍掉了。
辛辛苦苦把这个功能做出来,公司营收反而下降了,由于这个环节是业务里面的一个大头,年终奖可能发不出来了。
- 2024-07-11
-
回复了主题帖:
Vllink Basic2 高速无线调试器开箱与初次使用教程
原来是AIC啊,之前一直以为是esp32c6
- 2024-06-01
-
回复了主题帖:
#AI挑战营终点站# RV1106 RKNN板端推理 + RTSP/LCD显示
对于白底黑字环境,只需要将原本的灰度图反相即可:
cropped = 255 - cropped;
效果:
[localvideo]00e3567b84b579265c44a38f88dcdebd[/localvideo]
- 2024-05-30
-
回复了主题帖:
【AI挑战营第二站】算法工程化部署打包成SDK
什么是ONNX模型:ONNX是微软和 Facebook 牵头定制的一种深度学习框架,比起pytorch这类框架来说,更加轻量。但比起RKNN来说更加通用,本身针对ARM、RISCV通用的加速指令集有所优化,而有很多不同架构的MPU也会专门针对ONNX优化。
什么是RKNN模型:RKNN是瑞芯微专门针对旗下MPU内置的NPU优化的神经网络加速推理框架,能够充分调用NPU资源加速推理深度神经网络模型,支持大部分常见的深度学习算子。瑞芯微提供了RKNN-TOOLKIT工具,可以通过简单的命令,把pytorch、ONNX等常见框架的模型转化为RKNN模型,以移植到瑞芯微带NPU的芯片上运行。
https://bbs.eeworld.com.cn/thread-1283043-1-1.html
#AI挑战营第二站# Ubuntu下训练pytorch模型导出ONNX及RKNN转化
- 2024-05-28
-
回复了主题帖:
#AI挑战营终点站# RV1106 RKNN板端推理 + RTSP/LCD显示
补充使用LCD显示的效果:
[localvideo]87031850c77a766829e8fbba63290349[/localvideo]
-
回复了主题帖:
#AI挑战营终点站# RV1106 RKNN板端推理 + RTSP/LCD显示
Jacktang 发表于 2024-5-28 07:44
为什么会出现由于编译的库和板端库不兼容导致呢
可能是luckfox官方镜像放的库太旧了吧,例程用的库更新了,但是没有放进镜像里面(也可能是反过来
- 2024-05-27
-
回复了主题帖:
【AI挑战营终点站】应用落地:部署手写数字识别应用到幸狐RV1106开发板
完成打卡
https://bbs.eeworld.com.cn/thread-1283051-1-1.html
#AI挑战营终点站# RV1106 RKNN板端推理 + RTSP/LCD显示
-
发表了主题帖:
#AI挑战营终点站# RV1106 RKNN板端推理 + RTSP/LCD显示
## 环境准备
- 下载
- [rk linux 升级工具](https://files.luckfox.com/wiki/Core3566/upgrade_tool_v2.17.zip)
- [升级工具网盘](https://pan.baidu.com/s/1Mhf5JMpkFuZo_TuaGSxBYg?pwd=2sf8)
- linux下烧录
- `sudo ./upgrade_tool uf uckfox_pico_pro_max_image/update.img`
- 烧录完后发现系统uname -a是一样的……其实不用更新,出厂就是最新了
- SDK
- [luckfox-pico SDK](https://gitee.com/LuckfoxTECH/luckfox-pico/tree/main)
- gdbserver
- 一开始自己编译了一个 gdbserver ,结果运行不起来,然后在SDK里面一顿搜索,发现如下内容
```shell
# Enable build gdb and gdbserver debug tool
CONFIG_SYSDRV_ENABLE_GDB=y
$(eval $(call MACRO_CHECK_ENABLE_PKG, RK_ENABLE_GDB))
```
- 直接在终端运行 gdbserver ,发现果然自带了! 再一看,发现甚至自带了python3.11 !
## RKNN调用
### 寻找可用的识别例程序
- RKNN调用方式自然要从官方教程借鉴过来
> cv工程师当然要多用cv,程序员的事怎么能叫抄呢?
- 在 rknn-toolkit2 仓库有许多教程,但像yolo这种是面向识别和分割的,不太合适。其中 mobilenet 的demo同为分类任务:rknpu2/examples/rknn_mobilenet_demo/src/main.cc
- 修改 CMakeLists.txt 的 CMAKE_C_COMPILER 为你的GCC路径
- 直接使用该工程编译出来的程序其实可以直接运行,但必须把生成的lib库也一同拷下去,不然就要手动替换oem库,否则会报错,详见下方
```shell
mkdir build
cd build
cmake ..
make -j6
make install
cd ../rknn_mobilenet_demo
sshpass -p "luckfox" scp -r * root@10.37.49.129:/root
```
- 运行 ./rknn_mobilenet_demo model8.rknn pic/4.jpg
- 其中 4.jpg 是从上面转化 dataset 时从mnist数据集拷出来的测试集文件
## 分析例程、修改与封装
- RKNN初始化:`ret = rknn_init(&ctx, model_path, 0, 0, NULL);`
- 获取RKNN模型信息:`rknn_query`
1. RKNN_QUERY_SDK_VERSION
2. RKNN_QUERY_IN_OUT_NUM
3. RKNN_QUERY_INPUT_ATTR
4. RKNN_QUERY_NATIVE_OUTPUT_ATTR
- 其中后面两个是对输出输出数据体初始化,必须封装在推理函数中,上面两个只需要初始化一次即可。
- 将摄像头的数据转化为cv::Mat类型,再传入识别,进行函数封装:
- 获取摄像头数据流:`void *data = RK_MPI_MB_Handle2VirAddr(stVpssFrame.stVFrame.pMbBlk);`
- 转化为cv::Mat:`cv::Mat frame(height, width, CV_8UC3, data);`
- 自行封装RKNN推理函数:`int TiRknnForward(const cv::Mat frame, rknn_input_output_num io_num, std::vector &oresult)`
- 推理
- rknn_query RKNN_QUERY_INPUT_ATTR 初始化 input_attrs
- rknn_query RKNN_QUERY_NATIVE_OUTPUT_ATTR 初始化 output_attrs
- rknn_create_mem 开辟缓存
- rknn_set_io_mem 设置输入、输出缓存->NPU
- rknn_run 推理
- rknn_query RKNN_QUERY_OUTPUT_ATTR 获取结果
- 处理输出结果
- 解码结果数据 `rknn_GetTopN`
- 将结果保存到vector结构体里,返回:`oresult.push_back(std::make_pair(MaxClass[j], fMaxProb[j]));`
- rknn_destroy_mem 注销缓存
- 我移植后运行的效果,相关代码使用仓库readpic分支(或者回退到最初分支也可以):
-
## 性能优化
- 直接识别从mnist里面取出来的数据,效果还是可以的,说明模型和RKNN部分调用本身没有问题。但是使用摄像头直接传输数据的时候,基本无法识别。猜测可能是由于周边环境与训练时有差异。
- 训练的数据是中心白,周围全黑,于是我通过获取 advanced morphological transformations 找最大的白块轮廓,把其他区域置为全黑。具体操作如下:
1. 将图片转为灰度图
2. 设定阈值,将190以内的区域全黑。
3. 找到最大轮廓
4. 过滤满足条件的轮廓
5. 轮廓以外区域全黑
- 找到轮廓和全黑图后返回,把合适的区域切割出来,添加1.5倍黑边,然后进行RKNN推理。
- 在PC上测试一下效果:
-
- 转换后:
-
## 显示支持
### RTSP支持
- 参考 [luckfox_pico_rtsp_opencv](https://github.com/luckfox-eng29/luckfox_pico_rtsp_opencv.git) 工程,使用 rkmpi rtsp vi 等库实现,代码低耦合,无需赘述
### LCD支持
- 使用SPI进行用户态驱动 ili9341 屏幕
- ![引脚分布](https://wiki.luckfox.com/zh/assets/images/LUCKFOX-PICO-PROMAX-GPIO-6dbf2d0d09106289a1d07098e39504c2.jpg)
- LCD
- SPI连接 12~16
- RST连接 5 GPIO1_C6
- DC连接 4 GPIO1_C7
- LED使能找个3.3V上拉
- 程序仓库:见下,`读图测试`自带LCD现实,`RTSP流测试`如需LCD现实,需要回退一个提交
- 效果
## 程序仓库
- 读图测试(注意切换readpic分支):https://gitlink.org.cn/tinnu/mpu_rv1106_rknn/tree/readpic
- RTSP流测试(主分支):https://gitlink.org.cn/tinnu/mpu_rv1106_rknn
- 效果:[视频演示](https://v.youku.com/v_show/id_XNjM5MzcxNTM2NA==.html "视频演示")
[localvideo]83f07325d2fc5ed6a714a6aec42ef5d5[/localvideo]
## 问题处理: failed to decode config data!
- 板端运行报错。
- 这个问题是由于编译的库和板端库不兼容导致的
- 检查板端库大小
```shell
# ls /oem/usr/lib/librga.so -al
-rw-rw-r-- 1 1002 1002 154244 Nov 16 2023 /oem/usr/lib/librga.so
# ls /oem/usr/lib/librknnmrt.so -al
-rw-r--r-- 1 1002 1002 141048 Nov 16 2023 /oem/usr/lib/librknnmrt.so
```
- 检查编译使用库大小
- 两个位置都有这个库,看情况用的应该是 3rdparty 下面那个
./3rdparty/rknpu2/Linux/armhf-uclibc/librknnmrt.so
./tinnu-rv1106-rknn/lib/librknnmrt.so
- ls -al
```shell
-rwxr-xr-x 1 tinnu tinnu 166732 5月 26 19:18 librga.so
-rwxr-xr-x 1 tinnu tinnu 190368 5月 26 19:18 librknnmrt.so
```
- 把编译的这个库拷贝到板端 /oem/usr/lib/ 下替换掉原本库
-
发表了主题帖:
#AI挑战营第二站# Ubuntu下训练pytorch模型导出ONNX及RKNN转化
本帖最后由 tinnu 于 2024-5-27 19:21 编辑
## rknn 环境搭建
> 由于我使用的是markdown编辑器,发表帖子后缩进消失,代码请直接查看附件或仓库
### conda安装
- 我所使用的ubuntu kylin 20.04.3 LTS 实体机系统。python版本为3.8,虽然也使用conda托管,但还是沿用3.8.10经典版本。
```shell
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
```
### 创建conda环境
```shell
export PATH="/home/anaconda3/bin":$PATH
source /home/anaconda3/bin/activate
conda create -n torch1.10_rknn2 python=3.8
source /home/anaconda3/bin/activate torch1.10_rknn2
conda activate torch1.10_rknn2
```
### introduct
- Rockchip 的 RKNN(Rockchip Neural Network) 是一款 AI 推理框架,能够在 Rockchip 的 SoC 上调用芯片内部的npu加速单元。它提供了两个工具:RKNN-Toolkits 和 RKNPU。
- RKNPU 工具提供了运行时和一些demo,其中包含一些已经转化为.rknn的模型;在rknn2组件中,运行时是一个rknn_server程序,而rknn1中不需要该运行时支持,仅提供了驱动npu的ko文件,可以用于更新npu的驱动。不过出厂镜像默认支持,这里就不作更新。
- 另一个工具是 RKNN-Toolkits,这是一组转换模型工具,可以用于将常见的深度学习模型转换为 RKNN 支持的模型格式。RKNN-Toolkits 提供了一些脚本和 API,可以方便地将 Caffe、TensorFlow 等深度学习框架的模型转换为 RKNN 支持的格式,并将其部署到 RKNN 平台上进行推理运算。
- RKNN又分为RKNN1和RKNN2,RKNN1 是基于 NPUv1 架构实现的,而 RKNN2 是基于新一代 NPUv2 架构实现的。其两者SDK不兼容。我们要使用的RV1106的RKNPU是基于RKNN2,而用于对比的RV1126则是基于RKNN1。
### 安装 rknn_toolkit2
- 移除其他版本
- pip3 uninstall rknn_toolkit
- pip3 uninstall rknn_toolkit2
- 最新版本的 rknn_toolkit2 已经是1.6.0版本。这个版本在 rknn_toolkit2 仓库里面集成了 rknn_toolkit2 rknn_toolkit2_lite 和 rknpu,只需要下着一个即可。
```shell
git clone https://github.com/rockchip-linux/rknn-toolkit2.git --depth=1
pip3 install -r rknn-toolkit2/packages/requirements_cp38-1.6.0.txt
pip3 install rknn-toolkit2/packages/rknn_toolkit2-1.6.0+81f21f4d-cp38-cp38-linux_x86_64.whl
pip3 list | grep rknn
```
## 转化模型
- 之前帖子演示了 pytorch->onnx 模型转化,但这个转出来的模型进行 onnx->rknn ,会报错
```shell
The input shape ['batch_size', 28, 28] of 'input' is not support!
Please set the 'inputs' / 'input_size_list' parameters of 'rknn.load_onnx', or set the 'dyanmic_input' parameter of 'rknn.config' to fix the input shape!
```
```py
input_example = torch.randn(1, 28, 28)
model.eval()
torch.onnx.export(
model, # The model to be exported
input_example, # Example input data
"model/model-3.onnx", # Output file path
opset_version=19, # ONNX operator set version (choose a suitable version based on your PyTorch version and model requirements)
do_constant_folding=True, # Whether to execute constant folding for optimization
input_names=["input"], # The model's input names
output_names=["output"], # The model's output names
dynamic_axes={
"input": {0: "batch_size"}, # Make batch size dynamic (allow different batch sizes at inference time)
"output": {0: "batch_size"} # Similarly, make output batch size dynamic
},
)
```
- 经过一番处理,终于发现是 pytorch->onnx 转化时 dyanmic_input 的锅,删掉即可:
```py
input_example = torch.randn(1, 1, 28, 28)
model.eval()
torch.onnx.export(
model, # The model to be exported
input_example, # Example input data
"../3.tran/model8.onnx", # Output file path
opset_version=11, # ONNX operator set version (choose a suitable version based on your PyTorch version and model requirements)
do_constant_folding=True, # Whether to execute constant folding for optimization
)
```
- onnx->rknn 转化脚本
```py
# %%
from rknn.api import RKNN
# 创建RKNN对象
rknn = RKNN()
rknn.config(mean_values=[[128]], std_values=[[128]], target_platform='rv1106')
# %% Load model
import os
ONNX_MODEL = 'model8.onnx'
RKNN_MODEL = 'model8.rknn'
if not os.path.exists(ONNX_MODEL):
print('NO ONNX model...')
else:
print('--> Loading model')
ret = rknn.load_onnx(model=ONNX_MODEL)
if ret != 0:
print('Load model failed!')
exit(ret)
print('done')
# %% Build model
print('Building RKNN model...')
rknn.build(do_quantization=True, dataset='subset_data.csv')
# %% Export rknn model
print('--> Export rknn model')
ret = rknn.export_rknn(export_path=RKNN_MODEL, cpp_gen_cfg=True)
if ret != 0:
print('Export rknn model failed!')
exit(ret)
print('done')
```
- 但这里有个问题,由于只有3588才支持不量化(do_quantization=True),必须量化的话,需要 dataset 文件,dataset生成需要两步,我把 datasetfile_com.py 脚本放进附录和工程里面。
1. 转化数据集为图片文件
2. 生成 dataset.csv
- 转化结束会报一些错误,但能征程生成rknn模型,暂时不管。
- 转化后还可以评估模型,这里不赘述
- [mpu_rv1106_rknn](https://gitlink.org.cn/tinnu/mpu_rv1106_rknn) 工程里面model文件夹有所有的转化脚本文件:
|文件|功能|
|-|-|
|train_3.py |训练模型,生成onnx|
|translate_1.py |将onnx转化为rknn|
|datasetfile_com.py |转化出 dataset 文件
|subset_data.csv |dataset 文件|
|model8.rknn |最终生成的rknn模型|
|model8.onnx |训练生成的onnx模型|
### 问题处理:
- 进行 onnx->rknn 模型转化,会报错:
```shell
The input shape ['batch_size', 28, 28] of 'input' is not support!
Please set the 'inputs' / 'input_size_list' parameters of 'rknn.load_onnx', or set the 'dyanmic_input' parameter of 'rknn.config' to fix the input shape!
```
- 这是由于 pytorch->onnx 转化时,设置了 dyanmic_input 的问题,去掉 dyanmic_input 即可成功
- 2024-05-09
-
回复了主题帖:
入围名单公布:嵌入式工程师AI挑战营(初阶),获RV1106 Linux 板+摄像头的名单
本帖最后由 tinnu 于 2024-5-9 16:11 编辑
个人信息已确认,领取板卡,可继续完成&分享挑战营第二站和第三站任务。
- 2024-04-13
-
回复了主题帖:
【AI挑战营第一站】模型训练:在PC上完成手写数字模型训练,免费申请RV1106开发板
本帖最后由 tinnu 于 2024-4-13 18:42 编辑
1、跟帖回复:用自己的语言描述,模型训练的本质是什么,训练最终结果是什么
单纯针对狭义的神经网络模型来说:
训练的本质:就是通过计算出来的偏差,反向调整模型内部的各种参数,使之不断向偏差减小的方向变化,最终实现输出结果与用户预想的一致。
训练最终结果:得到一个充分调整过内部参数的神经网络模型。并且这个模型被期望:在随后的测试中,输入测试样本,可以得到一个预设结果。
2、跟帖回复:PyTorch是什么?目前都支持哪些系统和计算平台?
PyTorch:一个基于python深度学习框架,提供各类深度学习算子运算,并基于此实现各类深度学习模型。类似的框架还有caffe、tensorflow、keras、darknet等,这些都是瑞芯微rknn支持转化的框架。还有国产框架诸如paddlepaddle,但目前不受支持,无法直接转化。
PyTorch系统支持:windows linux macos
PyTorch系统支持:Python、C++、Java
PyTorch目前支持在纯CPU模式或者GPU加速下运行。
CPU支持:X86、ARM在内的各种CPU架构。
GPU支持:英伟达的CUDA平台,AMD的ROCm、HIP框架。第三方有各类厂商自发的支持,GPU方面有国产之光摩尔线程的musa,github上有musa for pytorch的仓库,具体可以看我下面的帖子的介绍。
GPGPU支持:比如谷歌的TPU、国内各种GPGPU,但他们都有各种问题,其中TPU无法支持GPU那么多算子,而GPGPU的驱动和pytorch支持包基本不会对外开放,个人用户是无法获取的。
具体参考Pytorch官网:pytorch
3、动手实践:
动手实践:#AI挑战营第一站# pytorch环境+minst数据集训练
-
发表了主题帖:
#AI挑战营第一站# pytorch环境+minst数据集训练
本帖最后由 tinnu 于 2024-4-13 18:25 编辑
- 安装显卡加速支持
1. 英伟达
- 安装cuda
- [cuda](https://developer.nvidia.com/cuda-toolkit-archive)
- [cudnn](https://developer.nvidia.com/cudnn-downloads?target_os=Windows)
- 装完之后在命令行输入 nvcc --version 没有报错即通过
```shell
> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:36:15_Pacific_Daylight_Time_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
```
2. AMD
- 安装ROCm
- [linux](https://github.com/ROCm/ROCm)
- [windows](https://www.amd.com/zh-cn/developer/resources/rocm-hub/hip-sdk.html)
3. 国产之光摩尔线程
- 安装musa
- https://developer.mthreads.com/sdk/download/musa?equipment=&os=Ubuntu&driverVersion=&version=
- PS:且不说摩尔线程硬件用的是IMG的IP,到底能不能算国产之类的话题,但这么多所谓自主的显卡,你就说他能不能pytorch吧,起码软件方面MTT是在脚踏实地地干活。
4. CPU
- 完全不用管
- 安装python
- pip3 install python3
- pip3 install matplotlib
- 根据pytorch官网提示安装pytorch
- https://pytorch.org/get-started/locally/
1. CPU安装
- win下纯cpu是这个命令:
- pip3 install torch torchvision torchaudio
2. 英伟达安装
- windows下安装:
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
3. AMD显卡安装
- LINUX+ROCm
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7
4. 摩尔线程显卡安装
- 按照摩尔线程开源仓库步骤安装:https://github.com/MooreThreads/torch_musa
- 因为我的显卡是N年前的MX130,装完之后发现跑得比CPU还慢,而且满三倍,于是选择纯CPU安装;安装完后一跑,发现下载贼久,因为minst的数据集在……你懂的,所以我觉得找个外面的云平台可能会比较快。
- 找来找去,能用的云端免费平台也就大名鼎鼎的kaggle,不过虽然能用,但注册的时候还是存在你懂的环节,自行解决不再赘述。所以只要用云平台,上面这些繁琐的步骤其实统统不用……
# 训练
- 首先加载一些程序库
- torch
- torchvision
- matplotlib
```py
import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
```
1. 加载数据集
- python 下可以通过 torchvision 库直接从网络下载数据集到本地,然后自动加载,前面说下载贼旧就是因为这个minst数据集。
- pytorch 里面加载数据有一套范式,一般是通过 torch.utils.data.DataLoader ,可以控制在训练的时候控制每一轮输出的数据量
- 以加载训练集为例:
```py
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
"./data/",
train=True,
download=True,
transform=torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
]
),
),
batch_size=batch_size_train,
shuffle=True,
)
```
2. 加载数据集
- minst 可以应用非常典型的网络结构,比如 两个卷积->全连接层->relu->droupout->全连接层
```py
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
```
- 配置一下优化器,这里可选收敛比较快的SGD,也可以用 Adam 反正minst训练也很快
- Adam收敛速度:
- 损失函数,这里没有定义损失函数,是直接在训练里面用 nll_loss (Negative Log Likelihood Loss) 说到底就是向量对数
3. 训练
```py
def train(epoch):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
```
- 训练也很简单,非常标准的流程:数据对齐->前向传播->计算损失->反向传播->优化器计算
- 5条语句,5个功能,完事
4. 测试
```py
def test():
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
pred = output.data.max(1, keepdim=True)[1]
```
5. 显示
```py
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
with torch.no_grad():
output = network(example_data)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(example_data[0], cmap="gray", interpolation="none")
plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1].item()))
plt.xticks([])
plt.yticks([])
plt.show()
```
6. 模型导出
```py
torch.save(network.state_dict(), "./model.pth")
torch.save(optimizer.state_dict(), "./optimizer.pth")
```
7. onnx转化
```py
def export_to_onnx(model, input_example, output_path="model.onnx"):
model.eval()
torch.onnx.export(
model,
input_example,
output_path,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
)
print("Model exported to ONNX format successfully.")
input_example = torch.randn(1, 28, 28)
export_to_onnx(network, input_example=input_example)
```