Yolo Binding的相关代码部分
yolo-binding的代码说明

文件组织

YOLO_BINDING
│  .gitignore
│  Cargo.lock
│  Cargo.toml
│  LICENSE
│  readme.md
│
└─src
    │  lib.rs
    │
    ├─core
    │      export.rs
    │      load.rs
    │      mod.rs
    │      predict.rs
    │
    └─utils
            HarmonyOS_Sans_Regular.ttf
            mod.rs
            picture.rs

一个好的项目要从文件开始设计(虽然咱是蒟蒻,也要有成为牛犇的理想哇!)

  • core是主推理环节,同时实现了预处理,推理,后处理三部分,是核心部分
  • utils是对解析出的数据进行下一步解析,比如画框,打标签等,后续也是可以拓展的

代码相关

这里不会展示所有代码,完整代码可以去仓库查看

预处理

在这个环节,我使用了tch库中封装好的函数。

但这个环节会对原图像进行拉伸,归一化,建议输入正方形图片

本来写了蒙灰色底版,但是DynamicImageTensor没有解决,后续再说吧

整体代码类似,这里只展示一个:

pub fn load_one_image(image_path: &str) -> Result<Tensor, Box<dyn Error>> {
    let image_tensor = vision::image::load(Path::new(&image_path))?;
    let image = tch::vision::image::resize(&image_tensor, 640, 640)
        .unwrap()
        .unsqueeze(0)
        .to_kind(tch::Kind::Float)
        / 255.;
    Ok(image)
}

YOLO的输入张量形状:[X,3,640,640],其中X是图片数量

模型引入&推理

由于libtorch的封装相当好,我们直接调用torchscript模型就可以

但是注意之前torch_cuda.dll的坑。

因为是自用的原因,我们直接调用在编写的时候,为了尽可能避免报错,允许在用户设置为Gpu时仍然调用Cpu,这点后续会删除,避免误导。

模型加载

let device = if cuda == true {
    let mut libtorch_path = env::var("LIBTORCH").unwrap();
    libtorch_path.push_str(r"\lib\torch_cuda.dll");
    if Path::new(&libtorch_path).exists() {
        let path = CString::from_str(&libtorch_path).unwrap();
        unsafe {
            LoadLibraryA(path.as_ptr() as *const c_char);            }
            Device::cuda_if_available()
    } else {
        panic!(
            "No {} exist,please check your libtorch version or set 'cuda' false instead",
            &libtorch_path
        );
    }
} else {
    Device::Cpu
}; // device choiced
let model = CModule::load_on_device(Path::new(model_path), device).expect("load model failed");

YOLO导出的torchscript模型中是有元信息的,但是我没有在torch-rs中找到有关接口,因此:

二进制,启动!!!

这部分代码比较多,我不在这里展示,不过大概阐述一下原理, 我看了两个导出的模型,元信息是包在开始的九个0x5A和结尾的0x504B中,因此我们可以直接读入,然后进行解析

事实上,这个torchscript文件是一个.zip的文件,因此你可以用zip进行解压,但是通过字节读取,我们可以得到更快的速度。

只是我没有见识,因此在下一版,我会把后识别符修改成0x504B

2025/02/06T23:12在刷B站时发现的

模型推理反倒是最简单的了,因为接口一步到位,只要注意张量设备转移就好。

模型推理

let input = input.to_device(device);
let output = model
    .yolo_model
    .forward_ts(&[input])
    .expect("forward failed");

后处理 之 core

这部分是将YOLO推理出的矩阵进行变换,得到一个格式化数据的过程。

这里的代码编写时,对Tensor切分的逻辑放在了mod.rs中,后续会进行修改。

YOLO本身的输出张量形状时这样的:[X,Y,8400],其中X是图片数量,Yxywh+分类类型

为了得到好看的框,我们在下面引入两个参量:confidencethreshold

  • confidence控制物品可信度筛选
  • threshold控制NMS筛选

我们就有以下逻辑;

  1. 通过conf筛去置信度低的框
  2. 通过threshold筛去重复的框

偷懒(✿◡‿◡),代码黏在下面了:

置信度筛选

fn filter_confidence(
    tensor: &Tensor,
    confidence: f64,
) -> Vec<(i64, i64, i64, i64, i64, f64)> {
    //Tensor [84, 8400]
    let pred = tensor.transpose(1, 0); //Tensor [8400, 84]
    let (npreds, pred_size) = pred.size2().unwrap();
    let full_xywh = pred.slice(1, 0, 4, 1); //Tensor [8400, 4]
    let mut filtered_results = Vec::new();
    for index in 0..npreds {
        // iterate all predictions
        let pred = pred.get(index); // Tensor [84]
        let max_conf_index = pred
            .narrow(0, 4, pred_size - 4)
            .argmax(0, true)
            .double_value(&[0]) as i64;
        let max_conf = pred.double_value(&[max_conf_index + 4]);
        if max_conf > confidence {
            let class_index = max_conf_index;
            filtered_results.push((
                full_xywh.double_value(&[index, 0]).round() as i64,
                full_xywh.double_value(&[index, 1]).round() as i64,
                full_xywh.double_value(&[index, 2]).round() as i64,
                full_xywh.double_value(&[index, 3]).round() as i64,
                class_index,
                max_conf,
            ));
        }
    }
    filtered_results
}

NMS筛选

fn nms(
    mut boxes: Vec<(i64, i64, i64, i64, i64, f64)>,
    threshold: f64,
) -> Vec<(i64, i64, i64, i64, i64, f64)> {
    boxes.sort_unstable_by(|a, b| b.5.partial_cmp(&a.5).unwrap_or(std::cmp::Ordering::Equal));
    let mut suppressed = vec![false; boxes.len()];
    let to_xyxy = |x: i64, y: i64, w: i64, h: i64| {
        let (x, y, w, h) = (x as f64, y as f64, w as f64, h as f64);
        let x1 = x - w / 2.0;
        let y1 = y - h / 2.0;
        let x2 = x + w / 2.0;
        let y2 = y + h / 2.0;
        (x1, y1, x2, y2)
    };
    let compute_iou =
        |a: &(i64, i64, i64, i64, i64, f64), b: &(i64, i64, i64, i64, i64, f64)| -> f64 {
            let a_rect = to_xyxy(a.0, a.1, a.2, a.3);
            let b_rect = to_xyxy(b.0, b.1, b.2, b.3);
            // Calculate intersection area
            let inter_x1 = a_rect.0.max(b_rect.0);
            let inter_y1 = a_rect.1.max(b_rect.1);
            let inter_x2 = a_rect.2.min(b_rect.2);
            let inter_y2 = a_rect.3.min(b_rect.3);

            let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);

            let a_area = (a_rect.2 - a_rect.0) * (a_rect.3 - a_rect.1);
            let b_area = (b_rect.2 - b_rect.0) * (b_rect.3 - b_rect.1);

            let union_area = a_area + b_area - inter_area;

            if union_area == 0.0 {
                0.0
            } else {
                inter_area / union_area
            }
        };

    for i in 0..boxes.len() {
        if suppressed[i] {
            continue;
        }
        for j in (i + 1)..boxes.len() {
            if suppressed[j] {
                continue;
            }
            if boxes[i].4 != boxes[j].4 {
                continue;
            }
            let iou = compute_iou(&boxes[i], &boxes[j]);
            if iou > threshold {
                suppressed[j] = true;
            }
        }
    }
    boxes 
        .into_iter()
        .enumerate()
        .filter(|(i, _)| !suppressed[*i])
        .map(|(_, b)| b)
        .collect()
}

核心逻辑都是一些简单的循环,主要是协调各个API,使程序正常工作。

后处理 之 utils

这部分主要实现画框的逻辑

主要依赖于Rust的imageproc

首先是将xywh转成合适的xyxy形式的坐标

let (x, y, w, h) = (x as f64, y as f64, w as f64, h as f64);
let (width, height) = picture.dimensions();
let (x1, y1, x2, y2) = (
    (x - w / 2.) / 640. * width as f64,
    (y - h / 2.) / 640. * height as f64,
    (x + w / 2.) / 640. * width as f64,
    (y + h / 2.) / 640. * height as f64,
);
let (x1, y1, x2, y2) = (x1 as f32, y1 as f32, x2 as f32, y2 as f32);

之后是画框

draw_line_segment_mut(&mut picture, (x1, y1), (x2, y1), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x2, y1), (x2, y2), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x2, y2), (x1, y2), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x1, y2), (x1, y1), Rgba([255, 0, 0, 255]));

最后是画字

let scale = ab_glyph::PxScale::from(20.0);
draw_text_mut(
    &mut picture,
    Rgba([255, 0, 0, 255]),
    x1 as i32,
    y1 as i32,
    scale,
    &font,
    type_name,
        );

至此,我们就用Rust实现了基本的yolo-detection封装。


最后修改于 2025-02-07