-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
84 lines (74 loc) · 2.74 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/python
# -*- coding: utf-8 -*-
# @Time : 2023/4/19 18:22
# @Author : achieve_dream
# @File : model.py
# @Software: Pycharm
import cv2
import numpy as np
from abc import ABC, abstractmethod
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, auc
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 设置字体
plt.rcParams["axes.unicode_minus"] = False # 该语句解决图像中的“-”负号的乱码问题
class Model(ABC):
"""
机器学习, 图像分类模型
"""
def __init__(self, img_nums: int):
"""
构造方法, 初始化
:param img_nums: 加载图片的数量
"""
assert (img_nums % 10 == 0) and (img_nums > 10), "图片数量必须大于10, 并且为10的整数倍"
self.dataset = self.pure_load_dataset(img_nums)
self.img_nums = img_nums
@staticmethod
def pure_load_dataset(img_nums: int) -> np.ndarray:
"""
只进行数字处理, 不进行IO操作
:return: dataset(选取的图像的下标)
"""
# return np.array([[i for i in range(j, j + 10)] for j in range(1, img_nums + 1, 10)])
return np.array([i for i in range(1, img_nums + 1)]).reshape(-1, 10)
@staticmethod
def read_img(img_index: int) -> np.ndarray:
return cv2.imread("dataset/" + str(img_index).rjust(5, '0') + ".bmp", cv2.IMREAD_GRAYSCALE)
@staticmethod
def softmax(distance: np.ndarray, eps: float = 1e-6) -> np.ndarray:
"""
根据图像的距离, 算它每个类别的概率
:param distance: 一个图像到每个类别的距离
:param eps: 一个很小的数, 防止距离为0而导致数据过大
:return: 概率数组
"""
# 由于距离和概率成反比, 因此取距离的倒数
d = 1 / (distance + eps)
e_x = np.exp(d - np.max(d)) # 防止指数爆炸
return e_x / e_x.sum(axis=0)
@abstractmethod
def compute(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, float]:
"""
计算方法结果
:return: 真实标签, 预测标签, 准确率
"""
...
@staticmethod
def plot(target_labels: np.ndarray, predict_scores: np.ndarray, file_name: str):
"""
绘制图形
:return: None
"""
fpr, tpr, _ = roc_curve(target_labels, predict_scores)
# roc面积
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'ROC 曲线 (面积 = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1])
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.title(f'{file_name.upper()} 曲线')
plt.legend()
plt.savefig(f'runs/{file_name}.svg')
plt.show()