真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

使用PointNet++測試分類自己的數(shù)據(jù)集并可視化-創(chuàng)新互聯(lián)

我這里PointNet++的代碼用的是pytorch版本的,鏈接為?https://github.com/yanx27/Pointnet2_pytorch

公司專注于為企業(yè)提供成都網(wǎng)站制作、網(wǎng)站設計、微信公眾號開發(fā)、電子商務商城網(wǎng)站建設,微信平臺小程序開發(fā),軟件按需策劃設計等一站式互聯(lián)網(wǎng)企業(yè)服務。憑借多年豐富的經(jīng)驗,我們會仔細了解各客戶的需求而做出多方面的分析、設計、整合,為客戶設計出具風格及創(chuàng)意性的商業(yè)解決方案,成都創(chuàng)新互聯(lián)更提供一系列網(wǎng)站制作和網(wǎng)站推廣的服務。

將自己的數(shù)據(jù)集格式修改為和modelnet40_normal_resampled數(shù)據(jù)集格式一樣。

?由于源碼中測試腳本只是輸出了測試數(shù)據(jù)集的分類精確度,且測試數(shù)據(jù)集同樣的是有標簽的,沒有模型驗證腳本,由于個人實驗需要,希望當模型訓練完成后能用自己的無標簽數(shù)據(jù)輸入后輸出類別去檢測模型的分類效果,因此根據(jù)模型測試腳本,修改了一下代碼,可以實現(xiàn)輸入一個無標簽的數(shù)據(jù),從而輸出分類結(jié)果以及可視化,從而更直觀的驗證模型訓練的準確度。?

代碼如下,其中可視化部分參考這位博主的文章?pointconv pytorch modelnet40 點云分類結(jié)果可視化_對象被拋出的博客-博客_modelnet40可視化

from data_utils.ModelNetDataLoader_my import ModelNetDataLoader
import argparse
import numpy as np
import os
import torch
import logging
from tqdm import tqdm
import sys
import importlib
import matplotlib.pyplot as plt

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

def pc_normalize(pc):  #點云數(shù)據(jù)歸一化
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.cuda.is_available())
def parse_args():
    '''PARAMETERS'''
    parser = argparse.ArgumentParser('Testing')
    parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
    parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size in training')
    parser.add_argument('--num_category', default=10, type=int, choices=[10, 40],  help='training on ModelNet10/40')
    parser.add_argument('--num_point', type=int, default=10000, help='Point Number')
    parser.add_argument('--log_dir', type=str, default='pointnet2_cls_msg', help='Experiment root')
    parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
    parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
    parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
    return parser.parse_args()
#加載數(shù)據(jù)集
dataset='/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/evalset/aaa_1.txt'
pcdataset = np.loadtxt(dataset, delimiter=' ').astype(np.float32)#數(shù)據(jù)讀取
point_set = pcdataset[0:10000, :] #我的輸入數(shù)據(jù)設置為原始數(shù)據(jù)中10000個點
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) #歸一化數(shù)據(jù)
point_set = point_set[:, 0:3] 
point_set = point_set.transpose(1,0)#將數(shù)據(jù)由N*C轉(zhuǎn)換為C*N
#print(point_set.shape)
point_set = point_set.reshape(1, 3, 10000)
n_points = point_set
point_set = torch.as_tensor(point_set)#需要將數(shù)據(jù)格式變?yōu)閺埩?,不然會報錯
point_set = point_set.cuda()
#print(point_set.shape)
#print(point_set.shape)
#分類測試函數(shù)
def test(model,point_set, num_class=10, vote_num=1):
    #mean_correct = []
    classifier = model.eval()
    class_acc = np.zeros((num_class, 3))
    vote_pool = torch.zeros(1, 10).cuda()
    for _ in range(vote_num):
        pred, _ = classifier(point_set)
        print(pred)
        vote_pool += pred
    pred = vote_pool / vote_num
    # 對預測結(jié)果每行取大值得到分類
    pred_choice = pred.data.max(1)[1]
    print(pred_choice)
    #可視化
    file_dir = '/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/visualizer'
    save_name_prefix = 'pred'
    draw(n_points[:, 0, :], n_points[:, 1, :], n_points[:, 2, :], save_name_prefix, file_dir, color=pred_choice)
    return pred_choice
#定義可視化函數(shù)
def draw(x, y, z, name, file_dir, color=None):
    """
    繪制單個樣本的三維點圖
    """
    if color is None:
        for i in range(len(x)):
            ax = plt.subplot(projection='3d')  # 創(chuàng)建一個三維的繪圖工程
            save_name = name + '-{}.png'.format(i)
            save_name = os.path.join(file_dir,save_name)
            ax.scatter(x[i], y[i], z[i],s=0.1, c='r')
            ax.set_zlabel('Z')  # 坐標軸
            ax.set_ylabel('Y')
            ax.set_xlabel('X')
            plt.draw()
            plt.savefig(save_name)
            # plt.show()
    else:
        colors = ['red', 'blue', 'green', 'yellow', 'orange', 'tan', 'orangered', 'lightgreen', 'coral', 'aqua']
        for i in range(len(x)):
            ax = plt.subplot(projection='3d')  # 創(chuàng)建一個三維的繪圖工程
            save_name = name + '-{}-{}.png'.format(i, color[i])
            save_name = os.path.join(file_dir,save_name)
            ax.scatter(x[i], y[i], z[i],s=0.1, c=colors[color[i]])
            ax.set_zlabel('Z')  # 坐標軸
            ax.set_ylabel('Y')
            ax.set_xlabel('X')
            plt.draw()
            plt.savefig(save_name)
            # plt.show()

def main(args):
    def log_string(str):
        logger.info(str)
        print(str)
    '''HYPER PARAMETER'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    '''CREATE DIR'''
    experiment_dir = '/home/niu/mysubject/Pointnet_Pointnet2_pytorch-master/log/classification/' + args.log_dir
    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    '''
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir)
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)
    '''
    num_class = args.num_category
    #選擇模型
    model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0]
    model = importlib.import_module(model_name)
    
    classifier = model.get_model(num_class, normal_channel=args.use_normals)
    if not args.use_cpu:
        classifier = classifier.cuda()
    #選擇訓練好的.pth文件
    checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth')
    classifier.load_state_dict(checkpoint['model_state_dict'])
    #預測分類
    with torch.no_grad():
         pred_choice = test(classifier.eval(), point_set, vote_num=args.num_votes, num_class=num_class)
         #log_string('pred_choice: %f' % (pred_choice))

if __name__ == '__main__':
    args = parse_args()
    main(args)

根據(jù)自己的數(shù)據(jù)格式修改自己對應的參數(shù)以及數(shù)據(jù)集路徑運行即可

分類輸出結(jié)果:

輸出為分類的數(shù)據(jù)類別3

可視化結(jié)果保存在visualizer文件下,可視化結(jié)果:

你是否還在尋找穩(wěn)定的海外服務器提供商?創(chuàng)新互聯(lián)www.cdcxhl.cn海外機房具備T級流量清洗系統(tǒng)配攻擊溯源,準確流量調(diào)度確保服務器高可用性,企業(yè)級服務器適合批量采購,新人活動首月15元起,快前往官網(wǎng)查看詳情吧


文章題目:使用PointNet++測試分類自己的數(shù)據(jù)集并可視化-創(chuàng)新互聯(lián)
瀏覽地址:http://weahome.cn/article/ijhej.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部