前言
这一章来学习如何使用 java 调用 python 机器学习模块,毕竟 python 在算法方法好用,但是做 web 项目还是 java 更优,最近有个项目想要集成机器学习算法,这里简单记录一下(默认使用idea开发工具,默认都会创建maven项目,不会自行百度)。
java 调用 python,分三步来学习:
- 第一步:java 调用 python 语句
- 第二步:java 调用 python 脚本
- 第三步:java 调用 python 脚本函数(如何传递参数)
- 第四步:java 调用 python 机器学习模块并运行
上面三步都需要调用 jython 库,两种加入项目方法:
- 从官网下载 jar 包,手动加入
- 在 pom 文件中直接加入依赖
1 2 3 4 5 |
<dependency> <groupId>org.python</groupId> <artifactId>jython-standalone</artifactId> <version>2.7.0</version> </dependency> |
1、java 调用 python 语句
首先要在idea中导入jython库,上面提到的两种方法,我使用的是第二种,即在 pom 文件中加入依赖,简单明了。
创建一个 javaRunPython 类,执行下列代码:
1 2 3 |
PythonInterpreter interpreter = new PythonInterpreter(); interpreter.exec("a=1+2; "); interpreter.exec("print(a);"); |
运行结果如下:
另外,我发现一个有趣的现象,无论是"print a",还是"print(a)",居然都没有报错,也就是说 jython 兼容支持 python2 和 python3 两种语法。
2、java 调用 python 脚本
2.1 PythonInterpreter 调用 python 脚本
首先要写一个 python 脚本,内容随意,下面是我写的:
1 2 |
hello = 'hello world, this is using java to pring python word' print(hello) |
然后在 maven 项目中,运行:
1 2 3 |
import org.python.util.PythonInterpreter; PythonInterpreter interpreter = new PythonInterpreter(); interpreter.execfile("E:\\pythonTest.py"); |
运行结果如下:
上面调用的只是普通的 python 脚本,如果脚本中导入了第三方库,还能不能运行呢?测试一下,写一个简单的生成矩阵 python 脚本:
1 2 3 4 5 6 |
print("sdafd") import numpy as np n = np.arange(0, 30, 2) n = n.reshape(3, 5) print(n) print("dafdafda") |
运行结果:
由上图可以看到,只执行了第一句,当导入第三方库的时候报错了,这个是因为在 jython 库中不存在 numpy 模块,自然会报错。
由此可以得出:PythonInterpreter 可以简单执行普通的 python 脚本,但是对于带有第三方库的 python 脚本就不行了。
2.2 使用 Runtime 调用 python 脚本(推荐)
先来个普通脚本:
1 2 |
hello = 'hello world, this is using java to pring python word' print(hello) |
文件名为 Runtime.py
Runtime 方法的 java 代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; Process proc; try { proc = Runtime.getRuntime().exec("python E:\\Runtime.py"); BufferedReader in = new BufferedReader(new InputStreamReader(proc.getInputStream())); String line = null; while ((line = in.readLine()) != null) { System.out.println(line); } in.close(); proc.waitFor(); } catch (IOException e) { e.printStackTrace(); } catch (InterruptedException e) { e.printStackTrace(); } |
执行结果和上面一样
再来看看带有第三方库的脚本:
1 2 3 4 5 6 |
print("sdafd") import numpy as np n = np.arange(0, 30, 2) n = n.reshape(3, 5) print(n) print("dafdafda") |
因为文件名和位置没有变化,Runtime 的代码不变,再次执行,结果如下:
可以看到,完美执行,没有报错。
这里说一下为什么没有报错:
1 |
proc = Runtime.getRuntime().exec("python E:\\Runtime.py"); |
该方法有一个缺点,它得到 python 执行结果是通过数据流得到的,每次读取一行,相当于每执行一次就要读取一次结果,这就会导致我们运行的过程中很耗时。
3、java 调用 python 脚本函数(如何传递参数)
这次写一个带参数的 python 脚本——简单的两数相加:
1 2 |
def add(x,y): return x+y |
脚本名称为 add.py,java 调用 python 函数代码为:
1 2 3 4 5 6 7 8 9 |
PythonInterpreter interpreter = new PythonInterpreter(); interpreter.execfile("E:\\RunTime.py"); // 第一个参数为期望获得的函数(变量)的名字,第二个参数为期望返回的对象类型 PyFunction pyFunction = interpreter.get("add", PyFunction.class); int a = 5, b = 10; //调用函数,如果函数需要参数,在Java中必须先将参数转化为对应的“Python类型” PyObject pyobj = pyFunction.__call__(new PyInteger(a), new PyInteger(b)); System.out.println("the anwser is: " + pyobj); |
运行结果:
4、java 调用 python 机器学习模型
java 调用 python 机器学习模型,我总结了一下共有四种:
- 利用上面的 java 调用 python 脚本——训练集和测试集写入文本中,python 脚本进行读取
- 将 python 训练的模型参数保存到文本中,用 java 代码重现模型的预测算法。这种工作量很大,而且出现的 bug 几率大大增加。最重要的是很多深度学习的框架就没办法用了。
- 使用 python 进程运行深度学习中训练的模型,在 java 应用程序中调用 python 进程提供的服务。这种方法没尝试过。python 语言写得程序毕竟还是在 python 环境中执行最有效率。而且 python 应用和 java 应用可以运行在不同的服务器上,通过进程的远程访问调用。
- 将机器学习模型保存为 pmml 文件,然后 java 调用 pmml 文件。这种方法是网上最常见的方法,进行上线部署的时候,不会依赖于 python 环境,推荐使用。
上面四种方法,前三种都需要依赖于 python 环境,如果要部署的系统中存在 python 环境那么使用前三种是可以的,如果没有,那么第四种方法是最优的。
下面对于 pmml 进行介绍和实例。
4.1 pmml 介绍
PMML:Predictive Model Markup Language 预测模型标记语言。data mining group 推出的,有十多年的历史了。是一种可以呈现预测分析模型的事实标准语言。标准东西的好处就是,各种开发语言都可以使用相应的包,把模型文件转成这种中间格式,而另外一种开发语言,可以使用相应的包导入该文件做线上预测。
PMML 是数据挖掘的一种通用的规范,它用统一的 XML 格式来描述我们生成的机器学习模型。这样无论你的模型是 sklearn,R 还是 Spark MLlib 生成的,我们都可以将其转化为标准的 XML 格式来存储。当我们需要将这个 PMML 的模型用于部署的时候,可以使用目标环境的解析 PMML 模型的库来加载模型,并做预测。
可以看出,要使用 PMML,需要两步的工作,第一块是将离线训练得到的模型转化为 PMML 模型文件,第二块是将 PMML 模型文件载入在线预测环境,进行预测。这两块都需要相关的库支持。
4.2 实例代码实现
针对 CIC-IDS2017 数据集,sklearn 决策树机器学习模型保存为 pmml 文件实现代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import pandas as pd from sklearn2pmml.pipeline import PMMLPipeline from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn2pmml import sklearn2pmml # 加载数据 raw_data_filename = "data/clearData/total_expend.csv" print("Loading raw data...") raw_data = pd.read_csv(raw_data_filename, header=None,low_memory=False) # 随机抽取比例 # raw_data=raw_data.sample(frac=0.03) # 将非数值型的数据转换为数值型数据 raw_data[last_column_index], attacks = pd.factorize(raw_data[last_column_index], sort=True) # 对原始数据进行切片,分离出特征和标签,第1~41列是特征,第42列是标签 features = raw_data.iloc[:, :raw_data.shape[1] - 1] # pandas中的iloc切片是完全基于位置的索引 labels = raw_data.iloc[:, raw_data.shape[1] - 1:] # 数据标准化 # features = preprocessing.scale(features) # features = pd.DataFrame(features) # 将多维的标签转为一维的数组 labels = labels.values.ravel() # 将数据分为训练集和测试集,并打印维数 df = pd.DataFrame(features) X_train, X_test, y_train, y_test = train_test_split(df, labels, train_size=0.8, test_size=0.2, stratify=labels) pipeline = PMMLPipeline([("classifier", DecisionTreeClassifier(criterion='entropy', max_depth=12, min_samples_leaf=1, splitter="best"))]) pipeline.fit(X_train, y_train) sklearn2pmml(pipeline, "data/pmml/DecisionTreeIris.pmml", with_repr = True) |
存储的 pmml 文件内容:
java 调用生成的 pmml 文件,并进行预测新数据代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.*; import org.jpmml.evaluator.support_vector_machine.VoteDistribution; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.xml.sax.SAXException; import javax.xml.bind.JAXBException; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.io.InputStream; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.io.*; /** * 分类模型测试 */ class PMMLDemo { /** * 加载模型 */ private Evaluator loadPmml() { PMML pmml = new PMML(); InputStream inputStream = null; try { // 读取resources文件夹下的pmml文件 Resource resource = new ClassPathResource("DecisionTreeIris.pmml"); inputStream = resource.getInputStream(); } catch (IOException e) { e.printStackTrace(); } if (inputStream == null) { return null; } InputStream is = inputStream; try { pmml = org.jpmml.model.PMMLUtil.unmarshal(is); } catch (SAXException | JAXBException e1) { e1.printStackTrace(); } finally { //关闭输入流 try { is.close(); } catch (IOException e) { e.printStackTrace(); } } ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); return modelEvaluatorFactory.newModelEvaluator(pmml); } /** * 分类预测 */ private void predict(Evaluator evaluator,Map<String, Double> featuremap) { List<InputField> inputFields = evaluator.getInputFields(); System.out.println(inputFields); // 从原始特征获取数据,作为模型输入 Map<FieldName, FieldValue> arguments = new LinkedHashMap<FieldName, FieldValue>(); for (InputField inputField : inputFields) { // 特征名称 FieldName inputFieldName = inputField.getName(); // 特征值 Object rawValue = featuremap.get(inputFieldName.getValue()); FieldValue inputFieldValue = inputField.prepare(rawValue); arguments.put(inputFieldName, inputFieldValue); } // 预测结果 Map<FieldName, ?> results = evaluator.evaluate(arguments); List<TargetField> targetFields = evaluator.getTargetFields(); for (TargetField targetField : targetFields) { FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); System.err.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue); } } //读取csv文件 private static String readCSV(){ //第一步:先获取csv文件的路径,通过BufferedReader类去读该路径中的文件 File csv = new File("E:\\ideaProject\\javaRunPython\\src\\main\\resources\\1.csv"); String lineDta = ""; try{ //第二步:从字符输入流读取文本,缓冲各个字符,从而实现字符、数组和行(文本的行数通过回车符来进行判定)的高效读取。 BufferedReader textFile = new BufferedReader(new FileReader(csv)); //第三步:将文档的下一行数据赋值给lineData,并判断是否为空,若不为空则输出 while (textFile.readLine()!= null){ lineDta = textFile.readLine(); } textFile.close(); }catch (FileNotFoundException e){ System.out.println("没有找到指定文件"); }catch (IOException e){ System.out.println("文件读写出错"); } return lineDta; } public static void main(String args[]){ PMMLDemo demo = new PMMLDemo(); Evaluator model = demo.loadPmml(); Map<String, Double> data = new HashMap<String, Double>(); //读取测试数据(一行),并对其进行处理 String test = readCSV(); System.out.println(test); String[] tests=test.split("\t"); for(int i=0;i<tests.length-1;i++){ data.put(""+i,Double.valueOf(tests[i])); } //将测试数据data放入模型中进行预测 demo.predict(model,data); } } |
先说一下最终结果:
打印的时候,用的是红色 err 打印的(注意这里不是报错),显示的数据 data 对于各个类别预测概率,其中预测类别为 6 的概率最大为 99.4%(标红框)。
4.3 代码实例解释
对于 python 的代码这里不解释了,比较简单。
具体说一下 java 的,主要分为三部分:
(1)java 读取 pmml 文件并将其转换为 java 机器学习模型 model
简单来说,就是利用流读取 pmml 文件,然后 java 导入的 pmml 库将其实现为一个机器学习模块。
(2)读取 csv 测试数据,并进行数据处理
因为 CIC-IDS2017 数据集比较大(共有 80 个特征),所以不能手写,看网上的例子,一般是通过如下进行手动写的:
1 2 3 4 |
data.put("x1", 5.1); data.put("x2", 3.5); data.put("x3", 1.4); data.put("x4", 0.2); |
这里需要说一下为什么数据集是一个 map 形式,要知道在 python 里面,数据集就是单纯数据列(只有 value,没有 key),这是因为在 python 进行数据训练的时候,sklearn 的决策树模型会自动加一个 key,这个 key 就是数据的索引。
我再看网上的例子的时候,发现有一部分手写测试数据的时候,并没有写全,但是依然能够运行,这个是因为机器学习模型并没有全部应用特征数据,就比如上面的决策树,它的层数只有十几层,也就是用到了只是十几个特征,你只需要将这十几个特征写入数据集就行,前提是这十几个特征的索引你要写对,其实可以打印一下模型中的特征,也就是下面的 inputFields 的索引:
1 2 3 4 5 6 7 8 |
for (InputField inputField : inputFields) { // 特征名称 FieldName inputFieldName = inputField.getName(); // 特征值 Object rawValue = featuremap.get(inputFieldName.getValue()); FieldValue inputFieldValue = inputField.prepare(rawValue); arguments.put(inputFieldName, inputFieldValue); } |
在对数据集进行处理的时候,曾遇到一个问题,就是读取 CSV 文件后获得一个字符,如何切分的问题。读取后的字符串是这个样子:
我一开始用单空格切分,发现不行,双空格也不行,打印的时候,发现字符串中有很多\t,于是就用\t 分割,果然正确,网上查了一下\t 表示:
看了这个解释,顿时明白了,csv 文件本身就是表格性质的,\t 相当于补全了。
(3)读取测试数据 data,进行预测
这一步理解起来就比较简单了,因为模型是决策树,从根节点出发,读取特征(inputFields),根据所需特征从 data 中读取出来,然后进行预测。
4.4 遇到的问题
遇到的问题主要是 java 这边的
(1)java-source1.5 中不支持 multi-catch 语句
解决方案参考:https://blog.csdn.net/qq_39793857/article/details/106925721
(2)Exception in thread “main“ java.lang.IllegalArgumentException : http://www.dmg.org / PMML-4_4 is not support
这是版本问题,要修改 pmml 文件的表头,具体参考:https://blog.csdn.net/qq_32113189/ article/details/107542225
评论