From 266d05c3633fe57538053ee68852393325ec5dc4 Mon Sep 17 00:00:00 2001 From: Nikolay Semyachkin Date: Sat, 1 Dec 2018 16:59:22 +0300 Subject: [PATCH] working on cpu version --- model.py | 3 ++- predict.py | 2 +- train.py | 13 ++++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/model.py b/model.py index aec6310..7fa56b6 100644 --- a/model.py +++ b/model.py @@ -131,6 +131,7 @@ def __init__(self, num_classes=1, num_filters=32, pretrained=True, is_deconv=Tru self.num_classes = num_classes self.Dropout = Dropout + res_blocks_dec = False self.res_blocks_dec = res_blocks_dec self.pool = nn.MaxPool2d(2, 2) self.relu = nn.ReLU(inplace=True) @@ -324,4 +325,4 @@ def _initialize_weights(self): elif isinstance(m, nn.Linear): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) - m.bias.data.zero_() \ No newline at end of file + m.bias.data.zero_() diff --git a/predict.py b/predict.py index d334df5..eea0bb5 100644 --- a/predict.py +++ b/predict.py @@ -29,7 +29,7 @@ def split_video(filename, n_frames=20): print("file not found") sys.exit(-1) - if file_path.split(".")[-1] == "png": + if file_path.split(".")[-1] != "mp4": imgs = cv2.imread(file_path) imgs = cv2.cvtColor(imgs, cv2.COLOR_BGR2RGB) imgs = np.array(imgs, dtype=np.uint8) diff --git a/train.py b/train.py index 8d524b5..15fe6fd 100644 --- a/train.py +++ b/train.py @@ -11,6 +11,7 @@ import torch from torch.utils import data from torchvision import transforms +import cv2 from torch.autograd import Variable from model import * @@ -22,7 +23,7 @@ def save_checkpoint(checkpoint_path, model, optimizer): print('model saved to %s' % checkpoint_path) def load_checkpoint(checkpoint_path, model, optimizer): - state = torch.load(checkpoint_path) + state = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(state['state_dict']) if optimizer: optimizer.load_state_dict(state['optimizer']) @@ -125,7 +126,10 @@ def __init__(self, path=None, **kwargs): def get_model(self, model): model = model.train() - return model.cuda(self.device_idx) + if torch.cuda.is_available(): + return model.cuda(self.device_idx) + else: + return model def LR_finder(self, dataset, **kwargs): @@ -379,7 +383,10 @@ def predict_crop(self, imgs): for i in range(imgs.shape[0]): img = self.norm(cv2.resize(imgs[i], (256, 320), interpolation=cv2.INTER_LANCZOS4)) img = img.unsqueeze_(0) - img = img.type(torch.FloatTensor).cuda() + if torch.cuda.is_available(): + img = img.type(torch.FloatTensor).cuda() + else: + img = img.type(torch.FloatTensor) output = torch.nn.functional.sigmoid(self.model(Variable(img))) output = output.cpu().data.numpy() y_pred = np.squeeze(output[0])