[STM32L5] 【STM32L562E-DK试用】用串口实现手写数字体识别

[复制链接]
703|0
 楼主| 傅沈骁 发表于 2025-3-1 15:55 | 显示全部楼层 |阅读模式
本帖最后由 傅沈骁 于 2025-3-1 15:54 编辑

虽然STM32L562推出有些年头了,但是它依然能支持CubeAI这样的边缘AI部署
本次测试基于b站教程:https://www.bilibili.com/video/BV1eg4y167G6/?spm_id_from=333.337.search-card.all.click&vd_source=30af65e26f8054bfda43260a9879957f
up主工程开源地址:https://github.com/colin2135/STM32G070_AI_TEST
上位机测试软件地址:https://github.com/colin2135/HandWriteApp

模型训练与保存

作为深度学习的入门教程,现在网上介绍MNIST手写数字体识别的教程已经很多了。这里贴一段用keras生成.h5文件的代码,不过为了和跟随up主的教程,我最后用了GitHub上的.tflite模型文件。

  1. from keras.datasets import mnist
  2. import matplotlib.pyplot as plt
  3. from keras.models import Sequential
  4. from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
  5. from keras.utils import np_utils
  6. import tensorflow as tf

  7. config = tf.compat.v1.ConfigProto()
  8. config.gpu_options.allow_growth = True
  9. sess = tf.compat.v1.Session(config=config)

  10. # 设定随机数种子,使得每个网络层的权重初始化一致
  11. # np.random.seed(10)

  12. # x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
  13. (x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()

  14. """
  15. 数据可视化
  16. """

  17. # 原始数据量可视化
  18. print('训练集图像的尺寸:', x_train_original.shape)
  19. print('训练集标签的尺寸:', y_train_original.shape)
  20. print('测试集图像的尺寸:', x_test_original.shape)
  21. print('测试集标签的尺寸:', y_test_original.shape)

  22. """
  23. 数据预处理
  24. """

  25. # 从训练集中分配验证集
  26. x_val = x_train_original[50000:]
  27. y_val = y_train_original[50000:]
  28. x_train = x_train_original[:50000]
  29. y_train = y_train_original[:50000]
  30. # 打印验证集数据量
  31. print('验证集图像的尺寸:', x_val.shape)
  32. print('验证集标签的尺寸:', y_val.shape)
  33. print('======================')

  34. # 将图像转换为四维矩阵(nums,rows,cols,channels), 这里把数据从unint类型转化为float32类型, 提高训练精度。
  35. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
  36. x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
  37. x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')

  38. # 原始图像的像素灰度值为0-255,为了提高模型的训练精度,通常将数值归一化映射到0-1。
  39. x_train = x_train / 255
  40. x_val = x_val / 255
  41. x_test = x_test / 255

  42. print('训练集传入网络的图像尺寸:', x_train.shape)
  43. print('验证集传入网络的图像尺寸:', x_val.shape)
  44. print('测试集传入网络的图像尺寸:', x_test.shape)

  45. # 图像标签一共有10个类别即0-9,这里将其转化为独热编码(One-hot)向量
  46. y_train = np_utils.to_categorical(y_train)
  47. y_val = np_utils.to_categorical(y_val)
  48. y_test = np_utils.to_categorical(y_test_original)

  49. """
  50. 定义网络模型
  51. """


  52. def CNN_model():
  53.     model = Sequential()
  54.     model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))
  55.     model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
  56.     model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))
  57.     model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
  58.     model.add(Flatten())
  59.     model.add(Dense(100, activation='relu'))
  60.     model.add(Dense(10, activation='softmax'))

  61.     print(model.summary())
  62.     return model


  63. """
  64. 训练网络
  65. """

  66. model = CNN_model()

  67. # 编译网络(定义损失函数、优化器、评估指标)
  68. model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

  69. # 开始网络训练(定义训练数据与验证数据、定义训练代数,定义训练批大小)
  70. train_history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=32, verbose=2)

  71. # 模型保存
  72. model.save('model.h5')

  73. # 定义训练过程可视化函数(训练集损失、验证集损失、训练集精度、验证集精度)
  74. def show_train_history(train_history, train, validation):
  75.     plt.plot(train_history.history[train])
  76.     plt.plot(train_history.history[validation])
  77.     plt.title('Train History')
  78.     plt.ylabel(train)
  79.     plt.xlabel('Epoch')
  80.     plt.legend(['train', 'validation'], loc='best')
  81.     plt.show()

  82. show_train_history(train_history, 'accuracy', 'val_accuracy')
  83. show_train_history(train_history, 'loss', 'val_loss')
安装CubeAI
1598767c2b17610b98.png
在CubeMX上方Software Packs下拉选择Select Components,选择其中的X-CUBE-AI
5006967c2b13b6562d.png
在左侧菜单栏选择Middleware and Software Packs,选择其中的X-CUBE-AI,导入模型并分析。如果这个模型过大,超过了flash的大小,可能还需要对模型进行压缩,并配置外部flash。
1556367c2b1f53bb3d.png
串口配置
观察开发板原理图可以发现,PA9和PA10可以做虚拟串口使用,对应的是UART1。
61267c2b25a2710a.png
开启UART1并设置为异步模式。由于需要串口收发,所以还要使能串口接收中断。
6731767c2b2bf6b7c5.png
最后使能DEBUG功能
2647967c2b31b28949.png
代码编写
首先包含相关头文件
  1. #include "stdio.h"
  2. #include "string.h"
  3. #include "ai_platform.h"
  4. #include "network.h"
  5. #include "network_data.h"
由于需要串口收发数据,因此需要对printf进行重定向
  1. #ifdef __GNUC__
  2. #define PUTCHAR_PROTOTYPE int __io_putchar(int ch)
  3. #else
  4. #define PUTCHAR_PROTOTYPE int fputc(int ch, FILE *f)
  5. #endif
  6. PUTCHAR_PROTOTYPE
  7. {
  8.     HAL_UART_Transmit(&huart1, (uint8_t *)&ch, 1, 0xFFFF);
  9.     return ch;
  10. }
定义AI模型相关参数,并声明后续使用到的一些函数
  1. ai_handle network;
  2. float aiInData[AI_NETWORK_IN_1_SIZE];
  3. float aiOutData[AI_NETWORK_OUT_1_SIZE];
  4. ai_u8 activations[AI_NETWORK_DATA_ACTIVATIONS_SIZE];

  5. ai_buffer * ai_input;
  6. ai_buffer * ai_output;

  7. static void AI_Init(void);
  8. static void AI_Run(float *pIn, float *pOut);
  9. void PictureCharArrayToFloat(uint8_t *srcBuf,float *dstBuf,int len);

  10. void Uart_send(char * str);
  11. #define UART_BUFF_LEN 1024
  12. #define ONE_FRAME_LEN 1+784+2
  13. uint16_t uart_rx_length = 0;
  14. uint8_t uart_rx_byte = 0;
  15. uint8_t uart_rx_buffer[UART_BUFF_LEN];
  16. volatile uint8_t goRunning = 0;
定义串口中断回调函数
  1. void HAL_UART_RxCpltCallback(UART_HandleTypeDef *UartHandle)
  2. {
  3.         if(goRunning ==0)
  4.         {
  5.                 if (uart_rx_length < UART_BUFF_LEN)
  6.                 {
  7.                         uart_rx_buffer[uart_rx_length] = uart_rx_byte;
  8.                         uart_rx_length++;

  9.                         if (uart_rx_byte == '\n')
  10.                         {
  11.                                 goRunning = 1;
  12.                         }
  13.                 }
  14.                 else
  15.                 {
  16.                         //rt_kprintf("rx len over");
  17.                         uart_rx_length = 0;
  18.                 }
  19.         }
  20.         HAL_UART_Receive_IT(&huart1, (uint8_t *)&uart_rx_byte, 1);
  21. }
定义串口发送函数
  1. void Uart_send(char * str)
  2. {
  3.         HAL_UART_Transmit(&huart1, (uint8_t *)str, strlen(str),0xffff);
  4. }
定义AI模型初始化函数
  1. static void AI_Init(void)
  2. {
  3.   ai_error err;

  4.   /* Create a local array with the addresses of the activations buffers */
  5.   const ai_handle act_addr[] = { activations };
  6.   /* Create an instance of the model */
  7.   err = ai_network_create_and_init(&network, act_addr, NULL);
  8.   if (err.type != AI_ERROR_NONE) {
  9.     printf("ai_network_create error - type=%d code=%d\r\n", err.type, err.code);
  10.     Error_Handler();
  11.   }
  12.   ai_input = ai_network_inputs_get(network, NULL);
  13.   ai_output = ai_network_outputs_get(network, NULL);
  14. }
定义AI模型运行函数
  1. static void AI_Run(float *pIn, float *pOut)
  2. {
  3.         char logStr[100];
  4.         int count = 0;
  5.         float max = 0;
  6.   ai_i32 batch;
  7.   ai_error err;

  8.   /* Update IO handlers with the data payload */
  9.   ai_input[0].data = AI_HANDLE_PTR(pIn);
  10.   ai_output[0].data = AI_HANDLE_PTR(pOut);

  11.   batch = ai_network_run(network, ai_input, ai_output);
  12.   if (batch != 1) {
  13.     err = ai_network_get_error(network);
  14.     printf("AI ai_network_run error - type=%d code=%d\r\n", err.type, err.code);
  15.     Error_Handler();
  16.   }
  17.   for (uint32_t i = 0; i < AI_NETWORK_OUT_1_SIZE; i++) {

  18.           sprintf(logStr,"%ld  %8.6f\r\n",i,aiOutData[i]);
  19.           Uart_send(logStr);
  20.           if(max<aiOutData[i])
  21.           {
  22.                   count = i;
  23.                   max= aiOutData[i];
  24.           }
  25.   }
  26.   sprintf(logStr,"current number is %d\r\n",count);
  27.   Uart_send(logStr);
  28. }
定义将串口收到的uint8_t类型数据转换为float类型函数
  1. void PictureCharArrayToFloat(uint8_t *srcBuf,float *dstBuf,int len)
  2. {
  3.         for(int i=0;i<len;i++)
  4.         {
  5.                 dstBuf[i] = srcBuf[i];//==1?0:1;
  6.         }
  7. }
主函数部分,需要完成外设初始化以及模型运行逻辑的书写
  1. int main(void)
  2. {

  3.   /* USER CODE BEGIN 1 */

  4.   /* USER CODE END 1 */

  5.   /* MCU Configuration--------------------------------------------------------*/

  6.   /* Reset of all peripherals, Initializes the Flash interface and the Systick. */
  7.   HAL_Init();

  8.   /* USER CODE BEGIN Init */

  9.   /* USER CODE END Init */

  10.   /* Configure the system clock */
  11.   SystemClock_Config();

  12.   /* USER CODE BEGIN SysInit */

  13.   /* USER CODE END SysInit */

  14.   /* Initialize all configured peripherals */
  15.   MX_GPIO_Init();
  16.   MX_ICACHE_Init();
  17.   MX_USART1_UART_Init();
  18.   /* USER CODE BEGIN 2 */
  19.   __HAL_RCC_CRC_CLK_ENABLE();
  20.   AI_Init();
  21.   memset(uart_rx_buffer,0,784);
  22.   HAL_UART_Receive_IT(&huart1, (uint8_t *)&uart_rx_byte, 1);
  23.   /* USER CODE END 2 */

  24.   /* Infinite loop */
  25.   /* USER CODE BEGIN WHILE */
  26.   while (1)
  27.   {
  28.     /* USER CODE END WHILE */

  29.     /* USER CODE BEGIN 3 */
  30.         if(goRunning>0)
  31.         {
  32.                 if(uart_rx_length == ONE_FRAME_LEN)
  33.                 {
  34.                         PictureCharArrayToFloat(uart_rx_buffer+1,aiInData,28*28);
  35.                         AI_Run(aiInData, aiOutData);
  36.                 }
  37.                 memset(uart_rx_buffer,0,784);
  38.                 goRunning = 0;
  39.                 uart_rx_length = 0;
  40.         }
  41.   }
  42.   /* USER CODE END 3 */
  43. }
最终实现效果如下,0-9的数字均能很好识别。
2737167c2bcc37e76e.png

您需要登录后才可以回帖 登录 | 注册

本版积分规则

5

主题

17

帖子

0

粉丝
快速回复 在线客服 返回列表 返回顶部