常用的工具函数整理

王 茂南 2019年8月17日07:32:37
评论
2377字阅读7分55秒
摘要这里就是一些常用的工具函数, 可以在需要的时候快速的进行使用.

简介

这里放一些工具函数, 如计算准确率, 文件跳行读取的方式等, 方便自己之后的使用.

文件的读取和判读

图片文件的判断(判断图像的大小和通道个数)

有的时候, 我们会先判断一下我们数据集中的图片是否都是RGB的, 会是否都是灰度图. 我们可以通过下面的方式来进行判断.

  1. import os
  2. import cv2
  3. import numpy as np
  4. from tqdm import tqdm
  5. from PIL import Image
  6. # 获取图片的路径
  7. def get_img_path(path):
  8.     ls = os.listdir(path)
  9.     ls = [path+"/"+x for x in ls]
  10.     return ls
  11. # 对路径内所有图片遍历, 打印出不是三通道的图片
  12. def Preprocess(path):
  13.     imagePaths = get_img_path(path) # 获取图片路径
  14.     for imagePath in tqdm(imagePaths):
  15.         imageTest = Image.open(imagePath)
  16.         # imageTest = cv2.imread(imagePath, cv2.IMREAD_COLOR) # 读取图片
  17.         try:
  18.             if imageTest==None:
  19.                 print('图像错误: {}'.format(imagePath))
  20.                 imageTest.close()
  21.                 os.remove(imagePath)
  22.                 print('已删除!')
  23.             elif imageTest.mode=='L': # 查看图片大小
  24.                 print('图像通道错误: {}'.format(imagePath))
  25.                 imageTest.close()
  26.                 os.remove(imagePath)
  27.                 print('已删除!')
  28.             else:
  29.                 imageTest.close()
  30.         except:
  31.             pass
  32.     print('Finish!')
  33. if __name__ == "__main__":
  34.     Preprocess('./datasets/monet2photo/trainB')

 

csv文件读取方式

有的时候文件较大的时候我们可以跳行进行读取, 将一个大的文件分成几次读入, 我们可以使用下面的方式进行文件的读取.

  1. df = pd.read_csv(csv_path,
  2.                     header=None,
  3.                     nrows = base, # 从第base行开始读取
  4.                 skiprows=skiprow) # 跳过读取的行数

 

评价指标的计算

计算准确率

当target是numpy的数据类型.

  1. def accuracy(target, logit):
  2.     ''' Obtain accuracy for training round '''
  3.     target = target.argmax(axis=1) # convert from one-hot encoding to class indices
  4.     corrects = (logit == target).sum()
  5.     accuracy = 100.0 * corrects / len(logit)
  6.     return accuracy

使用pytorch的时候来计算准确率.

  1. def accuracy(target, logit):
  2.     ''' Obtain accuracy for training round '''
  3.     target = torch.max(target, 1)[1] # convert from one-hot encoding to class indices
  4.     corrects = (logit == target).sum()
  5.     accuracy = 100.0 * corrects / len(logit)
  6.     return accuracy

 

AUC的计算

关于更多AUC的计算, 可以查看链接: ROC曲线的绘制与AUC的计算

首先是关于AUC的计算,可以直接使用sklearn中AUC的计算方式来完成。参考链接 : [sklearn]性能度量之AUC值(from sklearn.metrics import roc_auc_curve)

下面是一个简单的例子,可以看到输入的预测值可以是概率也可以是类别。

  1. ### 真实值和预测值
  2. import numpy as np
  3. y_test = np.array([0,0,1,1])
  4. y_pred1 = np.array([0.3,0.2,0.25,0.7])
  5. y_pred2 = np.array([0,0,1,0])
  6. ### 性能度量auc
  7. from sklearn.metrics import roc_auc_score
  8. # 预测值是概率
  9. auc_score1 = roc_auc_score(y_test,y_pred1)
  10. print(auc_score1)
  11. # 预测值是类别
  12. auc_score2 = roc_auc_score(y_test,y_pred2)
  13. print(auc_score2)

 

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2019年8月17日07:32:37
  • 转载请务必保留本文链接:https://mathpretty.com/10980.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: