這篇文章將為大家詳細(xì)講解有關(guān)Python sklearn KFold如何生成交叉驗證數(shù)據(jù)集,小編覺得挺實用的,因此分享給大家做個參考,希望大家閱讀完這篇文章后可以有所收獲。
創(chuàng)新互聯(lián)-專業(yè)網(wǎng)站定制、快速模板網(wǎng)站建設(shè)、高性價比海豐網(wǎng)站開發(fā)、企業(yè)建站全套包干低至880元,成熟完善的模板庫,直接使用。一站式海豐網(wǎng)站制作公司更省心,省錢,快速模板網(wǎng)站建設(shè)找我們,業(yè)務(wù)覆蓋海豐地區(qū)。費用合理售后完善,十余年實體公司更值得信賴。源起:
1.我要做交叉驗證,需要每個訓(xùn)練集和測試集都保持相同的樣本分布比例,直接用sklearn提供的KFold并不能滿足這個需求。
2.將生成的交叉驗證數(shù)據(jù)集保存成CSV文件,而不是直接用sklearn訓(xùn)練分類模型。
3.在編碼過程中有一的誤區(qū)需要注意:
這個sklearn官方給出的文檔
>>> import numpy as np >>> from sklearn.model_selection import KFold >>> X = ["a", "b", "c", "d"] >>> kf = KFold(n_splits=2) >>> for train, test in kf.split(X): ... print("%s %s" % (train, test)) [2 3] [0 1] [0 1] [2 3]
我之前犯的一個錯誤是將train,test理解成原數(shù)據(jù)集分割成子數(shù)據(jù)集之后的子數(shù)據(jù)集索引。而實際上,它就是原始數(shù)據(jù)集本身的樣本索引。
源碼:
# -*- coding:utf-8 -*- # 得到交叉驗證數(shù)據(jù)集,保存成CSV文件 # 輸入是一個包含正常惡意標(biāo)簽的完整數(shù)據(jù)集,在讀數(shù)據(jù)的時候分開保存到datasetBenign,datasetMalicious # 分別對兩個數(shù)據(jù)集進行KFold,最后合并保存 from sklearn.model_selection import KFold import csv def writeInFile(benignKFTrain, benignKFTest, maliciousKFTrain, maliciousKFTest, i, datasetBenign, datasetMalicious): newTrainFilePath = "E:\\hadoopExperimentResult\\5KFold\\AllDataSetIIR10\\dataset\\ImbalancedAllTraffic-train-%s.csv" % i newTestFilePath = "E:\\hadoopExperimentResult\\5KFold\\AllDataSetIIR10\\dataset\\IImbalancedAllTraffic-test-%s.csv" % i newTrainFile = open(newTrainFilePath, "wb")# wb 為防止空行 newTestFile = open(newTestFilePath, "wb") writerTrain = csv.writer(newTrainFile) writerTest = csv.writer(newTestFile) for index in benignKFTrain: writerTrain.writerow(datasetBenign[index]) for index in benignKFTest: writerTest.writerow(datasetBenign[index]) for index in maliciousKFTrain: writerTrain.writerow(datasetMalicious[index]) for index in maliciousKFTest: writerTest.writerow(datasetMalicious[index]) newTrainFile.close() newTestFile.close() def getKFoldDataSet(datasetPath): # CSV讀取文件 # 開始從文件中讀取全部的數(shù)據(jù)集 datasetFile = file(datasetPath, 'rb') datasetBenign = [] datasetMalicious = [] readerDataset = csv.reader(datasetFile) for line in readerDataset: if len(line) > 1: curLine = [] curLine.append(float(line[0])) curLine.append(float(line[1])) curLine.append(float(line[2])) curLine.append(float(line[3])) curLine.append(float(line[4])) curLine.append(float(line[5])) curLine.append(float(line[6])) curLine.append(line[7]) if line[7] == "benign": datasetBenign.append(curLine) else: datasetMalicious.append(curLine) # 交叉驗證分割數(shù)據(jù)集 K = 5 kf = KFold(n_splits=K) benignKFTrain = []; benignKFTest = [] for train,test in kf.split(datasetBenign): benignKFTrain.append(train) benignKFTest.append(test) maliciousKFTrain=[]; maliciousKFTest=[] for train,test in kf.split(datasetMalicious): maliciousKFTrain.append(train) maliciousKFTest.append(test) for i in range(K): print "======================== "+ str(i)+ " ========================" print benignKFTrain[i], benignKFTest[i] print maliciousKFTrain[i],maliciousKFTest[i] writeInFile(benignKFTrain[i], benignKFTest[i], maliciousKFTrain[i], maliciousKFTest[i], i, datasetBenign, datasetMalicious) datasetFile.close() if __name__ == "__main__": getKFoldDataSet(r"E:\hadoopExperimentResult\5KFold\AllDataSetIIR10\dataset\ImbalancedAllTraffic-10.csv")
關(guān)于“Python sklearn KFold如何生成交叉驗證數(shù)據(jù)集”這篇文章就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,使各位可以學(xué)到更多知識,如果覺得文章不錯,請把它分享出去讓更多的人看到。