Введение: Магия видеть внутренности (без рентгеновских очков)

Представьте: вы рентгенолог, изучающий МРТ-снимок в поисках опухоли, играющей в прятки в море оттенков серого. Это как искать Уолдо, только Уолдо может убить кого-нибудь, если вы его пропустите. Встречайте U-Net — Шерлока Холмса сегментации медицинских изображений. Мы собираемся создать систему, которая находит опухоли быстрее, чем малыш находит банку с печеньем. И не волнуйтесь, я проведу вас через каждый шаг, как GPS для новичков в глубоком обучении.

Архитектура U-Net: Цифровой танго сжатия и расширения

Представьте змею, которая ест медицинские изображения, переваривает особенности, а затем извергает идеальные пиксельные маски. Это, по сути, танец кодировщика-декодировщика U-Net. Кодировщик сжимает изображения, как застенчивая черепаха (путь сжатия), а декодер расширяет их, как самоуверенный баллон (путь расширения). Секретный соус? Пропускные соединения — архитектурный эквивалент шпаргалок между семестрами.

graph TD A[Исходное изображение] --> B[Блок кодировщика] B --> C[MaxPooling] C --> D[Блок кодировщика] D --> E[MaxPooling] E --> F[Блок кодировщика] F --> G[Узкое место] G --> H[Блок декодера] H --> I[UpSampling] I --> J[Блок декодера] J --> K[UpSampling] K --> L[Блок декодера] L --> M[Выходной маска] B -->|Пропускное соединение| J D -->|Пропускное соединение| L

Ваш инструментарий: PyTorch и немного изящества

Прежде чем начать готовить, давайте настроим нашу кухню:

pip install torch torchvision monai matplotlib

Теперь наденьте свою виртуальную поварскую шляпу — мы готовим U-Net с нуля! Вот наш рецепт:

Шаг 1: Двойная свертка (рабочий слой)

import torch
import torch.nn as nn
class DoubleConv(nn.Module):
    """Двойная свертка — вдвое больше веселья!"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)

Шаг 2: Построение кодировщика (компактор)

class DownBlock(nn.Module):
    """Уменьшитель — как корпоративное, но полезное"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.maxpool_conv(x)

Шаг 3: Создание декодера (луч обратного сжатия)

class UpBlock(nn.Module):
    """Вечеринка возвращения декодера"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        """x1 — из декодера, x2 — пропускное соединение"""
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

Шаг 4: Сборка U-Net Франкенштейна

class UNet(nn.Module):
    """Главное событие — представьте роботов-трансформеров, но для медицины"""
    def __init__(self, n_channels=1, n_classes=2):
        super().__init__()
        # Путь кодировщика
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = DownBlock(64, 128)
        self.down2 = DownBlock(128, 256)
        self.down3 = DownBlock(256, 512)
        self.down4 = DownBlock(512, 1024)
        # Путь декодера с пропускными соединениями
        self.up1 = UpBlock(1024, 512)
        self.up2 = UpBlock(512, 256)
        self.up3 = UpBlock(256, 128)
        self.up4 = UpBlock(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    def forward(self, x):
        # Шаги кодировщика
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        # Шаги декодера с пропускными соединениями
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

Подготовка данных: Превращение сканирований в закуски для ИИ

Медицинские изображения — как привередливые едоки — им нужна специальная подготовка:

from torch.utils.data import Dataset
import nibabel as nib
import numpy as np
class MedicalImageDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image = nib.load(self.image_paths[idx]).get_fdata()
        mask = nib.load(self.mask_paths[idx]).get_fdata()
        # Нормализация и добавление размерности канала
        image = (image - image.min()) / (image.max() - image.min())
        image = np.expand_dims(image, axis=0)
        # One-hot кодирование масок
        mask = (mask > 0).astype(np.float32)
        mask = np.expand_dims(mask, axis=0)
        if self.transform:
            image, mask = self.transform((image, mask))
        return image, mask

Увеличение данных: Искусство медицинских подделок

import torchvision.transforms as transforms
# Потому что иногда нужно солгать вашей модели, чтобы сделать её умнее
train_transform = transforms.Compose([
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2)
])

Обучение: Обучение вашего цифрового стажёра

Функция потерь Dice — когда точность просто не подходит

class DiceLoss(nn.Module):
    """Эквивалент трофея за участие в сегментации"""
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        intersection = (probs * targets).sum()
        union = probs.sum() + targets.sum()
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice

Цикл обучения: Где происходит магия (медленно)

def train_model(model, dataloader, optimizer, loss_fn, device, epochs=30):
    """Эквивалент отправки вашего ребёнка в медицинскую школу для ИИ"""
    model.train()
    for epoch in range(epochs):
        running