|
31 | 31 | import requests
|
32 | 32 | import sys
|
33 | 33 | import json
|
| 34 | +import shutil |
34 | 35 |
|
35 | 36 | from hashlib import sha256
|
36 | 37 | from enum import IntEnum, auto
|
@@ -125,12 +126,27 @@ def download_model(model):
|
125 | 126 | if tokt == TOKENIZER_TYPE.UGM:
|
126 | 127 | files.append("spiece.model")
|
127 | 128 |
|
128 |
| - for file in files: |
129 |
| - save_path = f"models/tokenizers/{name}/{file}" |
130 |
| - if os.path.isfile(save_path): |
131 |
| - logger.info(f"{name}: File {save_path} already exists - skipping") |
132 |
| - continue |
133 |
| - download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) |
| 129 | + if os.path.isdir(repo): |
| 130 | + # If repo is a path on the file system, copy the directory |
| 131 | + for file in files: |
| 132 | + src_path = os.path.join(repo, file) |
| 133 | + dst_path = f"models/tokenizers/{name}/{file}" |
| 134 | + if os.path.isfile(dst_path): |
| 135 | + logger.info(f"{name}: File {dst_path} already exists - skipping") |
| 136 | + continue |
| 137 | + if os.path.isfile(src_path): |
| 138 | + shutil.copy2(src_path, dst_path) |
| 139 | + logger.info(f"{name}: Copied {src_path} to {dst_path}") |
| 140 | + else: |
| 141 | + logger.warning(f"{name}: Source file {src_path} does not exist") |
| 142 | + else: |
| 143 | + # If repo is a URL, download the files |
| 144 | + for file in files: |
| 145 | + save_path = f"models/tokenizers/{name}/{file}" |
| 146 | + if os.path.isfile(save_path): |
| 147 | + logger.info(f"{name}: File {save_path} already exists - skipping") |
| 148 | + continue |
| 149 | + download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) |
134 | 150 |
|
135 | 151 |
|
136 | 152 | for model in models:
|
|
0 commit comments