ml/13_pytorch_resnet18.pyΒΆ
import glob
import json
import torch
import torchvision
from sklearn.metrics import accuracy_score
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from astrodata.ml.metrics import SklearnMetric
from astrodata.ml.models import PytorchModel
if __name__ == "__main__":
classes = json.load(open("testdata/imagenet_ex/imagenet_class_index.json"))
weights = torchvision.models.ResNet18_Weights.DEFAULT
transform = weights.transforms()
metrics = [SklearnMetric(accuracy_score, greater_is_better=True)]
model = PytorchModel(
model_class=resnet18(weights=weights),
model_params={},
loss_fn=nn.CrossEntropyLoss,
optimizer=optim.AdamW,
optimizer_params={"lr": 1e-3},
epochs=10,
batch_size=32,
device="cpu",
with_weight_init=True,
)
print(model)
img_paths = sorted(glob.glob("testdata/imagenet_ex/*.jpg"))
img_list = []
for image_path in img_paths:
img_list.append(transform(torchvision.io.read_image(image_path)))
y_true = [242, 0]
img_dataset = torch.utils.data.TensorDataset(
torch.stack(img_list), torch.tensor(y_true, dtype=torch.long)
)
pred = model.predict(data=img_dataset, batch_size=1)
dataset_val = torch.utils.data.TensorDataset(
torch.stack(img_list), torch.tensor(y_true, dtype=torch.long)
)
pred = model.predict(data=img_dataset, batch_size=1)
print(model.get_metrics(dataset=img_dataset, metrics=metrics))
print(pred)
for i in range(len(pred)):
print(
f"Ground Truth: {str.split(img_paths[i], '/')[-1]} Prediction: {classes[str(pred[i])][1]}"
)