neural network with comments
This commit is contained in:
parent
b241f3ad1e
commit
2fc5eed8c4
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 1.5 KiB |
63
main.py
63
main.py
@ -1,7 +1,7 @@
|
|||||||
import pygame
|
import pygame
|
||||||
import sys
|
import sys
|
||||||
import random
|
import random
|
||||||
from settings import screen_height, screen_width, SIZE, SPECIES, block_size, tile, road_coords, directions
|
from settings import SIZE, directions, draw_lines_on_window
|
||||||
from src.map import drawRoads, seedForFirstTime, return_fields_list, WORLD_MATRIX, get_type_by_position
|
from src.map import drawRoads, seedForFirstTime, return_fields_list, WORLD_MATRIX, get_type_by_position
|
||||||
from src.Tractor import Tractor
|
from src.Tractor import Tractor
|
||||||
from src.bfs import Astar
|
from src.bfs import Astar
|
||||||
@ -9,7 +9,7 @@ from src.Plant import Plant
|
|||||||
from src.Field import Field
|
from src.Field import Field
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
from src.ID3 import make_decision
|
from src.ID3 import action
|
||||||
import torch
|
import torch
|
||||||
import src.neural_networks as neural_networks
|
import src.neural_networks as neural_networks
|
||||||
|
|
||||||
@ -43,15 +43,13 @@ def recognize_plants(fields, destination):
|
|||||||
else:
|
else:
|
||||||
pred = 'none'
|
pred = 'none'
|
||||||
print(pred)
|
print(pred)
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
# pygame initialization
|
# pygame initialization
|
||||||
pygame.init()
|
pygame.init()
|
||||||
clock = pygame.time.Clock()
|
clock = pygame.time.Clock()
|
||||||
#pygame.mouse.set_visible(False)
|
|
||||||
|
|
||||||
#GAME SCREEN
|
# GAME SCREEN
|
||||||
screen = pygame.display.set_mode(SIZE)
|
screen = pygame.display.set_mode(SIZE)
|
||||||
pygame.display.set_caption("Traktor_interaktor")
|
pygame.display.set_caption("Traktor_interaktor")
|
||||||
background = pygame.image.load("assets/farmland.jpg")
|
background = pygame.image.load("assets/farmland.jpg")
|
||||||
@ -60,14 +58,7 @@ screen.fill((90,50,20))
|
|||||||
background.fill((90,50,20))
|
background.fill((90,50,20))
|
||||||
background = drawRoads(background)
|
background = drawRoads(background)
|
||||||
|
|
||||||
for line in range(26):
|
draw_lines_on_window(background)
|
||||||
pygame.draw.line(background, (0, 0, 0), (0, line * block_size), (936, line * block_size))
|
|
||||||
pygame.draw.line(background, (0, 0, 0), (line * block_size, 0), (line * block_size, screen_height))
|
|
||||||
|
|
||||||
pygame.draw.line(background, (0, 0, 0), (968, 285), (1336 , 285))
|
|
||||||
pygame.draw.line(background, (0, 0, 0), (968, 649), (1336 , 649))
|
|
||||||
pygame.draw.line(background, (0, 0, 0), (968, 285), (968, 649))
|
|
||||||
pygame.draw.line(background, (0, 0, 0), (1336, 285), (1336, 649))
|
|
||||||
|
|
||||||
#TRACTOR
|
#TRACTOR
|
||||||
tractor = Tractor('oil','manual', 'fuel', 'fertilizer1', 20)
|
tractor = Tractor('oil','manual', 'fuel', 'fertilizer1', 20)
|
||||||
@ -82,20 +73,18 @@ plant_group = pygame.sprite.Group()
|
|||||||
plant_group = seedForFirstTime()
|
plant_group = seedForFirstTime()
|
||||||
fields = return_fields_list()
|
fields = return_fields_list()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
tractor_move = pygame.USEREVENT + 1
|
tractor_move = pygame.USEREVENT + 1
|
||||||
pygame.time.set_timer(tractor_move, 200)
|
pygame.time.set_timer(tractor_move, 200)
|
||||||
moves = []
|
moves = []
|
||||||
goal_astar = Astar()
|
goal_astar = Astar()
|
||||||
mx=random.randrange(0, 936, 36)
|
mx = random.randrange(0, 936, 36)
|
||||||
my=random.randrange(0, 936, 36)
|
my = random.randrange(0, 936, 36)
|
||||||
destination = (mx, my)
|
destination = (mx, my)
|
||||||
print("Destination: ", destination)
|
print("Destination: ", destination)
|
||||||
mx=int((mx+18)/36)
|
mx = int((mx+18)/36)
|
||||||
my=int((my+18)/36)
|
my = int((my+18)/36)
|
||||||
print("Destination: ", mx,my)
|
print("Destination: ", mx, my)
|
||||||
|
|
||||||
#ID3 TREE LOADING
|
#ID3 TREE LOADING
|
||||||
dtree = pickle.load(open(os.path.join('src','tree.plk'),'rb'))
|
dtree = pickle.load(open(os.path.join('src','tree.plk'),'rb'))
|
||||||
@ -104,25 +93,6 @@ dtree = pickle.load(open(os.path.join('src','tree.plk'),'rb'))
|
|||||||
this_field = WORLD_MATRIX[mx][my]
|
this_field = WORLD_MATRIX[mx][my]
|
||||||
this_contain = Field.getContain(this_field)
|
this_contain = Field.getContain(this_field)
|
||||||
|
|
||||||
def action(this_contain):
|
|
||||||
if isinstance(this_contain, Plant):
|
|
||||||
this_plant = this_contain
|
|
||||||
params=Plant.getParameters(this_plant)
|
|
||||||
# print(this_field)
|
|
||||||
#ID3 decision
|
|
||||||
decision=make_decision(params[0],params[1],params[2],params[3],params[4],tractor.fuel,tractor.capacity,params[5],dtree)
|
|
||||||
# print('wzorst',params[0],'wilgotnosc',params[1],'dni_od_nawiezienia',params[2],'pogoda',params[3],'zdrowa',params[4],'paliwo',tractor.fuel,'pojemnosc eq',tractor.capacity,'cena sprzedazy',params[5])
|
|
||||||
# print(decision)
|
|
||||||
if decision == 1:
|
|
||||||
print('Gotowe do zbioru')
|
|
||||||
return 1
|
|
||||||
else:
|
|
||||||
print('nie zbieramy')
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
print('Road, no plant growing')
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
moves = goal_astar.search(
|
moves = goal_astar.search(
|
||||||
[tractor.rect.x, tractor.rect.y, directions[tractor.rotation]], destination)
|
[tractor.rect.x, tractor.rect.y, directions[tractor.rotation]], destination)
|
||||||
@ -142,7 +112,7 @@ if __name__ == "__main__":
|
|||||||
if event.type == pygame.KEYDOWN:
|
if event.type == pygame.KEYDOWN:
|
||||||
if event.key==pygame.K_RETURN:
|
if event.key==pygame.K_RETURN:
|
||||||
tractor.collect(plant_group)
|
tractor.collect(plant_group)
|
||||||
recognize_plants(fields, destination)
|
# recognize_plants(fields, destination)
|
||||||
if event.key == pygame.K_ESCAPE:
|
if event.key == pygame.K_ESCAPE:
|
||||||
running = False
|
running = False
|
||||||
if event.type == tractor_move:
|
if event.type == tractor_move:
|
||||||
@ -151,11 +121,16 @@ if __name__ == "__main__":
|
|||||||
step = moves_list.pop() # pop the last element
|
step = moves_list.pop() # pop the last element
|
||||||
moves = tuple(moves_list) # convert back to tuple
|
moves = tuple(moves_list) # convert back to tuple
|
||||||
tractor.movement(step[0])
|
tractor.movement(step[0])
|
||||||
if tractor.rect.x == destination[0] and tractor.rect.y == destination[1] and action(this_contain) == 1:
|
# checks if tractor is in destiantion field and make decision if it's ready to collect
|
||||||
|
if tractor.rect.x == destination[0] and tractor.rect.y == destination[1] and action(this_contain, Plant, tractor, dtree) == 1:
|
||||||
|
# show what should be in this field
|
||||||
print('expected:', expected_plant)
|
print('expected:', expected_plant)
|
||||||
if recognize_plants(fields, destination) == 'carrot' or 'potato' or 'wheat':
|
# check if program correctly recognize plant
|
||||||
|
if recognize_plants(fields, destination) == expected_plant:
|
||||||
|
# if correctly recognized than plant can be collected
|
||||||
tractor.collect(plant_group)
|
tractor.collect(plant_group)
|
||||||
|
else:
|
||||||
|
print('wrong recognition')
|
||||||
|
|
||||||
|
|
||||||
Tractor.movement_using_keys(tractor)
|
Tractor.movement_using_keys(tractor)
|
||||||
|
11
settings.py
11
settings.py
@ -1,4 +1,5 @@
|
|||||||
from cmath import sqrt
|
from cmath import sqrt
|
||||||
|
import pygame
|
||||||
|
|
||||||
|
|
||||||
screen_width = 1368
|
screen_width = 1368
|
||||||
@ -19,3 +20,13 @@ field_size = field_width*field_height
|
|||||||
fields_amount = 25
|
fields_amount = 25
|
||||||
|
|
||||||
directions = {0: 'UP', 90: 'RIGHT', 180: 'DOWN', 270: 'LEFT'}
|
directions = {0: 'UP', 90: 'RIGHT', 180: 'DOWN', 270: 'LEFT'}
|
||||||
|
|
||||||
|
def draw_lines_on_window(background):
|
||||||
|
for line in range(26):
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (0, line * block_size), (936, line * block_size))
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (line * block_size, 0), (line * block_size, screen_height))
|
||||||
|
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (968, 285), (1336 , 285))
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (968, 649), (1336 , 649))
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (968, 285), (968, 649))
|
||||||
|
pygame.draw.line(background, (0, 0, 0), (1336, 285), (1336, 649))
|
19
src/ID3.py
19
src/ID3.py
@ -42,3 +42,22 @@ def learnTree():
|
|||||||
# #przy robaczywej == 1 daje ok czyli jak 1 to git jest mozna zbierac, ale planowalem inaczej
|
# #przy robaczywej == 1 daje ok czyli jak 1 to git jest mozna zbierac, ale planowalem inaczej
|
||||||
# decision=make_decision(70,85,12,4,0,65,54,1500,dtree)
|
# decision=make_decision(70,85,12,4,0,65,54,1500,dtree)
|
||||||
# print(decision)
|
# print(decision)
|
||||||
|
|
||||||
|
def action(this_contain, Plant, tractor, dtree):
|
||||||
|
if isinstance(this_contain, Plant):
|
||||||
|
this_plant = this_contain
|
||||||
|
params=Plant.getParameters(this_plant)
|
||||||
|
# print(this_field)
|
||||||
|
#ID3 decision
|
||||||
|
decision=make_decision(params[0],params[1],params[2],params[3],params[4],tractor.fuel,tractor.capacity,params[5],dtree)
|
||||||
|
# print('wzorst',params[0],'wilgotnosc',params[1],'dni_od_nawiezienia',params[2],'pogoda',params[3],'zdrowa',params[4],'paliwo',tractor.fuel,'pojemnosc eq',tractor.capacity,'cena sprzedazy',params[5])
|
||||||
|
# print(decision)
|
||||||
|
if decision == 1:
|
||||||
|
print('Gotowe do zbioru')
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
print('nie zbieramy')
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print('Road, no plant growing')
|
||||||
|
return 0
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -6,24 +6,45 @@ from torch.optim import Adam
|
|||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Check if CUDA-enabled GPU is available and set the device
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Define the classes for classification
|
||||||
classes = ['carrot', 'potato', 'wheat']
|
classes = ['carrot', 'potato', 'wheat']
|
||||||
|
|
||||||
|
# Set the paths for the training and test data directories
|
||||||
train_path = 'assets/learning/train'
|
train_path = 'assets/learning/train'
|
||||||
test_path = 'assets/learning/test'
|
test_path = 'assets/learning/test'
|
||||||
|
|
||||||
|
#list of transforms to compose (lista przekształceń do utworzenia)
|
||||||
transformer = torchvision.transforms.Compose([
|
transformer = torchvision.transforms.Compose([
|
||||||
|
# resize input image to the given size
|
||||||
torchvision.transforms.Resize((150, 150)),
|
torchvision.transforms.Resize((150, 150)),
|
||||||
|
# convert image to tensor(muli dim array)
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
|
# normalize tensor image with wit mean and standard deviation
|
||||||
|
# normalize doesn't support PIL image -> that is why we do .ToTensor before
|
||||||
|
# output[channel] = (input[channel] - mean[channel]) / std[channel]
|
||||||
torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, num_classes=3):
|
def __init__(self, num_classes=3):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
|
# Sequential - ordered dictionary
|
||||||
|
# Define the convolutional layers
|
||||||
|
# The output of one layer serves as the input to the next layer (3->12->20->32)
|
||||||
self.features = nn.Sequential(
|
self.features = nn.Sequential(
|
||||||
|
# Applies a 2D convolution over an input signal composed of several input planes
|
||||||
|
# Stosuje splot 2D dla sygnału wejściowego złożonego z kilku płaszczyzn wejściowych
|
||||||
nn.Conv2d(3, 12, kernel_size=3, stride=1, padding=1),
|
nn.Conv2d(3, 12, kernel_size=3, stride=1, padding=1),
|
||||||
|
# parameter of torch.nn.BatchNorm2d is the number of dimensions/channels that output
|
||||||
|
# from the last layer and come in to the batch norm layer.
|
||||||
nn.BatchNorm2d(12),
|
nn.BatchNorm2d(12),
|
||||||
|
# activation function relu(x) = { 0 if x<0, x if x > 0}
|
||||||
|
# after each layer, an activation function needs to be applied
|
||||||
|
# so as to make the network non-linear and fit complex data
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(kernel_size=2),
|
nn.MaxPool2d(kernel_size=2),
|
||||||
nn.Conv2d(12, 20, kernel_size=3, stride=1, padding=1),
|
nn.Conv2d(12, 20, kernel_size=3, stride=1, padding=1),
|
||||||
@ -32,72 +53,111 @@ class Net(nn.Module):
|
|||||||
nn.BatchNorm2d(32),
|
nn.BatchNorm2d(32),
|
||||||
nn.ReLU()
|
nn.ReLU()
|
||||||
)
|
)
|
||||||
|
# takes the flattened feature maps from the previous convolutional layers as input
|
||||||
self.classifier = nn.Linear(32 * 75 * 75, num_classes)
|
self.classifier = nn.Linear(32 * 75 * 75, num_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
# Forward pass through the network
|
||||||
x = self.features(x)
|
x = self.features(x)
|
||||||
|
# Pass the input through the sequential block of
|
||||||
|
# convolutional layers and activation functions
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
|
# Reshape the tensor by flattening it along the second dimension
|
||||||
x = self.classifier(x)
|
x = self.classifier(x)
|
||||||
|
# Pass the flattened tensor through the linear layer for classification
|
||||||
return x
|
return x
|
||||||
|
# Return the output tensor
|
||||||
|
|
||||||
def train(dataloader, model, optimizer, loss_fn):
|
def train(dataloader, model, optimizer, loss_fn):
|
||||||
model.train()
|
model.train()
|
||||||
size = len(dataloader.dataset)
|
size = len(dataloader.dataset)
|
||||||
|
# Get the total number of training examples
|
||||||
for batch, (X, y) in enumerate(dataloader):
|
for batch, (X, y) in enumerate(dataloader):
|
||||||
X, y = X.to(device), y.to(device)
|
X, y = X.to(device), y.to(device)
|
||||||
|
# Move the input tensors to the appropriate device (CPU or GPU)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
# Clear the gradients of the model parameters
|
||||||
pred = model(X.float())
|
pred = model(X.float())
|
||||||
|
# Perform a forward pass to obtain the predicted outputs
|
||||||
loss = loss_fn(pred, y)
|
loss = loss_fn(pred, y)
|
||||||
|
# Compute the loss between the predicted outputs and the ground truth labels
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
# Perform backpropagation to compute the gradients of the model parameters
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
# Update the model parameters using the computed gradients
|
||||||
if batch % 5 == 0:
|
if batch % 5 == 0:
|
||||||
current = batch * len(X)
|
current = batch * len(X)
|
||||||
|
# Compute the current batch size
|
||||||
print(f"loss: {loss.item():>7f} [{current:>5d}/{size:>5d}]")
|
print(f"loss: {loss.item():>7f} [{current:>5d}/{size:>5d}]")
|
||||||
|
# Print the current loss and the progress of the training in material def accuracy
|
||||||
|
|
||||||
def test(dataloader, model, loss_fn):
|
def test(dataloader, model, loss_fn):
|
||||||
model.eval()
|
model.eval()
|
||||||
|
# Set the model to evaluation mode
|
||||||
size = len(dataloader.dataset)
|
size = len(dataloader.dataset)
|
||||||
|
# Get the total number of examples in the dataloader
|
||||||
test_loss, correct = 0, 0
|
test_loss, correct = 0, 0
|
||||||
|
# Initialize variables to keep track of the total test
|
||||||
|
# loss and the number of correct predictions
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
# Disable gradient computation
|
||||||
for X, y in dataloader:
|
for X, y in dataloader:
|
||||||
X, y = X.to(device), y.to(device)
|
X, y = X.to(device), y.to(device)
|
||||||
|
# Move the input tensors to the appropriate device (CPU or GPU)
|
||||||
pred = model(X.float())
|
pred = model(X.float())
|
||||||
|
# Perform a forward pass to obtain the predicted outputs
|
||||||
test_loss += loss_fn(pred, y).item()
|
test_loss += loss_fn(pred, y).item()
|
||||||
|
# Compute the loss between the predicted outputs and the ground truth labels
|
||||||
correct += (pred.argmax(1) == y).sum().item()
|
correct += (pred.argmax(1) == y).sum().item()
|
||||||
|
# Count the number of correct predictions
|
||||||
|
|
||||||
test_loss /= size
|
test_loss /= size
|
||||||
|
# Calculate the average test loss
|
||||||
accuracy = 100.0 * correct / size
|
accuracy = 100.0 * correct / size
|
||||||
|
# Calculate the accuracy as a percentage
|
||||||
print(f"Test Error:\n Accuracy: {accuracy:.1f}%, Avg loss: {test_loss:.8f}\n")
|
print(f"Test Error:\n Accuracy: {accuracy:.1f}%, Avg loss: {test_loss:.8f}\n")
|
||||||
|
# Print the test accuracy and average test loss
|
||||||
|
|
||||||
def predict(img_path, model):
|
def predict(img_path, model):
|
||||||
image = Image.open(img_path).convert('RGB')
|
image = Image.open(img_path).convert('RGB')
|
||||||
|
# Open the image file from the given path and convert it to RGB mode
|
||||||
image_tensor = transformer(image).unsqueeze(0).to(device)
|
image_tensor = transformer(image).unsqueeze(0).to(device)
|
||||||
|
# Apply the image transformation pipeline defined earlier and convert the image to a tensor
|
||||||
|
# Add an extra dimension at the beginning to represent the batch
|
||||||
|
# Move the image tensor to the appropriate device (CPU or GPU)
|
||||||
output = model(image_tensor)
|
output = model(image_tensor)
|
||||||
|
# Pass the image tensor through the model to obtain the output logits
|
||||||
_, predicted_idx = torch.max(output, 1)
|
_, predicted_idx = torch.max(output, 1)
|
||||||
|
# Find the index of the predicted class by taking the maximum value along the second dimension
|
||||||
pred = classes[predicted_idx.item()]
|
pred = classes[predicted_idx.item()]
|
||||||
|
# Retrieve the corresponding class label from the classes list using the predicted index
|
||||||
return pred
|
return pred
|
||||||
|
|
||||||
def learn():
|
def learn():
|
||||||
num_epochs = 50
|
num_epochs = 50
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
|
# Create a dataset from the images in the train_path directory
|
||||||
train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transformer)
|
train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transformer)
|
||||||
|
# Create a data loader for the train dataset to load data in batches
|
||||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
# Create a dataset from the images in the test_path directory
|
||||||
test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transformer)
|
test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transformer)
|
||||||
|
# Create a data loader for the test dataset to load data in batches
|
||||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
# Create an instance of the neural network model
|
||||||
model = Net(len(classes)).to(device)
|
model = Net(len(classes)).to(device)
|
||||||
|
# Create an optimizer for updating the model parameters during training
|
||||||
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=0.0001)
|
optimizer = Adam(model.parameters(), lr=1e-3, weight_decay=0.0001)
|
||||||
|
# Define the loss function for computing the training loss
|
||||||
loss_fn = nn.CrossEntropyLoss()
|
loss_fn = nn.CrossEntropyLoss()
|
||||||
|
# Perform training for the specified number of epochs
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
print(f"Epoch {epoch + 1}\n-------------------------------")
|
print(f"Epoch {epoch + 1}\n-------------------------------")
|
||||||
|
# Train the model using the training data
|
||||||
train(train_loader, model, optimizer, loss_fn)
|
train(train_loader, model, optimizer, loss_fn)
|
||||||
|
# Evaluate the model on the test data
|
||||||
test(test_loader, model, loss_fn)
|
test(test_loader, model, loss_fn)
|
||||||
|
# Print a message indicating that the training is done
|
||||||
print("Done!")
|
print("Done!")
|
||||||
|
# Save the trained model's state dictionary to a file
|
||||||
torch.save(model.state_dict(), 'plants2.model')
|
torch.save(model.state_dict(), 'plants2.model')
|
Loading…
Reference in New Issue
Block a user