Введение: Магия видеть внутренности (без рентгеновских очков)
Представьте: вы рентгенолог, изучающий МРТ-снимок в поисках опухоли, играющей в прятки в море оттенков серого. Это как искать Уолдо, только Уолдо может убить кого-нибудь, если вы его пропустите. Встречайте U-Net — Шерлока Холмса сегментации медицинских изображений. Мы собираемся создать систему, которая находит опухоли быстрее, чем малыш находит банку с печеньем. И не волнуйтесь, я проведу вас через каждый шаг, как GPS для новичков в глубоком обучении.
Архитектура U-Net: Цифровой танго сжатия и расширения
Представьте змею, которая ест медицинские изображения, переваривает особенности, а затем извергает идеальные пиксельные маски. Это, по сути, танец кодировщика-декодировщика U-Net. Кодировщик сжимает изображения, как застенчивая черепаха (путь сжатия), а декодер расширяет их, как самоуверенный баллон (путь расширения). Секретный соус? Пропускные соединения — архитектурный эквивалент шпаргалок между семестрами.
Ваш инструментарий: PyTorch и немного изящества
Прежде чем начать готовить, давайте настроим нашу кухню:
pip install torch torchvision monai matplotlib
Теперь наденьте свою виртуальную поварскую шляпу — мы готовим U-Net с нуля! Вот наш рецепт:
Шаг 1: Двойная свертка (рабочий слой)
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
"""Двойная свертка — вдвое больше веселья!"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
Шаг 2: Построение кодировщика (компактор)
class DownBlock(nn.Module):
"""Уменьшитель — как корпоративное, но полезное"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
Шаг 3: Создание декодера (луч обратного сжатия)
class UpBlock(nn.Module):
"""Вечеринка возвращения декодера"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
"""x1 — из декодера, x2 — пропускное соединение"""
x1 = self.up(x1)
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
Шаг 4: Сборка U-Net Франкенштейна
class UNet(nn.Module):
"""Главное событие — представьте роботов-трансформеров, но для медицины"""
def __init__(self, n_channels=1, n_classes=2):
super().__init__()
# Путь кодировщика
self.inc = DoubleConv(n_channels, 64)
self.down1 = DownBlock(64, 128)
self.down2 = DownBlock(128, 256)
self.down3 = DownBlock(256, 512)
self.down4 = DownBlock(512, 1024)
# Путь декодера с пропускными соединениями
self.up1 = UpBlock(1024, 512)
self.up2 = UpBlock(512, 256)
self.up3 = UpBlock(256, 128)
self.up4 = UpBlock(128, 64)
self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# Шаги кодировщика
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
# Шаги декодера с пропускными соединениями
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
Подготовка данных: Превращение сканирований в закуски для ИИ
Медицинские изображения — как привередливые едоки — им нужна специальная подготовка:
from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
class MedicalImageDataset(Dataset):
def __init__(self, image_paths, mask_paths, transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = nib.load(self.image_paths[idx]).get_fdata()
mask = nib.load(self.mask_paths[idx]).get_fdata()
# Нормализация и добавление размерности канала
image = (image - image.min()) / (image.max() - image.min())
image = np.expand_dims(image, axis=0)
# One-hot кодирование масок
mask = (mask > 0).astype(np.float32)
mask = np.expand_dims(mask, axis=0)
if self.transform:
image, mask = self.transform((image, mask))
return image, mask
Увеличение данных: Искусство медицинских подделок
import torchvision.transforms as transforms
# Потому что иногда нужно солгать вашей модели, чтобы сделать её умнее
train_transform = transforms.Compose([
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2)
])
Обучение: Обучение вашего цифрового стажёра
Функция потерь Dice — когда точность просто не подходит
class DiceLoss(nn.Module):
"""Эквивалент трофея за участие в сегментации"""
def __init__(self, smooth=1e-6):
super().__init__()
self.smooth = smooth
def forward(self, logits, targets):
probs = torch.sigmoid(logits)
intersection = (probs * targets).sum()
union = probs.sum() + targets.sum()
dice = (2. * intersection + self.smooth) / (union + self.smooth)
return 1 - dice
Цикл обучения: Где происходит магия (медленно)
def train_model(model, dataloader, optimizer, loss_fn, device, epochs=30):
"""Эквивалент отправки вашего ребёнка в медицинскую школу для ИИ"""
model.train()
for epoch in range(epochs):
running