ml/13_pytorch_resnet18.pyΒΆ

import torchvision
import json
import glob
from torchvision.models import resnet18
from torch import nn, optim
from torch.utils.data import DataLoader

from astrodata.ml.models import PytorchModel

if __name__ == "__main__":
    classes = json.load(open("testdata/imagenet_ex/imagenet_class_index.json"))
    weights = torchvision.models.ResNet18_Weights.DEFAULT
    transform = weights.transforms()

    model = PytorchModel(
        model_class=resnet18(weights=weights),
        model_params={},
        loss_fn=nn.CrossEntropyLoss,
        optimizer=optim.AdamW,
        optimizer_params={"lr": 1e-3},
        epochs=10,
        batch_size=32,
        device="cpu",
        with_weight_init=True,
    )
    
    print(model)
    
    img_paths = sorted(glob.glob("testdata/imagenet_ex/*.jpg"))

    img_list = []

    for image_path in img_paths:
        img_list.append(transform(torchvision.io.read_image(image_path)))
        
    dataloader_img_list = DataLoader(img_list, batch_size=1)
    pred = model.predict(dataloader_img_list, 1)
    
    for i in range(len(pred)):
        print(f"Ground Truth: {str.split(img_paths[i], '/')[-1]}    Prediction: {classes[str(pred[i])][1]}")