diff --git a/model.py b/model.py index 683b854..6e2aac9 100644 --- a/model.py +++ b/model.py @@ -24,7 +24,7 @@ class MLP(nn.Module): def forward(self, x): x = x.view(x.size(0), -1) - return self.layers(x) + return self.layers(x.float()) class PlantsDataset(Dataset):