在Java中如何通过原始的模型依赖库完成模型的加载使用

1、ONNX是什么?

ONNX是一种开放的文件格式标准,专门用来存储和交换神经网络模型。就像 .jpg 是图片的格式,.mp4 是视频文件的格式一样,.onnx 就是 AI 模型的格式,里面规定了模型结构(算子、层)和参数(权重),不依赖任何特定框架。

Java模型

ONNX 则定义了一套统一的“普通话”格式(.onnx 文件)。它规定了如何描述一个神经网络的结构和权重。一个模型一旦被导出为 ONNX 格式,就相当于学会了说“普通话”,任何支持 ONNX 的“运行时”都能读懂它。

2、ONNX Runtime是什么?

在学习ONNX Runtime之前需要先了解推理引擎的概念,推理引擎类似在数据库中的SQL执行引擎,先以数据库系统中SQL执行引擎为例介绍引擎的用途。我们在执行一个SQL时会在数据库服务中发生什么呢?首先会校验连接信息,验证当前用户与权限信息;解析当前执行的SQL语句生成语法查询树后会生成最优执行计划交由执行器去按照执行计划执行最终返回结果。另外:MySql的执行引擎是针对MySql语句做出的工作,PostgreSql的执行引擎是针对PostgreSql语句做出的工作。

同理,在AI推理系统中也有对应的概念推理引擎。ONNX Runtime就是一种推理引擎,是专门用来高效执行.onnx文件的推理引擎。推理引擎的工作流程是这样的:

加载:加载模型文件;

优化:解析模型文件,进行算子融合、内存分配优化、常量折叠等工作;

调用:按照模型要求的格式把数据通过引擎的API喂给模型;

响应:引擎按照模型的输出结果响应,拿到响应结果之后转换成有意义的结果。

如果把 .onnx 文件比作一个“加密的智能合约”,那 ONNX Runtime 就是那个能解读并执行这份合约的“虚拟机”。

ONNX Runtime 提供了一个 Java API 包。我们在 Java 项目中引入这个包,就可以用 Java 代码来加载 .onnx 文件、喂入数据、拿回结果。

3、onnx格式的模型文件的前世

在大多数时候模型文件的训练都是在Python侧来完成,一个模型文件的产生流程大致是这样的:

在Python侧做训练

准备训练数据 -> 通过PyTorch训练框架进行训练 -> 得到产物model.pt文件 -> 通过Python进行格式转化得到model.onnx文件,得到转换成标准的onnx文件之后其他程序可以拿来做模型相关的任务。

在Java侧做推理

拿到onnx格式的模型文件之后可以在Java侧来做推理,通过onnxRuntime库完成onnx格式的模型文件加载使用。

4、使用ONNX Runtime

引入依赖 

 <dependency>
      <groupId>com.microsoft.onnxruntime</groupId>
      <artifactId>onnxruntime</artifactId>
      <version>1.20.0</version>
</dependency>

创建环境

OrtEnvironment environment = OrtEnvironment.getEnvironment();

加载模型

try(OrtSession.SessionOptions options = new OrtSession.SessionOptions();){
    OrtSession session = environment.createSession(modelPath, options);
    ...
}catch (Exception e) {
    throw new RuntimeException(e);
}

查看模型所需数据(输入、输出)

System.out.println("========== 模型输入信息 ==========");
for (Map.Entry<String, NodeInfo> entry : session.getInputInfo().entrySet()) {
    TensorInfo info = (TensorInfo) entry.getValue().getInfo();
    System.out.printf("  名称: " + entry.getKey());
    System.out.printf("  形状: " +  Arrays.toString(info.getShape()));
    System.out.printf("  类型: " +  info.type);
    System.out.printf("  总元素数: " +  info.getNumElements());
}
System.out.println("========== 模型输出信息 ==========");
for (Map.Entry<String, NodeInfo> entry : session.getOutputInfo().entrySet()) {
    TensorInfo info = (TensorInfo) entry.getValue().getInfo();
    System.out.printf("  名称: " +  entry.getKey());
    System.out.printf("  形状: " +  Arrays.toString(info.getShape()));
    System.out.printf("  类型: " +  info.type);
    System.out.printf("  总元素数: " +  info.getNumElements());
}
System.out.println("==================================");
//output
// ========== 模型输入信息 ==========
//   名称: images
//   形状: [1, 3, 1024, 1024]
//   类型: FLOAT
//   总元素数: 3145728
// ========== 模型输出信息 ==========
//   名称: output0
//   形状: [1, 20, 21504]
//   类型: FLOAT
//   总元素数: 430080
// ==================================

准备张量数据(模拟)

// 动态从模型读取输入形状
String inputName = session.getInputNames().iterator().next();
NodeInfo inputNodeInfo = session.getInputInfo().get(inputName);
TensorInfo inputTensorInfo = (TensorInfo) inputNodeInfo.getInfo();
long[] inputShape = inputTensorInfo.getShape();
long numInputElements = inputTensorInfo.getNumElements();
float[] rawData = new float[(int) numInputElements];
//准备一些模拟的数据作为流程能够跑通的示范,后续可以将图片文件预处理为张量数据输入给模型
for (int i = 0; i < numInputElements; i++) {
    rawData[i] = (float) (i % 256) / 255.0f;
}
//将float数组通过OnnxTensor包装为张量数据
OnnxTensor inputTensor = OnnxTensor.createTensor(
        environment,
        FloatBuffer.wrap(rawData),
        inputShape
);

将将要输入的图片文件预处理为float[]

一张彩色图片本质上是三种颜色(RBG)的组合,一张彩色图片由三个通道组成,每个通道是一个二维矩阵,每个格子是一个像素值(0,255)。

图片预处理是连接图片世界和张量世界的关键步骤,模型的输入要求是float[],按照CHW 排布(Channel-Height-Width),即先放 R 通道全部像素,再放 G 通道,再放 B 通道。在电脑磁盘上存储的图片是按照 HWC 排布的,所以预处理的核心就是要把排布格式做出对应的转换,同时也要将图片缩放到模型要求的尺寸,将数据归一化到 [0, 1] 区间。

 private static float[] preprocessImage(BufferedImage image, int targetWidth, int targetHeight) {
        // --- 第 1 步:缩放到目标尺寸 ---
        BufferedImage resized = new BufferedImage(targetWidth, targetHeight, BufferedImage.TYPE_INT_RGB);
        Graphics2D g = resized.createGraphics();
        g.drawImage(image, 0, 0, targetWidth, targetHeight, null);
        g.dispose();
        System.out.println("[预处理-1] 图片已缩放至 " + targetWidth + "x" + targetHeight);
        // --- 第 2 步:逐像素提取 RGB,按 CHW 排布写入 float[] ---
        float[] tensorData = new float[3 * targetWidth * targetHeight];
        for (int y = 0; y < targetHeight; y++) {
            for (int x = 0; x < targetWidth; x++) {
                int rgb = resized.getRGB(x, y);
                int r = (rgb >> 16) & 0xFF;
                int gv = (rgb >> 8) & 0xFF;
                int b = rgb & 0xFF;
                // CHW 排布:channel 0=R, channel 1=G, channel 2=B
                int base = y * targetWidth + x;          // 像素在单通道中的偏移
                int channelStride = targetWidth * targetHeight; // 每个通道的像素数
                tensorData[0 * channelStride + base] = r / 255.0f; // R 通道
                tensorData[1 * channelStride + base] = gv / 255.0f; // G 通道
                tensorData[2 * channelStride + base] = b / 255.0f;  // B 通道
            }
        }
        System.out.println("CHW 排布转换完成,共 " + tensorData.length + " 个浮点数");
        System.out.println("第一个 R 值: " + tensorData[0]);
        System.out.println("第一个 G 值: " + tensorData[targetWidth * targetHeight]);
        System.out.println("第一个 B 值: " + tensorData[2 * targetWidth * targetHeight]);
        return tensorData;
    }

交给引擎执行推理

OrtSession.Result result = session.run(Map.of(inputName, inputTensor));

等待响应结果原始张量数据

Optional<OnnxValue> output0 = result.get("output0");
System.out.println(output0.get().getInfo().toString()); //TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 20, 21504])
TensorInfo outputInfo = outputTensor.getInfo();
long[] outputShape = outputInfo.getShape();
long numElements = outputInfo.getNumElements();
// 读取输出数据到 Java float 数组
float[] outputData = new float[(int) numElements];
outputTensor.getFloatBuffer().get(outputData);

拿到原始张量数据之后可以再去编写一段后处理的程序结合使用场景输出符合当前使用场景的真正结果。

后处理解析原始张量

private static List<Detection> postprocess(float[] output, long[] outputShape) {
    int numChannels = (int) outputShape[1];    // 20
    int numPredictions = (int) outputShape[2]; // 21504
    int numClasses = numChannels - 5;          // 15
    // --- 遍历所有候选框,计算 sigmoid 置信度 ---
    List<Detection> candidates = new ArrayList<>();
    for (int i = 0; i < numPredictions; i++) {
        float cx = output[0 * numPredictions + i];
        float cy = output[1 * numPredictions + i];
        float w = output[2 * numPredictions + i];
        float h = output[3 * numPredictions + i];
        // 找 15 个类别中置信度最高的
        int maxClassId = 0;
        float maxClassScore = Float.NEGATIVE_INFINITY;
        for (int c = 0; c < numClasses; c++) {
            float score = output[(4 + c) * numPredictions + i];
            if (score > maxClassScore) {
                maxClassScore = score;
                maxClassId = c;
            }
        }
        float angle = output[(4 + numClasses) * numPredictions + i];
        // --- 筛选---
        float confidence = maxClassScore;
        //CONF_THRESHOLD置信度的值可以根据需求进行调整,只保留大于执行度的检测结果
        if (confidence >= CONF_THRESHOLD) {
            Detection det = new Detection();
            det.cx = cx;
            det.cy = cy;
            det.w = w;
            det.h = h;
            det.angle = angle;
            det.classId = maxClassId;
            det.confidence = confidence;
            det.corners = getRotatedBoxCorners(cx, cy, w, h, angle);
            candidates.add(det);
        }
    }
    // --- NMS 去重 ---
    return nms(candidates);
}
//避免同一个物体北检测到多次导致会有多个框来描述这个物体,nms方法会保留置信度最高的那个框,去掉同一个物体中其他描述这个物体的框
private static List<Detection> nms(List<Detection> detections) {
    // 按置信度从高到低排序
    Collections.sort(detections, (a, b) -> Float.compare(b.confidence, a.confidence));
    List<Detection> result = new ArrayList<>();
    boolean[] suppressed = new boolean[detections.size()];
    for (int i = 0; i < detections.size(); i++) {
        if (suppressed[i]) continue;
        result.add(detections.get(i));
        for (int j = i + 1; j < detections.size(); j++) {
            if (suppressed[j]) continue;
            // 只有同类别才做去重
            if (detections.get(i).classId != detections.get(j).classId) continue;
            float iou = rotatedIoU(detections.get(i).corners, detections.get(j).corners);
            if (iou > NMS_THRESHOLD) {
                suppressed[j] = true;
            }
        }
    }
    return result;
}


上一篇: 已是第一篇
下一篇: 企业数字化转型痛点分析:二十条