由于工作涉及到了ROC曲线,自己不是很懂,就在网上找了资料自学了一下。 明白每一行代码,甚至可以改代码的时候,应该就学得差不多了。 这次我选用了著名的iris鸢尾花数据集作为数据源。(R,python,学习经常用到) iris包含花瓣长度、花瓣宽度、花萼长度、花萼宽度四个特征, “setosa”,“versicolor”,"virginica"3个种类的鸢尾花,一共150个数据样本。 Sklearn.datasets机器学习包可以直接得到。大概长这样。 你期待的代码:
import numpy as np #画图用的包 import matplotlib.pyplot as plt # 支持向量机分类算法 from sklearn import svm,datasets #roc 2分类曲线 from sklearn.metrics import roc_curve,auc from sklearn.model_selection import train_test_split # 下载iris数据集 iris = datasets.load_iris() # 获取数据特征 X = iris.data # 获取数据标签(0,1,2)分别代表不同的种类的鸢尾花 y = iris.target # 由于数据是3分类的,我们需要转换为2分类 #变为2分类,我取了(0,1) X, y = X[y != 2], y[y != 2] # 增加噪音特征,是问题稍稍点挑战性 # 可通过用Numpy工具包生成模拟数据集,使用RandomState获得随机数生成器 # 参数0为随机种子,当多次运行此段代码能够得到完全一样的结果。 random_state = np.random.RandomState(0) # 获取X数据矩阵行和列(100,4), n_samples, n_features = X.shape # np.c_是按行连接两个矩阵,就是把两矩阵左右相加,要求行数相等 X.shape=(100,800) X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] # 将数据划分为训练集和测试集,test_size=.3表示30%的测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3,random_state=0) # 学习预测其他类 # svm:作用分类classification,回归regression,异常检测outlier detection # svm:support vector machine(支持向量机) # svm.SVC()全称C—support Vector Classification # kernel:核函数,可选linear(线性核函数)、poly、rbf、sigmoid、precomputed,默认为rbf(高斯核)。 # probability:是否采用概率估计,默认为False。 # random_state:数据随机洗牌时的种子值,默认为缺省。 clf = svm.SVC(kernel='linear', probability=True, random_state=random_state) # 训练模型 clf_fit=clf.fit(X_train,y_train) #通过decision_function()计算得到的y_score的值,用在roc_curve()函数中 # decision_function代表的是参数实例到各个类所代表的超平面的距离; # 在梯度下滑里面特有的(随机森林里面没有decision_function),这个返回的距离,或者说是分值; # 后续的对于这个值的利用方式是指定阈值来进行过滤: y_score = clf_fit.decision_function(X_test) # 计算每个类的ROC曲线和ROC面积 fpr, tpr, threshold = roc_curve(y_test, y_score) #fpr,tpr,thresholds 分别为假正率、真正率和阈值 roc_auc = auc(fpr, tpr) #计算auc的值:0.8133333333333334 #画ROC曲线图 plt.figure(figsize=(10, 10)) plt.plot(fpr, tpr, color='darkorange', lw=1, label='ROC curve (area = %0.2f)' % roc_auc) ###假正率为横坐标,真正率为纵坐标做曲线 plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.show()还有很多不足之处,还望不吝赐教。谢谢。