Skip to content

Commit ac7e960

Browse files
authored
Support torch>=2.6 in word_language_model example (#1347)
Update the example usage of `torch.load()` with required safe globals. Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
1 parent cc8e404 commit ac7e960

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

run_python_examples.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ function vision_transformer() {
154154

155155
function word_language_model() {
156156
uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
157+
for model in "RNN_TANH" "RNN_RELU" "LSTM" "GRU" "Transformer"; do
158+
uv run main.py --model $model --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
159+
done
157160
}
158161

159162
function gcn() {

word_language_model/main.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.onnx
99

1010
import data
11-
import model
11+
from model import PositionalEncoding, RNNModel, TransformerModel
1212

1313
parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM/GRU/Transformer Language Model')
1414
parser.add_argument('--data', type=str, default='./data/wikitext-2',
@@ -108,9 +108,9 @@ def batchify(data, bsz):
108108

109109
ntokens = len(corpus.dictionary)
110110
if args.model == 'Transformer':
111-
model = model.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
111+
model = TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout).to(device)
112112
else:
113-
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
113+
model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied).to(device)
114114

115115
criterion = nn.NLLLoss()
116116

@@ -243,7 +243,33 @@ def export_onnx(path, batch_size, seq_len):
243243

244244
# Load the best saved model.
245245
with open(args.save, 'rb') as f:
246-
model = torch.load(f)
246+
if args.model == 'Transformer':
247+
safe_globals = [
248+
PositionalEncoding,
249+
TransformerModel,
250+
torch.nn.functional.relu,
251+
torch.nn.modules.activation.MultiheadAttention,
252+
torch.nn.modules.container.ModuleList,
253+
torch.nn.modules.dropout.Dropout,
254+
torch.nn.modules.linear.Linear,
255+
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
256+
torch.nn.modules.normalization.LayerNorm,
257+
torch.nn.modules.sparse.Embedding,
258+
torch.nn.modules.transformer.TransformerEncoder,
259+
torch.nn.modules.transformer.TransformerEncoderLayer,
260+
]
261+
else:
262+
safe_globals = [
263+
RNNModel,
264+
torch.nn.modules.dropout.Dropout,
265+
torch.nn.modules.linear.Linear,
266+
torch.nn.modules.rnn.GRU,
267+
torch.nn.modules.rnn.LSTM,
268+
torch.nn.modules.rnn.RNN,
269+
torch.nn.modules.sparse.Embedding,
270+
]
271+
with torch.serialization.safe_globals(safe_globals):
272+
model = torch.load(f)
247273
# after load the rnn params are not a continuous chunk of memory
248274
# this makes them a continuous chunk, and will speed up forward pass
249275
# Currently, only rnn model supports flatten_parameters function.

word_language_model/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torch<2.6
1+
torch>=2.6

0 commit comments

Comments
 (0)