文件组织
YOLO_BINDING
│ .gitignore
│ Cargo.lock
│ Cargo.toml
│ LICENSE
│ readme.md
│
└─src
│ lib.rs
│
├─core
│ export.rs
│ load.rs
│ mod.rs
│ predict.rs
│
└─utils
HarmonyOS_Sans_Regular.ttf
mod.rs
picture.rs
一个好的项目要从文件开始设计(虽然咱是蒟蒻,也要有成为牛犇的理想哇!)
core
是主推理环节,同时实现了预处理,推理,后处理三部分,是核心部分utils
是对解析出的数据进行下一步解析,比如画框,打标签等,后续也是可以拓展的
代码相关
这里不会展示所有代码,完整代码可以去仓库查看
预处理
在这个环节,我使用了tch
库中封装好的函数。
但这个环节会对原图像进行拉伸,归一化,建议输入正方形图片
本来写了蒙灰色底版,但是
DynamicImage
转Tensor
没有解决,后续再说吧
整体代码类似,这里只展示一个:
pub fn load_one_image(image_path: &str) -> Result<Tensor, Box<dyn Error>> {
let image_tensor = vision::image::load(Path::new(&image_path))?;
let image = tch::vision::image::resize(&image_tensor, 640, 640)
.unwrap()
.unsqueeze(0)
.to_kind(tch::Kind::Float)
/ 255.;
Ok(image)
}
YOLO的输入张量形状:[X,3,640,640]
,其中X
是图片数量
模型引入&推理
由于libtorch
的封装相当好,我们直接调用torchscript
模型就可以
但是注意之前torch_cuda.dll
的坑。
因为是自用的原因,我们直接调用在编写的时候,为了尽可能避免报错,允许在用户设置为Gpu
时仍然调用Cpu
,这点后续会删除,避免误导。
模型加载
let device = if cuda == true {
let mut libtorch_path = env::var("LIBTORCH").unwrap();
libtorch_path.push_str(r"\lib\torch_cuda.dll");
if Path::new(&libtorch_path).exists() {
let path = CString::from_str(&libtorch_path).unwrap();
unsafe {
LoadLibraryA(path.as_ptr() as *const c_char); }
Device::cuda_if_available()
} else {
panic!(
"No {} exist,please check your libtorch version or set 'cuda' false instead",
&libtorch_path
);
}
} else {
Device::Cpu
}; // device choiced
let model = CModule::load_on_device(Path::new(model_path), device).expect("load model failed");
YOLO导出的torchscript
模型中是有元信息的,但是我没有在torch-rs
中找到有关接口,因此:
二进制,启动!!!
这部分代码比较多,我不在这里展示,不过大概阐述一下原理,
我看了两个导出的模型,元信息是包在开始的九个0x5A
和结尾的0x504B
中,因此我们可以直接读入,然后进行解析
事实上,这个
torchscript
文件是一个.zip
的文件,因此你可以用zip
进行解压,但是通过字节读取,我们可以得到更快的速度。只是我没有见识,因此在下一版,我会把后识别符修改成
0x504B
2025/02/06T23:12在刷B站时发现的
模型推理反倒是最简单的了,因为接口一步到位,只要注意张量设备转移就好。
模型推理
let input = input.to_device(device);
let output = model
.yolo_model
.forward_ts(&[input])
.expect("forward failed");
后处理 之 core
这部分是将YOLO推理出的矩阵进行变换,得到一个格式化数据的过程。
这里的代码编写时,对Tensor
切分的逻辑放在了mod.rs
中,后续会进行修改。
YOLO本身的输出张量形状时这样的:[X,Y,8400]
,其中X
是图片数量,Y
是xywh
+分类类型
为了得到好看的框,我们在下面引入两个参量:confidence
和threshold
confidence
控制物品可信度筛选threshold
控制NMS
筛选
我们就有以下逻辑;
- 通过
conf
筛去置信度低的框 - 通过
threshold
筛去重复的框
偷懒(✿◡‿◡),代码黏在下面了:
置信度筛选
fn filter_confidence(
tensor: &Tensor,
confidence: f64,
) -> Vec<(i64, i64, i64, i64, i64, f64)> {
//Tensor [84, 8400]
let pred = tensor.transpose(1, 0); //Tensor [8400, 84]
let (npreds, pred_size) = pred.size2().unwrap();
let full_xywh = pred.slice(1, 0, 4, 1); //Tensor [8400, 4]
let mut filtered_results = Vec::new();
for index in 0..npreds {
// iterate all predictions
let pred = pred.get(index); // Tensor [84]
let max_conf_index = pred
.narrow(0, 4, pred_size - 4)
.argmax(0, true)
.double_value(&[0]) as i64;
let max_conf = pred.double_value(&[max_conf_index + 4]);
if max_conf > confidence {
let class_index = max_conf_index;
filtered_results.push((
full_xywh.double_value(&[index, 0]).round() as i64,
full_xywh.double_value(&[index, 1]).round() as i64,
full_xywh.double_value(&[index, 2]).round() as i64,
full_xywh.double_value(&[index, 3]).round() as i64,
class_index,
max_conf,
));
}
}
filtered_results
}
NMS筛选
fn nms(
mut boxes: Vec<(i64, i64, i64, i64, i64, f64)>,
threshold: f64,
) -> Vec<(i64, i64, i64, i64, i64, f64)> {
boxes.sort_unstable_by(|a, b| b.5.partial_cmp(&a.5).unwrap_or(std::cmp::Ordering::Equal));
let mut suppressed = vec![false; boxes.len()];
let to_xyxy = |x: i64, y: i64, w: i64, h: i64| {
let (x, y, w, h) = (x as f64, y as f64, w as f64, h as f64);
let x1 = x - w / 2.0;
let y1 = y - h / 2.0;
let x2 = x + w / 2.0;
let y2 = y + h / 2.0;
(x1, y1, x2, y2)
};
let compute_iou =
|a: &(i64, i64, i64, i64, i64, f64), b: &(i64, i64, i64, i64, i64, f64)| -> f64 {
let a_rect = to_xyxy(a.0, a.1, a.2, a.3);
let b_rect = to_xyxy(b.0, b.1, b.2, b.3);
// Calculate intersection area
let inter_x1 = a_rect.0.max(b_rect.0);
let inter_y1 = a_rect.1.max(b_rect.1);
let inter_x2 = a_rect.2.min(b_rect.2);
let inter_y2 = a_rect.3.min(b_rect.3);
let inter_area = (inter_x2 - inter_x1).max(0.0) * (inter_y2 - inter_y1).max(0.0);
let a_area = (a_rect.2 - a_rect.0) * (a_rect.3 - a_rect.1);
let b_area = (b_rect.2 - b_rect.0) * (b_rect.3 - b_rect.1);
let union_area = a_area + b_area - inter_area;
if union_area == 0.0 {
0.0
} else {
inter_area / union_area
}
};
for i in 0..boxes.len() {
if suppressed[i] {
continue;
}
for j in (i + 1)..boxes.len() {
if suppressed[j] {
continue;
}
if boxes[i].4 != boxes[j].4 {
continue;
}
let iou = compute_iou(&boxes[i], &boxes[j]);
if iou > threshold {
suppressed[j] = true;
}
}
}
boxes
.into_iter()
.enumerate()
.filter(|(i, _)| !suppressed[*i])
.map(|(_, b)| b)
.collect()
}
核心逻辑都是一些简单的循环,主要是协调各个API,使程序正常工作。
后处理 之 utils
这部分主要实现画框的逻辑
主要依赖于Rust的imageproc
库
首先是将xywh
转成合适的xyxy
形式的坐标
let (x, y, w, h) = (x as f64, y as f64, w as f64, h as f64);
let (width, height) = picture.dimensions();
let (x1, y1, x2, y2) = (
(x - w / 2.) / 640. * width as f64,
(y - h / 2.) / 640. * height as f64,
(x + w / 2.) / 640. * width as f64,
(y + h / 2.) / 640. * height as f64,
);
let (x1, y1, x2, y2) = (x1 as f32, y1 as f32, x2 as f32, y2 as f32);
之后是画框
draw_line_segment_mut(&mut picture, (x1, y1), (x2, y1), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x2, y1), (x2, y2), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x2, y2), (x1, y2), Rgba([255, 0, 0, 255]));
draw_line_segment_mut(&mut picture, (x1, y2), (x1, y1), Rgba([255, 0, 0, 255]));
最后是画字
let scale = ab_glyph::PxScale::from(20.0);
draw_text_mut(
&mut picture,
Rgba([255, 0, 0, 255]),
x1 as i32,
y1 as i32,
scale,
&font,
type_name,
);
至此,我们就用Rust实现了基本的yolo-detection
封装。
最后修改于 2025-02-07