Added model.py
This commit is contained in:
parent
61212f5e4c
commit
70f8d14fb8
@ -2,9 +2,13 @@ FROM ubuntu:latest
|
||||
|
||||
RUN apt-get update && apt-get install -y python3-pip unzip coreutils
|
||||
|
||||
RUN pip install --user kaggle pandas
|
||||
RUN pip install --user kaggle pandas scikit-learn tensorflow
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY ./data_processing.sh ./
|
||||
COPY ./OrangeQualityData.csv ./
|
||||
COPY ./orange_quality_model_tf.h5 ./
|
||||
COPY ./predictions_tf.json ./
|
||||
|
||||
CMD ["python3", "data_processing.sh"]
|
||||
|
40
model.py
Normal file
40
model.py
Normal file
@ -0,0 +1,40 @@
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
from sklearn.preprocessing import LabelEncoder, StandardScaler
|
||||
from sklearn.model_selection import train_test_split
|
||||
import json
|
||||
|
||||
df = pd.read_csv('OrangeQualityData.csv')
|
||||
|
||||
encoder = LabelEncoder()
|
||||
df["Color"] = encoder.fit_transform(df["Color"])
|
||||
df["Variety"] = encoder.fit_transform(df["Variety"])
|
||||
df["Blemishes"] = df["Blemishes (Y/N)"].apply(lambda x: 1 if x.startswith("Y") else 0)
|
||||
|
||||
df.drop(columns=["Blemishes (Y/N)"], inplace=True)
|
||||
|
||||
X = df.drop(columns=["Quality (1-5)"])
|
||||
y = df["Quality (1-5)"]
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
scaler = StandardScaler()
|
||||
X_train_scaled = scaler.fit_transform(X_train)
|
||||
X_test_scaled = scaler.transform(X_test)
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Dense(64, activation='relu', input_shape=(X_train_scaled.shape[1],)),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
tf.keras.layers.Dense(1)
|
||||
])
|
||||
|
||||
model.compile(optimizer='sgd', loss='mse')
|
||||
|
||||
history = model.fit(X_train_scaled, y_train, epochs=100, verbose=0, validation_data=(X_test_scaled, y_test))
|
||||
|
||||
model.save('orange_quality_model_tf.h5')
|
||||
|
||||
predictions = model.predict(X_test_scaled)
|
||||
|
||||
with open('predictions_tf.json', 'w') as f:
|
||||
json.dump(predictions.tolist(), f, indent=4)
|
BIN
orange_quality_model_tf.h5
Normal file
BIN
orange_quality_model_tf.h5
Normal file
Binary file not shown.
149
predictions_tf.json
Normal file
149
predictions_tf.json
Normal file
@ -0,0 +1,149 @@
|
||||
[
|
||||
[
|
||||
4.033654689788818
|
||||
],
|
||||
[
|
||||
4.351343631744385
|
||||
],
|
||||
[
|
||||
1.5783445835113525
|
||||
],
|
||||
[
|
||||
4.1473917961120605
|
||||
],
|
||||
[
|
||||
3.9993104934692383
|
||||
],
|
||||
[
|
||||
3.8747072219848633
|
||||
],
|
||||
[
|
||||
4.48088264465332
|
||||
],
|
||||
[
|
||||
2.086705207824707
|
||||
],
|
||||
[
|
||||
4.511044979095459
|
||||
],
|
||||
[
|
||||
3.5592899322509766
|
||||
],
|
||||
[
|
||||
4.714838027954102
|
||||
],
|
||||
[
|
||||
4.666493892669678
|
||||
],
|
||||
[
|
||||
4.15949010848999
|
||||
],
|
||||
[
|
||||
4.062054634094238
|
||||
],
|
||||
[
|
||||
3.3104782104492188
|
||||
],
|
||||
[
|
||||
3.671990394592285
|
||||
],
|
||||
[
|
||||
4.121957302093506
|
||||
],
|
||||
[
|
||||
5.101129055023193
|
||||
],
|
||||
[
|
||||
3.2231392860412598
|
||||
],
|
||||
[
|
||||
4.860662937164307
|
||||
],
|
||||
[
|
||||
3.1851491928100586
|
||||
],
|
||||
[
|
||||
4.8820481300354
|
||||
],
|
||||
[
|
||||
2.043302059173584
|
||||
],
|
||||
[
|
||||
3.892570972442627
|
||||
],
|
||||
[
|
||||
4.5895609855651855
|
||||
],
|
||||
[
|
||||
2.4837639331817627
|
||||
],
|
||||
[
|
||||
2.157947063446045
|
||||
],
|
||||
[
|
||||
4.463848114013672
|
||||
],
|
||||
[
|
||||
4.560668468475342
|
||||
],
|
||||
[
|
||||
4.00075626373291
|
||||
],
|
||||
[
|
||||
2.392961263656616
|
||||
],
|
||||
[
|
||||
2.701927423477173
|
||||
],
|
||||
[
|
||||
3.338017463684082
|
||||
],
|
||||
[
|
||||
5.018939018249512
|
||||
],
|
||||
[
|
||||
3.04030179977417
|
||||
],
|
||||
[
|
||||
3.2020576000213623
|
||||
],
|
||||
[
|
||||
4.9051432609558105
|
||||
],
|
||||
[
|
||||
4.618875980377197
|
||||
],
|
||||
[
|
||||
3.9825124740600586
|
||||
],
|
||||
[
|
||||
4.486858367919922
|
||||
],
|
||||
[
|
||||
4.929688930511475
|
||||
],
|
||||
[
|
||||
3.033353328704834
|
||||
],
|
||||
[
|
||||
4.897153854370117
|
||||
],
|
||||
[
|
||||
3.149707555770874
|
||||
],
|
||||
[
|
||||
3.7602972984313965
|
||||
],
|
||||
[
|
||||
4.0963287353515625
|
||||
],
|
||||
[
|
||||
3.0543882846832275
|
||||
],
|
||||
[
|
||||
3.698991060256958
|
||||
],
|
||||
[
|
||||
2.9619979858398438
|
||||
]
|
||||
]
|
Loading…
Reference in New Issue
Block a user