SI-projekt-smieciarka/snn/nn_train.py
2022-05-27 05:28:53 +02:00

25 lines
890 B
Python

#niepotrzebny plik, bo algorytm już został wytrenowany na Google Colab
from fastai.vision.all import *
DATASET_PATH = Path('../dataset')
glass_train = (DATASET_PATH/'glass').ls().sorted()
plastic_train = (DATASET_PATH/'plastic').ls().sorted()
paper_train = (DATASET_PATH/'paper').ls().sorted()
others_train = (DATASET_PATH/'others').ls().sorted()
trash_datablock = DataBlock(
get_items=get_image_files,
get_y=parent_label,
blocks=(ImageBlock, CategoryBlock),
item_tfms=RandomResizedCrop(224, min_scale=0.3),
splitter=RandomSplitter(valid_pct=0.2, seed=100),
batch_tfms=aug_transforms(mult=2)
)
dls = trash_datablock.dataloaders(os.path.join(Path(os.getcwd()), DATASET_PATH), num_workers=0, bs=32)
learn = vision_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(4)
learn.show_results()
learn.export()