projekt_ml/settings.py
2025-01-19 16:27:39 +01:00

49 lines
1.4 KiB
Python

import warnings
from typing import Literal
from torchvision.datasets import DatasetFolder
from torch.utils.data import WeightedRandomSampler
import numpy as np
# Ignore FutureWarnings, torch loves them, especially when deserializing models
# the problem however is that there's nothing I can do about loading unsafe
# model pickles + how can a pickle be unsafe? they are delicious
warnings.simplefilter(action="ignore", category=FutureWarning)
RAW_DATASET_PATH = "raw/"
ClassNameTp = Literal["civilian", "military"]
SetNameTp = Literal["train", "test", "validate"]
SETS: list[SetNameTp] = ["train", "test"]
CLASSES: list[ClassNameTp] = ["civilian", "military"]
TORCH_FILE_EXT = ".pth"
MODEL_FILENAME = "model" + TORCH_FILE_EXT
IMAGE_TARGET_SIZE = 225
def get_balancing_sampler(dataset: DatasetFolder) -> WeightedRandomSampler:
"""Create a WeightedRandomSampler that balances the classes in the dataset
Args:
dataset (DatasetFolder): The dataset to balance
Returns:
WeightedRandomSampler: A sampler that returns items from dataset randomly, but with no bias
"""
class_counts = np.bincount(dataset.targets)
class_weights = 1.0 / class_counts
sample_weights = class_weights[dataset.targets]
sampler = WeightedRandomSampler(
weights=sample_weights.tolist(),
num_samples=len(sample_weights),
replacement=True,
)
return sampler
from model import CNNv2 as OurCNN