1 # coding=utf-8
2
3 import pandas as pd
4 from sklearn.model_selection
import train_test_split
5 from sklearn
import tree
6 from sklearn.metrics
import precision_recall_curve
#准确率与召回率
7 import numpy as np
8 #import graphviz
9
10 import os
11 os.environ[
"PATH"] += os.pathsep +
'C:/Program Files (x86)/Graphviz2.38/bin/'
12
13
14
15 def get_data():
16 file_path =
"Iris.xlsx"
17
18 data =
pd.read_excel(file_path)
19 loandata =
pd.DataFrame(data)
20 ncol =
(len(loandata.keys()))
21 print(ncol)
22 # l = list(data.head(0)) #获取表头
23 # print(l)
24
25 feature1 =
[]
26 for i
in range(ncol-1
):
27 feature1.append(
"feature"+
str(i))
28 print(feature1)
29 iris_x = data.iloc[1:, :ncol-1]
#此处有冒号,不显示最后一列
30 iris_y = data.iloc[1:,ncol-1]
#此处没有冒号,直接定位
31
32 '''计算到底有几个类别'''
33 from collections
import Counter
34 counter =
Counter(iris_y)
35 con =
len(counter)
36 print(counter.keys())
37 class_names =
[]
38 for i
in range(con):
39 class_names.append(list(counter.keys())[i])
40 x_train, x_test, y_train, y_test =
train_test_split(iris_x,iris_y)
41 print(x_train)
42 print(y_test)
43 # return x_train, x_test, y_train, y_test
44
45
46 #def dtfit(x_train, x_test, y_train, y_test):
47
48 clf =
tree.DecisionTreeClassifier()
49 clf =
clf.fit(x_train,y_train)
50 predict_data =
clf.predict(x_test)
51 predict_proba =
clf.predict_proba(x_test)
52 from sklearn
import metrics
53 # Do classification task,
54 # then get the ground truth and the predict label named y_true and y_pred
55 classify_report =
metrics.classification_report(y_test, clf.predict(x_test))
56 confusion_matrix =
metrics.confusion_matrix(y_train, clf.predict(x_train))
57 overall_accuracy =
metrics.accuracy_score(y_train, clf.predict(x_train))
58 acc_for_each_class = metrics.precision_score(y_train,clf.predict(x_train), average=
None)
59 overall_accuracy =
np.mean(acc_for_each_class)
60 print(classify_report)
61
62
63
64
65 import pydotplus
66 dot_data = tree.export_graphviz(clf, out_file=None,feature_names=feature1, filled=True, rounded=True, special_characters=True,precision = 4
)
67 graph =
pydotplus.graph_from_dot_data(dot_data)
68 graph.write_pdf(
"workiris.pdf")
69 return classify_report
70
71
72 if __name__ ==
"__main__":
73 x =
get_data()
74 #dtfit(x_train, x_test, y_train, y_test)
数据地址:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
保存后注意填写表头
转载于:https://www.cnblogs.com/shizhenqiang/p/8204986.html
相关资源:鸢尾花数据集决策树模型