机器学习-用远程树莓派摄像头做图片分类训练

  • 基于本地摄像头的图片分类模型训练时,由于采集样本使用的摄像头(本地摄像头)和识别所使用的摄像头(连接在树莓派上的摄像头)不同, 图片的背景也会有差异,导致在树莓派上的识别效果可能没那么好,而且下载好的模型还需要额外拷贝到树莓派(有点麻烦有木有~~)。
  • 是否可以直接使用树莓派的摄像头进行采样训练,然后训练好的模型也能自动下载到树莓派上呢? 本章介绍的使用远程树莓派摄像头的图片分类训练就可以做到这些。

准备阶段

  • 首先,把USB摄像头模块连接到树莓派的USB接口上。
图1

(图1)USB摄像头实物连接图


  • 树莓派接通电源,点击连接设备,正常连接树莓派。

    图2

    (图2)连接树莓派


  • 按顺序点击更多功能——机器学习——远程使用树莓派摄像头进行图片分类。

    图3

    (图3)进入图片分类界面

注:远程使用树莓派摄像头进行图片分类的功能需要古德微树莓派的版本在3.0.0及以上才可以使用,并且只有在正常连接树莓派获取树莓派IP地址后才可以使用

训练模型

  • 接下来开始训练模型,以“苹果”模型和“梨”模型为例。首先,进行图片采样和标注,将需要识别的物体放在背景下,再点击摄像头标志打开摄像头,点击拍照按钮进行采样(可以转动物体模型,多角度对物体进行采样,这样训练识别效果会更好,采样10-20张为宜)。采样结束后将类别名修改为目标类别。比如,上面放的苹果模型就标注为苹果,重复上述步骤采样及标注另外的物体模型。

    图4

    (图4)开启摄像头及标注类别

    图5

    (图5)拍照采样


  • 采样结束后,点击“开始训练”对样本图片进行训练。(训练需要一定时间)

    图6

    (图6)点击开始训练


  • 训练完成后右方可以查看效果预览,来验证训练模型识别的准确性。 我们通过置信度的高低判断拍摄物体与训练物体的相似程度,并以此来判断物体的类别。效果如下图:

    图7

    (图7)图片分类模型预览效果


  • 如果训练模型预览的效果较好(能够准确识别且识别率稳定在80%以上),那就可以将模型下载到树莓派上,以便在程序中调用。

    图8

    (图8)下载模型到树莓派

下载模型需要一定时间,下载完成后会提示模型文件和标签文件保存的地址。下载的目录下有同名文件时会将原有文件改名为文件名加时间后缀再将模型文件保存到目录。

图片分类案例

  • 如何使用图片分类模型进行识别

    • 模型训练完成并下载到树莓派后,调用以下积木,即可进行图片分类调用以下积木,即可进行图片分类。

      图7

      (图9)图片分类简单

    • 在加载图片分类模型的积木块中,分别输入要加载的模型文件和标签文件路径。将要识别的物体放在摄像头下,点击运行,在右侧调试区查看运行结果。

    • 点击这里下载本案例代码。

  • 图片分类应用案例:水果分类器

    • 案例简介:使用180舵机控制平板,当把苹果放到平板上时,舵机控制平板往一侧翻转;当把梨放到平板上时,舵机控制平板往另一侧翻转。
    • 代码截图:
      图8

      (图10)水果分类器代码

    • 效果演示:
      图9

      (图11)水果分类器效果演示

    • 点击这里下载本案例代码。
  • 使用python加载图片分类模型案例

    • 案例简介:使用python加载图片分类模型,实时对画面进行识别分类,将分类结果显示到画面中。
    • 效果演示
      图9

      (图12)水果分类识别演示

    • python代码

        import time
        import numpy as np
        import cv2
        from tflite_runtime.interpreter import Interpreter
        from PIL import Image, ImageDraw, ImageFont
        def load_labels(path):
            with open(path, 'r') as f:
                return {i: line.strip() for i, line in enumerate(f.readlines())}
        def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):
            if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
                img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            # 创建一个可以在给定图像上绘图的对象
            draw = ImageDraw.Draw(img)
            # 字体的格式
            fontStyle = ImageFont.truetype('/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', textSize)
            # 绘制文本
            draw.text((left, top), text, textColor, font=fontStyle)
            # 转换回OpenCV格式
            return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
        #加载水果分类模型
        labels_path = '/home/pi/model/image_classification/labels.txt'
        model_path = '/home/pi/model/image_classification/model.tflite'
        labels = load_labels(labels_path)
        print("#load model")
        interpreter = Interpreter(model_path = model_path)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        _, height, width, _ = input_details[0]['shape']
        try:
            cap = cv2.VideoCapture(0)
            while cap.isOpened():
                ret, new_img = cap.read()        
                origin_img = new_img.copy()
                new_img = cv2.cvtColor(new_img, cv2.COLOR_BGR2RGB)
                new_img = cv2.resize(new_img,(width,height))
                new_img = np.expand_dims(new_img, axis=0)
                start_time = time.time()
                interpreter.set_tensor(input_details[0]['index'], new_img)
                # 开始预测
                interpreter.invoke()   
                # 获取预测的结果
                output_data = np.squeeze(interpreter.get_tensor(output_details[0]['index']))        
                max_label_id = np.argmax(output_data)
                if output_details[0]['dtype'] == np.uint8:
                    scale, zero_point = output_details[0]['quantization']
                    output_data = scale * (output_data - zero_point)            
                elapsed_ms = (time.time() - start_time)
                origin_img = cv2ImgAddText(origin_img, '%s accuracy:%.2f fps:%d' % (labels[max_label_id], output_data[max_label_id], int(1/elapsed_ms)),30,30,(0,255,0), 40)
                cv2.imshow("frame", origin_img)
                if cv2.waitKey(1) == ord('q'):
                    break
        finally:
            cv2.destroyAllWindows()
      
Copyright © 古德微 2023 all right reserved,powered by GDWRobot本课修订时间: 2023-10-25

results matching ""

    No results matching ""