strip classes

This commit is contained in:
Paweł Felcyn 2023-05-29 12:00:46 +02:00
parent 43c13acdee
commit 38c66cea53
3 changed files with 3 additions and 18 deletions

View File

@ -18,10 +18,10 @@ def _read_training_data() -> TrainingData:
line_attributes = values[:-1]
line_class = values[-1]
attributes.append(line_attributes)
classes.append(line_class)
classes.append(line_class.strip())
return TrainingData(attributes, classes)
def attributes_to_floats(attributes: list[str]) -> list[float]:
def _attributes_to_floats(attributes: list[str]) -> list[float]:
output: list[float] = []
if attributes[0] == 'Longitiudonal':
output.append(0)
@ -88,21 +88,6 @@ trainning_data = _read_training_data()
X = trainning_data.attributes
Y = trainning_data.classes
# le_shape = LabelEncoder()
# le_flexibility = LabelEncoder()
# le_color = LabelEncoder()
# le_shape.fit([x[0] for x in X])
# le_flexibility.fit([x[3] for x in X])
# le_color.fit([x[4] for x in X])
# X_encoded = np.array([
# [le_shape.transform([x[0]])[0], x[1], x[2], le_flexibility.transform([x[3]])[0], le_color.transform([x[4]])[0]]
# for x in X
# ])
# encoder = OneHotEncoder(categories='auto', sparse=False)
# X_encoded = encoder.fit_transform(X_encoded)
model = tree.DecisionTreeClassifier()
encoded = [_attributes_to_floats(x) for x in X]

Binary file not shown.

View File

@ -57,7 +57,7 @@ def create_garbage_pieces() -> List[Garbage]:
for line in lines[1:]:
param = line.strip().split(',')
garbage_pieces.append(
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7]))
Garbage('img', param[0], param[1], param[2], param[3], param[4], param[5], param[6], param[7].strip()))
return garbage_pieces