49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import warnings
|
|
from typing import Literal
|
|
from torchvision.datasets import DatasetFolder
|
|
from torch.utils.data import WeightedRandomSampler
|
|
import numpy as np
|
|
|
|
# Ignore FutureWarnings, torch loves them, especially when deserializing models
|
|
# the problem however is that there's nothing I can do about loading unsafe
|
|
# model pickles + how can a pickle be unsafe? they are delicious
|
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
|
|
|
RAW_DATASET_PATH = "raw/"
|
|
|
|
ClassNameTp = Literal["civilian", "military"]
|
|
SetNameTp = Literal["train", "test", "validate"]
|
|
|
|
SETS: list[SetNameTp] = ["train", "test"]
|
|
CLASSES: list[ClassNameTp] = ["civilian", "military"]
|
|
|
|
TORCH_FILE_EXT = ".pth"
|
|
|
|
MODEL_FILENAME = "model" + TORCH_FILE_EXT
|
|
|
|
IMAGE_TARGET_SIZE = 225
|
|
|
|
|
|
def get_balancing_sampler(dataset: DatasetFolder) -> WeightedRandomSampler:
|
|
"""Create a WeightedRandomSampler that balances the classes in the dataset
|
|
|
|
Args:
|
|
dataset (DatasetFolder): The dataset to balance
|
|
|
|
Returns:
|
|
WeightedRandomSampler: A sampler that returns items from dataset randomly, but with no bias
|
|
"""
|
|
class_counts = np.bincount(dataset.targets)
|
|
class_weights = 1.0 / class_counts
|
|
sample_weights = class_weights[dataset.targets]
|
|
|
|
sampler = WeightedRandomSampler(
|
|
weights=sample_weights.tolist(),
|
|
num_samples=len(sample_weights),
|
|
replacement=True,
|
|
)
|
|
return sampler
|
|
|
|
|
|
from model import CNNv2 as OurCNN
|