from tqdm import tqdm from random import randint import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as VF from torch.optim.adam import Adam from image_dataset import ImageDataset from torch.utils.data import DataLoader class ResidualBlock(nn.Module): def __init__(self, channels, kernel): super().__init__() self.pre = nn.Sequential( nn.BatchNorm2d(channels), nn.ReLU() ) self.conv = nn.Sequential( nn.Conv2d(channels, channels, kernel, 1, kernel // 2), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, kernel, 1, kernel // 2) ) def forward(self, x): x = self.pre(x) x = self.conv(x) + x return x class ConvPartWholeHierarchy(nn.Module): def __init__(self, n_layers=4): super().__init__() self.n_layers = n_layers layers = [] for _ in range(n_layers): layers.append(nn.Sequential( nn.Conv2d(3, 128, 7, 2, 1), ResidualBlock(128, 1), ResidualBlock(128, 1), nn.BatchNorm2d(128), nn.ReLU(), )) self.layers = nn.ModuleList([]) def forward(self, x): layers = [] for i in range(self.n_layers): x = self.layers[i](x) layers.append(x) return layers BATCH_SIZE = 32 model = ConvPartWholeHierarchy() optim = Adam(model.parameters(), lr=1e-4) model.cuda() ds = ImageDataset('./data', horizontal_flip=False) train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=3, shuffle=True, pin_memory=True, drop_last=True) total_images = 0 for epoch in range(100): batch_index = 0 for batch in tqdm(train_loader, smoothing=0, unit='img', unit_scale=BATCH_SIZE): batch = batch.cuda() total_images += BATCH_SIZE batch_index += 1 while True: flip_type = randint(1, 15) # 0 and 14 are an identity transformations. if flip_type != 14: break def flip(tensor): if flip_type & 4 > 0: tensor = VF.hflip(tensor) if flip_type & 8 > 0: tensor = VF.vflip(tensor) tensor = VF.rotate(tensor, (flip_type % 4) * 90) return tensor x = model(batch)[-1] y = flip(model(flip(batch))[-1]) loss = F.mse_loss(x, y) optim.zero_grad() loss.backward() optim.step() tqdm.write(f"Ep: {epoch+1:06}, Bat: {batch_index:06}, Loss: {loss.item():.5f}, K img: {total_images/1000:07.1f}")