监督学习算法2.3.5-决策树

mac2026-01-20  9

全文代码如下

需要在github上下载相关数据集,下载整个包,在data中找到ram_prices.csv即可

点这下载

#决策树 import mglearn import numpy as np import matplotlib.pyplot as plt import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split cancer = load_breast_cancer() x_train,x_test,y_train,y_test = train_test_split(cancer.data,cancer.target,stratify=cancer.target,random_state=42) tree = DecisionTreeClassifier(random_state=0) tree.fit(x_train,y_train) print('accuracy on training set:{:.3f}'.format(tree.score(x_train,y_train))) print('accuracy on test set:{:.3f}'.format(tree.score(x_test,y_test))) #树的深度为4 tree = DecisionTreeClassifier(max_depth=4,random_state=0) tree.fit(x_train,y_train) print('accuracy on training set:{:.3f}'.format(tree.score(x_train,y_train))) print('accuracy on test set:{:.3f}'.format(tree.score(x_test,y_test))) from sklearn.tree import export_graphviz import graphviz export_graphviz(tree,out_file='tree.dot',class_names=['malignant','benign'],feature_names=cancer.feature_names,impurity=False,filled=True) with open('tree.dot') as f: dot_graph = f.read() graphviz.Source(dot_graph) print('feature importances:{}'.format(tree.feature_importances_)) def plot_feature_importances_cancer(model): n_features = cancer.data.shape[1] plt.barh(range(n_features),model.feature_importances_,align='center') plt.yticks(np.arange(n_features),cancer.feature_names) plt.xlabel('feature importance') plt.ylabel("feature") plt.show() from IPython import display plot_feature_importances_cancer(tree) tree = mglearn.plots.plot_tree_not_monotone() display.display(tree) plt.show() #计算机内存价格 ram_prices = pd.read_csv("ram_price.csv") plt.semilogy(ram_prices.date,ram_prices.price) plt.xlabel('year') plt.ylabel('price in $/mbyte') plt.show() from sklearn.tree import DecisionTreeRegressor from sklearn.linear_model import LinearRegression data_train = ram_prices[ram_prices.date < 2000] data_test = ram_prices[ram_prices.date >= 2000] x_train = data_train.date[:,np.newaxis] y_train = np.log(data_train.price) tree = DecisionTreeRegressor().fit(x_train,y_train) linear_reg = LinearRegression().fit(x_train,y_train) x_all = ram_prices.date[:,np.newaxis] pred_tree = tree.predict(x_all) pred_lr = linear_reg.predict(x_all) price_tree = np.exp(pred_tree) price_lr = np.exp(pred_lr) plt.semilogy(data_train.date,data_train.price,label='training data') plt.semilogy(data_test.date,data_test.price,label='test data') plt.semilogy(ram_prices.date,price_tree,label='tree prediction') plt.semilogy(ram_prices.date,price_lr,label='linear prediction') plt.legend() plt.show()

最新回复(0)