changes in common part

This commit is contained in:
Kamila Bobkowska 2020-06-04 07:33:58 +00:00
parent 9e608db67c
commit 76ec19cde0

View File

@ -20,29 +20,30 @@ from models.garbageDump import Dump
def createSets(): def createSets():
rootDir = 'ClassificationGarbage' if not path.exists('ClassificationGarbage/trainSet'):
typesDir = ['/cardboard', '/glass', '/metal', '/paper', '/plastic'] rootDir = 'ClassificationGarbage'
testRatio = 0.2 typesDir = ['/cardboard', '/glass', '/metal', '/paper', '/plastic']
testRatio = 0.2
for cls in typesDir: for cls in typesDir:
os.makedirs(rootDir + '/trainSet' + cls) os.makedirs(rootDir + '/trainSet' + cls)
os.makedirs(rootDir + '/testSet' + cls) os.makedirs(rootDir + '/testSet' + cls)
sourceDir = rootDir + cls sourceDir = rootDir + cls
allFileNames = os.listdir(sourceDir) allFileNames = os.listdir(sourceDir)
np.random.shuffle(allFileNames) np.random.shuffle(allFileNames)
trainingFileNames, testFileNames = np.split(np.array(allFileNames), [int(len(allFileNames) * (1 - testRatio))]) trainingFileNames, testFileNames = np.split(np.array(allFileNames), [int(len(allFileNames) * (1 - testRatio))])
trainingFileNames = [sourceDir +'/' + name for name in trainingFileNames.tolist()] trainingFileNames = [sourceDir +'/' + name for name in trainingFileNames.tolist()]
testFileNames = [sourceDir +'/' + name for name in testFileNames.tolist()] testFileNames = [sourceDir +'/' + name for name in testFileNames.tolist()]
print(cls + ':') print(cls + ':')
print('Total images: ', len(allFileNames)) print('Total images: ', len(allFileNames))
print('Training: ', len(trainingFileNames)) print('Training: ', len(trainingFileNames))
print('Testing: ', len(testFileNames)) print('Testing: ', len(testFileNames))
for name in trainingFileNames: for name in trainingFileNames:
shutil.copy(name, rootDir +'/trainSet' + cls) shutil.copy(name, rootDir +'/trainSet' + cls)
for name in testFileNames: for name in testFileNames:
shutil.copy(name, rootDir +'/testSet' + cls) shutil.copy(name, rootDir +'/testSet' + cls)
print("Images copied.") print("Images copied.")
def processTrainData(): def processTrainData():