25 lines
890 B
Python
25 lines
890 B
Python
|
#niepotrzebny plik, bo algorytm już został wytrenowany na Google Colab
|
||
|
from fastai.vision.all import *
|
||
|
|
||
|
DATASET_PATH = Path('../dataset')
|
||
|
|
||
|
glass_train = (DATASET_PATH/'glass').ls().sorted()
|
||
|
plastic_train = (DATASET_PATH/'plastic').ls().sorted()
|
||
|
paper_train = (DATASET_PATH/'paper').ls().sorted()
|
||
|
others_train = (DATASET_PATH/'others').ls().sorted()
|
||
|
|
||
|
trash_datablock = DataBlock(
|
||
|
get_items=get_image_files,
|
||
|
get_y=parent_label,
|
||
|
blocks=(ImageBlock, CategoryBlock),
|
||
|
item_tfms=RandomResizedCrop(224, min_scale=0.3),
|
||
|
splitter=RandomSplitter(valid_pct=0.2, seed=100),
|
||
|
batch_tfms=aug_transforms(mult=2)
|
||
|
)
|
||
|
|
||
|
dls = trash_datablock.dataloaders(os.path.join(Path(os.getcwd()), DATASET_PATH), num_workers=0, bs=32)
|
||
|
|
||
|
learn = vision_learner(dls, resnet34, metrics=error_rate)
|
||
|
learn.fine_tune(4)
|
||
|
learn.show_results()
|
||
|
learn.export()
|