diff --git a/neural_style_app/app.py b/neural_style_app/app.py index 4b3b7e9..d235aea 100644 --- a/neural_style_app/app.py +++ b/neural_style_app/app.py @@ -12,6 +12,9 @@ import torch.nn as nn app = Flask(__name__) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +torch.set_default_device(device) + # Image transformation imsize = 512 if torch.cuda.is_available() else 128 loader = transforms.Compose([