1.项目背景
在全球范围内,视力障碍与眼部疾病已成为影响人类生活质量的重要因素。特别是在孟加拉国等医疗资源相对匮乏的地区,成年人失明率(1.5%)与视力低下率(21.6%)居高不下,这使得眼部疾病的早期筛查显得尤为紧迫。早期的病理干预能有效防止如视网膜脱离、青光眼或糖尿病视网膜病变等导致的不可逆损伤,但传统的专家诊断模式往往受限于地理位置与昂贵的诊疗成本。
本项目旨在利用深度学习中的计算机视觉技术,构建一套基于轻量化架构 MobileNetV3 的全自动眼疾识别流水线。我们采用了由安瓦拉·哈米达眼科医院等多家专业机构协助采集的 5335 张高质量临床眼底图像,并由医疗专家完成了包含近视、黄斑瘢痕、视盘水肿在内的 10 余种病理类别的精准标注。为了使模型具备更强的泛化能力,我们通过旋转、平移及对比度增强等手段将数据集扩充至 16242 张,以模拟真实的拍摄场景。选择 MobileNetV3 作为主干网络,核心初衷在于平衡诊断精度与运算效率,探索如何将复杂的眼病筛查算法无缝集成到低功耗的手持医疗设备中。通过对医学影像中细微纹理特征的深度挖掘,本项目不仅致力于实现 86% 以上的临床初筛准确率,更希望为偏远地区的防盲治盲工作提供一种可落地的数字化辅助诊断范式。
2.数据集介绍
本实验数据集来源于Kaggle,该数据集为各种眼部疾病的眼底图像数据集。关于数据集:
在全球范围内,眼部疾病被认为是导致非致命性残疾的重要因素。在孟加拉国,1.5%的成年人失明,21.6%的成年人视力低下。因此,眼部疾病的早期发现对于保护视力、预防失明和维护整体健康至关重要。早期发现能够及时进行干预和治疗,防止不可逆的损伤,并保障患者的生活质量。通过分析数据集,研究人员将能够识别疾病趋势、开发诊断算法、评估治疗效果并制定预防措施。
目前,计算机视觉方法在执行此类分类和检测任务方面展现出巨大的潜力。
为了开发基于计算机视觉的算法,本文提供了一个包含多种眼病图像的大型数据集,其中包括视网膜色素变性、视网膜脱离、翼状胬肉、近视、黄斑瘢痕、青光眼、视盘水肿、糖尿病视网膜病变、中心性浆液性脉络膜视网膜病变以及健康眼部图像的原始数据集和增强数据集。该数据集的分类工作由一位来自医疗机构的领域专家协助完成。
在医院方面的协助下,我们从法里德布尔的安瓦拉·哈米达眼科医院和BNS·扎鲁尔·哈克眼科医院收集了共计5335张健康眼部和患病眼部的图像。然后,我们利用旋转、宽度偏移、高度偏移、平移、翻转和缩放等技术,从这些原始图像中生成了共计16242张增强图像,以增加数据量。
3.技术工具
Python版本:3.9
代码编辑器:jupyter notebook
4.实验过程
4.1导入数据
在眼科医学影像处理中,数据的规范化和设备兼容性是核心考量点。我们首先导入了基于 PyTorch 的深度学习生态系统,利用其强大的 torchvision 模块来加载预训练模型权重。为了确保模型能够处理不同来源的眼底扫描图或眼部照片,我们统一将图像尺寸调整为 224x224。此外,我们采用了严格的 7:2:1 比例划分数据集,不仅确保了模型有足够的样本进行梯度下降(训练集),还通过独立的验证集调整超参数,并留出 10% 的原始数据作为最终的“盲测”考卷。
# --- 导入基础库与系统工具 ---import osimport timeimport randomimport copyimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optim# --- 导入视觉处理与 PyTorch 核心模块 ---import torchvisionimport torchvision.transforms as transformsfrom torchvision import transforms, datasetsimport torchvision.models as modelsfrom torch.utils.data import DataLoader, random_split# --- 导入辅助分析与可视化库 ---import matplotlib.pyplot as pltimport seaborn as snsfrom PIL import Imagefrom sklearn.metrics import confusion_matrix, classification_report# 显式导入权重点配置,确保模型加载的兼容性from torchvision.models import VGG16_Weights, MobileNet_V3_Large_Weights, DenseNet121_Weights# --- 硬件配置:优先使用 CUDA 加速 ---device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# --- 数据集路径设置 ---dataset_path = "/kaggle/input/data/Augmented_Dataset"# --- 数据转换与标准化预处理 ---# 眼病影像通常色调较为统一,标准化的均值与标准差设为 0.5 可帮助模型更快适应像素分布transform = transforms.Compose([transforms.Resize((224, 224)), # 适配 MobileNetV3 标准输入尺寸transforms.ToTensor(), # 转化为张量格式transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化处理])# 加载原始增强后的数据集full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)# --- 科学划分数据集 ---# 定义划分比例:70% 训练, 20% 验证, 10% 测试train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1total_size = len(full_dataset)train_size = int(train_ratio * total_size)val_size = int(val_ratio * total_size)test_size = total_size - train_size - val_size# 执行随机切分,random_split 保证了切分的随机性与不重复性train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])# --- 构建批处理加载器 (DataLoaders) ---batch_size = 32train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 确认类别信息num_classes = len(full_dataset.classes)print(f"检测到类别: {full_dataset.classes}, 类别总数: {num_classes}")# --- 样本均衡性统计函数 ---def count_images_per_class(dataset):# 统计数据集中每个疾病类别的样本数量,防止样本倾斜class_counts = {cls: 0 for cls in dataset.dataset.classes}# 遍历索引路径统计真实样本分布for _, label in dataset.dataset.samples:class_counts[dataset.dataset.classes[label]] += 1return class_counts# 获取并输出训练集的样本分布情况class_counts = count_images_per_class(train_dataset)for class_name, count in class_counts.items():print(f"疾病类别 '{class_name}': 包含 {count} 张训练样本")
在完成数据导入与初步统计后,我们已经建立了一个结构严密的“医学影像仓库”。通过打印的样本分布,我们可以检查是否存在某些罕见眼疾样本不足的问题,这直接决定了后续模型权重的偏置处理。由于医学图像对局部细节(如视网膜血管纹理或晶状体混浊度)极为敏感,这种标准化的预处理和有序的批次加载(DataLoader)将支撑起 MobileNetV3 在高维度特征空间中的梯度平稳更新,为构建轻量且精准的医疗辅助诊断系统打下地基。
4.2数据可视化
类别分布统计
为了直观掌握眼科图像数据集的构成情况,我们首先利用 Matplotlib 绘制了训练集的类别分布直方图。由于医疗数据常面临某些疾病样本获取难度大的挑战,通过 Viridis 色彩映射的条形图,我们可以清晰地判断各病种(如 Cataract、Diabetic Retinopathy、Glaucoma 等)的样本量是否存在显著失衡。这种全局视角有助于我们在后续训练中决定是否需要引入加权损失函数,从而防止模型对高频类别产生过度偏好。
# 获取训练集中的各类别样本计数train_class_counts = count_images_per_class(train_dataset)# 根据类别数量生成 Viridis 色系,增加可视化的专业辨识度colors = plt.cm.viridis(np.linspace(0, 1, len(train_class_counts)))# 绘制类别分布柱状图plt.figure(figsize=(13, 5))plt.bar(train_class_counts.keys(), train_class_counts.values(), color=colors)plt.xlabel("疾病类别 (Class)")plt.ylabel("图像数量 (Number of Images)")plt.title("训练数据集类别分布图")plt.xticks(rotation=45) # 旋转标签防止重叠plt.show()
典型病灶样本展示
在确认了样本分布后,我们编写了 show_random_images 函数,从每个疾病类别中随机抽取具有代表性的眼部原始图像进行平铺展示。通过这种方式,我们可以初步审查数据集的图像质量、拍摄角度以及病灶特征的显著程度。对于 MobileNetV3 这种轻量化网络而言,这种“视觉初探”能帮我们预判模型在捕捉眼底细微病理变化时的潜在难点,确保预处理后的图像仍保留了关键的诊断信息。
def show_random_images(dataset, num_images=1):"""随机展示每个类别的典型眼部图像,用于视觉质量审查"""# 根据类别总数计算子图行数,保持布局整齐rows = len(dataset.dataset.classes) // 5 + (len(dataset.dataset.classes) % 5 > 0)fig, axes = plt.subplots(rows, 5, figsize=(20, 5 * rows))axes = axes.flatten()# 遍历每个疾病类别for idx, cls in enumerate(dataset.dataset.classes):# 筛选出属于当前类别的所有图像索引class_indices = [i for i, (_, label) in enumerate(dataset.dataset.samples) if dataset.dataset.classes[label] == cls]# 随机抽取一张图像并加载random_idx = random.choice(class_indices)img_path, _ = dataset.dataset.samples[random_idx]img = Image.open(img_path)axes[idx].imshow(img)axes[idx].axis("off") # 隐藏轴坐标axes[idx].set_title(cls) # 标注对应疾病名称# 隐藏多余的空白子图区域for i in range(len(dataset.dataset.classes), len(axes)):axes[i].axis("off")plt.tight_layout()plt.show()# 执行展示,观察不同眼部疾病的视觉差异show_random_images(train_dataset)
4.3构建模型
为了充分利用预训练权重的优势,我们编写了灵活的模型构建函数。以 MobileNetV3 为例,我们并没有生硬地冻结所有卷积层,而是采用了“选择性解冻”策略:冻结前部捕获基础纹理的层,仅释放最后两个特征块(Blocks)参与梯度更新。这种做法能让模型在保留通用视觉常识的同时,深度学习眼部疾病特有的病理特征(如细微的出血点或混浊区域)。此外,我们重新设计了分类头,引入了线性层与 Dropout 层,有效增强了模型的非线性映射能力并防止了过拟合。
def get_model_mobilenetv3(num_classes, freeze_layers=True):"""构建 MobileNetV3-Large 模型:结合轻量化 NAS 搜索架构,适配移动端医疗诊断需求"""# 加载带有默认预训练权重的 MobileNetV3model = models.mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)if freeze_layers:total_blocks = len(model.features)for idx, module in enumerate(model.features):# 差异化冻结:仅解冻最后两个特征提取块进行微调if idx < total_blocks - 2:for param in module.parameters():param.requires_grad = Falseelse:for param in module.parameters():param.requires_grad = True# 自定义分类层:引入 512 维中间层并配合 Dropout 增强泛化性model.classifier = nn.Sequential(nn.Linear(model.classifier[0].in_features, 512),nn.ReLU(),nn.Dropout(0.4),nn.Linear(512, num_classes))return model.to(device)def get_model_densenet121(num_classes, freeze_layers=True):"""构建 DenseNet121 模型:利用密集连接机制,提升眼病图像中细微特征的传递效率"""model = models.densenet121(weights=DenseNet121_Weights.DEFAULT)if freeze_layers:features = list(model.features.children())total_blocks = len(features)for idx, module in enumerate(features):# 同样采用部分冻结策略,平衡预训练知识与新特征学习if idx < total_blocks - 2:for param in module.parameters():param.requires_grad = Falseelse:for param in module.parameters():param.requires_grad = Truemodel.classifier = nn.Sequential(nn.Dropout(0.4),nn.Linear(model.classifier.in_features, num_classes))return model.to(device)
医疗影像训练周期较长且极易在后期陷入过拟合。为此,我们封装了 EarlyStopping 类。该类会实时监控验证集的损失值(Validation Loss),如果损失值在连续多个周期(Patience)内不再下降,系统将强制终止训练。这种“及时止损”的策略不仅节省了计算资源,更确保了模型最终保留的是在未知数据上表现最佳的权重参数,从而提升了医疗辅助诊断的可靠性。
class EarlyStopping:"""早停机制:防止模型对医疗图像样本过度拟合"""def __init__(self, patience=5):self.patience = patience # 容忍度:允许 loss 不下降的最大轮数self.best_loss = float('inf')self.counter = 0def should_stop(self, val_loss):# 实时评估验证集 Loss,决定是否提前终止训练if val_loss < self.best_loss:self.best_loss = val_lossself.counter = 0else:self.counter += 1return self.counter >= self.patience
通过这一阶段的架构布局,我们已经为“眼部扫描器”构建了坚实的骨架。MobileNetV3 负责提供轻量高效的特征解析,而精心设置的冻结策略与早停机制则为后续的训练过程拉起了两道“安全防护网”。
4.4训练模型
我们实现的 train_model 函数集成了训练模式(model.train)与评估模式(model.eval)的自动切换。在训练阶段,模型通过反向传播不断更新权重;而在验证阶段,我们通过 torch.no_grad() 禁用梯度计算,以节省显存并加快评估速度。该函数不仅记录了每一轮的损失(Loss)和准确率(Accuracy),还引入了 best_model_wts 机制,确保无论训练由于何种原因停止,返回的始终是在验证集上表现最好的那一套参数,从而规避了过拟合带来的性能回退。
def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, early_stopping, epochs=20):"""核心训练逻辑:涵盖前向传播、反向传播、权重保存及早停监测"""train_losses, val_losses = [], []train_accs, val_accs = [], []all_val_labels, all_val_preds = [], []# 深度拷贝模型初始权重,用于后续保存最优模型best_model_wts = copy.deepcopy(model.state_dict())for epoch in range(epochs):start_time = time.time()# --- 训练阶段 (Training Phase) ---model.train()running_train_loss, correct_train, total_train = 0.0, 0, 0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad() # 梯度清零,防止累加outputs = model(images) # 前向传播loss = criterion(outputs, labels) # 计算损失loss.backward() # 反向传播计算梯度optimizer.step() # 更新权重参数running_train_loss += loss.item()preds = outputs.argmax(1)correct_train += (preds == labels).sum().item()total_train += labels.size(0)# 记录训练指标train_loss = running_train_loss / len(train_loader)train_acc = correct_train / total_traintrain_losses.append(train_loss)train_accs.append(train_acc)# --- 验证阶段 (Validation Phase) ---model.eval()running_val_loss, correct_val, total_val = 0.0, 0, 0with torch.no_grad(): # 关闭梯度跟踪,提升推理效率for images, labels in val_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)running_val_loss += loss.item()preds = outputs.argmax(1)correct_val += (preds == labels).sum().item()total_val += labels.size(0)# 记录验证集真实值与预测值,用于后续混淆矩阵分析all_val_labels.extend(labels.cpu().numpy())all_val_preds.extend(preds.cpu().numpy())val_loss = running_val_loss / len(val_loader)val_acc = correct_val / total_valval_losses.append(val_loss)val_accs.append(val_acc)# 动态调整学习率:如果验证损失进入平台期,自动缩减 LRscheduler.step(val_loss)epoch_time = time.time() - start_time# 实时打印当前轮次的监控数据print(f"Epoch {epoch+1}/{epochs} - Train Acc: {train_acc:.4f}, Train Loss: {train_loss:.4f}, "f"Val Acc: {val_acc:.4f}, Val Loss: {val_loss:.4f}, Time: {epoch_time:.2f}s")# 触发早停监测if early_stopping.should_stop(val_loss):print("触发早停机制,训练提前结束。")break# 更新最佳模型权重if val_loss < early_stopping.best_loss:best_model_wts = copy.deepcopy(model.state_dict())# 加载表现最好的权重并返回model.load_state_dict(best_model_wts)return model, train_losses, val_losses, train_accs, val_accs, all_val_labels, all_val_preds
为了让 MobileNetV3 在医疗影像任务中更快、更稳地收敛,我们选择了 Adam 优化器,它能自适应地调整每个参数的学习率。同时,我们引入了 ReduceLROnPlateau 调度器。这意味着当模型连续 3 轮验证损失不再下降时,学习率会自动缩减为原来的 10%,这种“退一步海阔天空”的策略能让模型在训练后期进行更细粒度的搜索,精准定位损失函数的最优解。
# 初始化模型与硬件部署mobilenet_model = get_model_mobilenetv3(num_classes, freeze_layers=True)# 配置优化器与调度策略optimizer_mnv3 = optim.Adam(mobilenet_model.parameters(), lr=learning_rate)scheduler_mnv3 = optim.lr_scheduler.ReduceLROnPlateau(optimizer_mnv3, mode='min', factor=0.1, patience=3)early_stopping_mnv3 = EarlyStopping(patience=5)# 执行训练mobilenet_model, train_losses_mnv3, val_losses_mnv3, train_accs_mnv3, val_accs_mnv3, val_labels_mnv3, val_preds_mnv3 = train_model(mobilenet_model, nn.CrossEntropyLoss(), optimizer_mnv3, scheduler_mnv3,train_loader, val_loader, early_stopping_mnv3, epochs=epochs)
通过这种严谨的循环逻辑,每一轮训练的时间都被控制在合理范围内,同时通过动态监控验证指标,我们构建了一个既能“拼命学习”又能“保持理智”的智能识别系统。这对于处理白内障、青光眼等病理细节复杂的图像尤为重要,确保了每一滴算力都用在刀刃上。
4.5模型评估
我们通过绘制准确率(Accuracy)与损失值(Loss)随迭代轮数变化的曲线,来复盘模型的“学习心路历程”。理想的医疗影像模型应当在训练集和验证集上表现出同步的提升趋势。如果两条曲线分叉过大,说明模型陷入了样本记忆而非特征理解。通过这张图,我们可以直观地看到早停机制是在哪一刻精准介入,为模型锁定了最具泛化能力的权重状态。
def plot_accuracy_and_loss(train_losses, val_losses, train_accs, val_accs):"""绘制训练与验证的双指标曲线:评估收敛速度与过拟合风险"""plt.figure(figsize=(12, 5))# 准确率曲线:观察模型对眼病特征的抓取上限plt.subplot(1, 2, 1)plt.plot(train_accs, label="训练集准确率")plt.plot(val_accs, label="验证集准确率")plt.xlabel("迭代轮数 (Epochs)")plt.ylabel("准确率 (Accuracy)")plt.title("模型准确率演变趋势")plt.legend()plt.grid(True)# 损失曲线:观察梯度下降的稳定性plt.subplot(1, 2, 2)plt.plot(train_losses, label="训练集损失")plt.plot(val_losses, label="验证集损失")plt.xlabel("迭代轮数 (Epochs)")plt.ylabel("损失值 (Loss)")plt.title("模型损失值演变趋势")plt.legend()plt.grid(True)plt.tight_layout()plt.show()
混淆矩阵是医疗辅助诊断中的“终极审计表”。对于眼科医生来说,误诊(将正常识别为患病)与漏诊(将患病识别为正常)的后果截然不同。通过热力图形式的混淆矩阵,我们可以清晰地看到不同眼疾类别之间的误判分布。随后,我们进一步计算了每一类的独立准确率柱状图,这能帮我们精准定位模型对哪种疾病(如病理特征不明显的早期青光眼)的敏感度不足,从而为后续的数据增强或模型微调提供数据支撑。
def plot_confusion_matrix(y_true, y_pred, class_names):"""构建混淆矩阵热力图:深挖各病种间的误判逻辑"""cm = confusion_matrix(y_true, y_pred)plt.figure(figsize=(8, 6))# 使用 Blues 调色板,颜色越深代表预测越准确sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",xticklabels=class_names, yticklabels=class_names)plt.title("眼部疾病分类混淆矩阵")plt.xlabel("模型预测结果")plt.ylabel("真实病理诊断")plt.show()def plot_per_class_accuracy(y_true, y_pred, class_names):"""计算并展示各疾病类别的独立识别率"""cm = confusion_matrix(y_true, y_pred)# 提取对角线元素(正确数)除以该类总数per_class_accuracy = np.diag(cm) / cm.sum(axis=1)plt.figure(figsize=(12, 6))plt.bar(class_names, per_class_accuracy, color="skyblue")plt.xlabel("眼病类别")plt.ylabel("识别准确率")plt.title("不同病种识别性能对比")plt.xticks(rotation=45)plt.show()# 启动评估流程print("nMobileNetV3 训练流程已全部完成,正在生成性能报告...n")plot_accuracy_and_loss(train_losses_mnv3, val_losses_mnv3, train_accs_mnv3, val_accs_mnv3)plot_confusion_matrix(val_labels_mnv3, val_preds_mnv3, class_names)plot_per_class_accuracy(val_labels_mnv3, val_preds_mnv3, class_names)
5.总结
本实验基于 Kaggle 提供的多类别眼部疾病数据集,深入探讨了计算机视觉在辅助医疗诊断中的实战潜力。该数据集具有极高的临床参考价值,其 5335 张原始图像均由专业眼科医院协助采集,并经由医疗领域专家精准标注,涵盖了从糖尿病视网膜病变到青光眼等多种致盲性风险疾病。为了提升轻量化模型MobileNetV3的鲁棒性,我们通过旋转、平移及缩放等增强技术将样本扩充至 16242 张,有效缓解了医学影像中常见的样本不均衡问题。实验结果表明,通过迁移学习与差异化层冻结策略,模型在训练集上达到了 94.75% 的高准确率,验证集准确率也稳健地保持在 86.39%,且损失值控制在 0.4247。这一性能表现证明了 MobileNetV3 架构在兼顾运算效率与识别精度方面的卓越平衡,不仅为眼部疾病的早期筛查提供了一种低成本、高效率的自动化方案,也为未来将深度学习算法集成到移动端或嵌入式医疗设备中奠定了坚实的实战基础。
79