這里inference兩個程序的連接,如目標(biāo)檢測,可以利用一個程序提取候選框,然后把候選框輸入到分類cnn網(wǎng)絡(luò)中。
這里常需要進(jìn)行一定的連接。
#加載訓(xùn)練好的分類CNN網(wǎng)絡(luò) model=torch.load('model.pkl') #假設(shè)proposal_img是我們提取的候選框,是需要輸入到CNN網(wǎng)絡(luò)的數(shù)據(jù) #先定義transforms對輸入cnn的網(wǎng)絡(luò)數(shù)據(jù)進(jìn)行處理,常包括resize、totensor等操作 data_transforms=transforms.Compose([transforms.RandomSizedCrop(224), transforms.ToTensor()]) #由于transforms是對PIL格式數(shù)據(jù)操作,所以必要時轉(zhuǎn)化格式 def tensor_to_PIL(tensor): image = tensor.cpu().clone() image = image.squeeze(0) image = unloader(image) return image #unqueeze(0)是加多一維,對應(yīng)原來batchsiaze data=data_transforms(proposal_img).unqueeze(0) #新版本pytorch已經(jīng)不用variable,可以省略這句 data=Variable(data) #貌似這句也是多余的 torch.no_grad() predict=F.softmax(model(data.cuda()).cuda())