|
8 | 8 | import torch.onnx
|
9 | 9 |
|
10 | 10 | import data
|
11 |
| -import model |
| 11 | +from model import PositionalEncoding, RNNModel, TransformerModel |
12 | 12 |
|
13 | 13 | parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM/GRU/Transformer Language Model')
|
14 | 14 | parser.add_argument('--data', type=str, default='./data/wikitext-2',
|
@@ -108,9 +108,9 @@ def batchify(data, bsz):
|
108 | 108 |
|
109 | 109 | ntokens = len(corpus.dictionary)
|
110 | 110 | if args.model == 'Transformer':
|
111 |
| - model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device) |
| 111 | + model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device) |
112 | 112 | else:
|
113 |
| - model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) |
| 113 | + model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device) |
114 | 114 |
|
115 | 115 | criterion = nn.NLLLoss()
|
116 | 116 |
|
@@ -243,7 +243,33 @@ def export_onnx(path, batch_size, seq_len):
|
243 | 243 |
|
244 | 244 | # Load the best saved model.
|
245 | 245 | with open(args.save, 'rb') as f:
|
246 |
| - model = torch.load(f) |
| 246 | + if args.model == 'Transformer': |
| 247 | + safe_globals = [ |
| 248 | + PositionalEncoding, |
| 249 | + TransformerModel, |
| 250 | + torch.nn.functional.relu, |
| 251 | + torch.nn.modules.activation.MultiheadAttention, |
| 252 | + torch.nn.modules.container.ModuleList, |
| 253 | + torch.nn.modules.dropout.Dropout, |
| 254 | + torch.nn.modules.linear.Linear, |
| 255 | + torch.nn.modules.linear.NonDynamicallyQuantizableLinear, |
| 256 | + torch.nn.modules.normalization.LayerNorm, |
| 257 | + torch.nn.modules.sparse.Embedding, |
| 258 | + torch.nn.modules.transformer.TransformerEncoder, |
| 259 | + torch.nn.modules.transformer.TransformerEncoderLayer, |
| 260 | + ] |
| 261 | + else: |
| 262 | + safe_globals = [ |
| 263 | + RNNModel, |
| 264 | + torch.nn.modules.dropout.Dropout, |
| 265 | + torch.nn.modules.linear.Linear, |
| 266 | + torch.nn.modules.rnn.GRU, |
| 267 | + torch.nn.modules.rnn.LSTM, |
| 268 | + torch.nn.modules.rnn.RNN, |
| 269 | + torch.nn.modules.sparse.Embedding, |
| 270 | + ] |
| 271 | + with torch.serialization.safe_globals(safe_globals): |
| 272 | + model = torch.load(f) |
247 | 273 | # after load the rnn params are not a continuous chunk of memory
|
248 | 274 | # this makes them a continuous chunk, and will speed up forward pass
|
249 | 275 | # Currently, only rnn model supports flatten_parameters function.
|
|
0 commit comments