Prześlij pliki do ''
This commit is contained in:
parent
580fbd9d2e
commit
cbe10a0da1
32
net_training.py
Normal file
32
net_training.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from tensorflow.keras.applications import MobileNetV2
|
||||||
|
from tensorflow.keras.applications.mobilenet import preprocess_input, decode_predictions
|
||||||
|
from tensorflow.keras.layers import Dense
|
||||||
|
from tensorflow.keras.models import Model
|
||||||
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
||||||
|
|
||||||
|
podstawa_modelu = MobileNetV2(weights="imagenet", include_top=False, pooling='avg')
|
||||||
|
x = podstawa_modelu.output
|
||||||
|
preds = Dense(7, activation='softmax')(x)
|
||||||
|
|
||||||
|
model = Model(inputs=podstawa_modelu.input, outputs=preds)
|
||||||
|
|
||||||
|
for layer in model.layers[:50]:
|
||||||
|
layer.trainable = False
|
||||||
|
for layer in model.layers[50:]:
|
||||||
|
layer.trainable = True
|
||||||
|
|
||||||
|
train_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
|
||||||
|
|
||||||
|
train_generator = train_datagen.flow_from_directory('./dataset', target_size=(224, 224), color_mode='rgb',
|
||||||
|
batch_size=32, class_mode='categorical', shuffle=True)
|
||||||
|
|
||||||
|
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
||||||
|
step_size_train = train_generator.n // train_generator.batch_size
|
||||||
|
model.fit_generator(generator=train_generator, steps_per_epoch=step_size_train, epochs=30)
|
||||||
|
|
||||||
|
model.save('neural_model.h5')
|
Loading…
Reference in New Issue
Block a user