Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +14 -0
- InternLM/__init__.py +0 -0
- InternLM/configs/kd_1b_to_300m.py +208 -0
- InternLM/configs/pretrain_300m.py +168 -0
- InternLM/internlm/__init__.py +10 -0
- InternLM/internlm/apis/__init__.py +0 -0
- InternLM/internlm/apis/inference.py +848 -0
- InternLM/internlm/core/__init__.py +9 -0
- InternLM/internlm/core/communication/__init__.py +32 -0
- InternLM/internlm/core/communication/p2p.py +582 -0
- InternLM/internlm/core/communication/utils.py +125 -0
- InternLM/internlm/core/context/__init__.py +49 -0
- InternLM/internlm/core/context/parallel_context.py +569 -0
- InternLM/internlm/core/context/process_group_initializer.py +418 -0
- InternLM/internlm/core/context/random.py +131 -0
- InternLM/internlm/core/engine.py +227 -0
- InternLM/internlm/core/gradient_handler.py +76 -0
- InternLM/internlm/core/naive_amp.py +136 -0
- InternLM/internlm/core/scheduler/__init__.py +14 -0
- InternLM/internlm/core/scheduler/base_scheduler.py +187 -0
- InternLM/internlm/core/scheduler/no_pipeline_scheduler.py +266 -0
- InternLM/internlm/core/scheduler/pipeline_scheduler.py +1363 -0
- InternLM/internlm/core/trainer.py +190 -0
- InternLM/internlm/data/__init__.py +13 -0
- InternLM/internlm/data/batch_sampler.py +354 -0
- InternLM/internlm/data/collaters.py +88 -0
- InternLM/internlm/data/dataset.py +56 -0
- InternLM/internlm/data/dummy_dataset.py +44 -0
- InternLM/internlm/data/packed_dataset.py +421 -0
- InternLM/internlm/data/single_dataset.py +117 -0
- InternLM/internlm/data/utils.py +46 -0
- InternLM/internlm/initialize/__init__.py +16 -0
- InternLM/internlm/initialize/initialize_tensor.py +63 -0
- InternLM/internlm/initialize/initialize_trainer.py +235 -0
- InternLM/internlm/initialize/launch.py +511 -0
- InternLM/internlm/initialize/legacy/__init__.py +0 -0
- InternLM/internlm/initialize/legacy/launch.py +40 -0
- InternLM/internlm/model/__init__.py +23 -0
- InternLM/internlm/model/embedding.py +273 -0
- InternLM/internlm/model/linear.py +201 -0
- InternLM/internlm/model/loss.py +81 -0
- InternLM/internlm/model/metrics.py +263 -0
- InternLM/internlm/model/modeling_internlm.py +524 -0
- InternLM/internlm/model/modeling_vit.py +527 -0
- InternLM/internlm/model/multi_head_attention.py +186 -0
- InternLM/internlm/model/muse/__init__.py +18 -0
- InternLM/internlm/model/muse/modeling_taming_vqgan.py +591 -0
- InternLM/internlm/model/muse/modeling_utils.py +1171 -0
- InternLM/internlm/model/norm.py +46 -0
- InternLM/internlm/model/utils.py +224 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
InternLM/tools/data/derain_prompt/000000_img.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
InternLM/tools/data/derain_prompt/000000_label.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
InternLM/tools/data/derain_prompt/000001_img.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
InternLM/tools/data/derain_prompt/000001_label.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
InternLM/tools/data/derain_prompt/000002_img.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
InternLM/tools/data/derain_prompt/000002_label.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
InternLM/tools/data/examples/derain_1.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
InternLM/tools/data/examples/derain_2.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
InternLM/tools/data/examples/pose_2.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
InternLM/tools/data/examples/seg_1.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
InternLM/tools/data/examples/seg_2.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
InternLM/tools/data/pose_prompt/000002_img.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
InternLM/tools/data/seg_prompt/000000_img.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
figs/DeLVM.PNG filter=lfs diff=lfs merge=lfs -text
|
InternLM/__init__.py
ADDED
|
File without changes
|
InternLM/configs/kd_1b_to_300m.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
kd_config = dict(gt_weight=1., kd_weight=1., temperature=1)
|
| 2 |
+
teacher_type = "INTERNLM"
|
| 3 |
+
|
| 4 |
+
teacher_ckpt_folder = '/path/to/teacher'
|
| 5 |
+
|
| 6 |
+
VQGAN_FOLDER = '/path/to/vqgan'
|
| 7 |
+
T_SEQ_LEN = 2048
|
| 8 |
+
T_HIDDEN_SIZE = 2048
|
| 9 |
+
T_NUM_ATTENTION_HEAD = 16
|
| 10 |
+
T_MLP_RATIO = 8 / 3
|
| 11 |
+
T_NUM_LAYER = 22
|
| 12 |
+
T_VOCAB_SIZE = 8192
|
| 13 |
+
|
| 14 |
+
teacher = dict(
|
| 15 |
+
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
| 16 |
+
num_attention_heads=T_NUM_ATTENTION_HEAD,
|
| 17 |
+
embed_split_hidden=True,
|
| 18 |
+
vocab_size=T_VOCAB_SIZE,
|
| 19 |
+
embed_grad_scale=1,
|
| 20 |
+
parallel_output=True,
|
| 21 |
+
hidden_size=T_HIDDEN_SIZE,
|
| 22 |
+
num_layers=T_NUM_LAYER,
|
| 23 |
+
mlp_ratio=T_MLP_RATIO,
|
| 24 |
+
apply_post_layer_norm=False,
|
| 25 |
+
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
| 26 |
+
norm_type="rmsnorm",
|
| 27 |
+
layer_norm_epsilon=1e-5,
|
| 28 |
+
use_flash_attn=True,
|
| 29 |
+
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
| 30 |
+
lvm_config=dict(
|
| 31 |
+
enable=True,
|
| 32 |
+
embedding_cfg=dict(
|
| 33 |
+
vq_model_path=VQGAN_FOLDER,
|
| 34 |
+
embedding_dim=T_HIDDEN_SIZE,
|
| 35 |
+
freeze_vq_model=True,
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
########################################################
|
| 41 |
+
JOB_NAME = "lvm_llama_kd"
|
| 42 |
+
DO_ALERT = False
|
| 43 |
+
model_type = "INTERNLM"
|
| 44 |
+
|
| 45 |
+
SEQ_LEN = 2048
|
| 46 |
+
HIDDEN_SIZE = 1024
|
| 47 |
+
NUM_ATTENTION_HEAD = 8
|
| 48 |
+
MLP_RATIO = 8 / 3
|
| 49 |
+
NUM_LAYER = 22
|
| 50 |
+
VOCAB_SIZE = 8192
|
| 51 |
+
|
| 52 |
+
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
| 53 |
+
SAVE_CKPT_FOLDER = "local:/path_to_save/"
|
| 54 |
+
LOAD_CKPT_FOLDER = "local:/path_to_load/"
|
| 55 |
+
|
| 56 |
+
CHECKPOINT_EVERY = 10000
|
| 57 |
+
ckpt = dict(
|
| 58 |
+
enable_save_ckpt=True, # set True to enable ckpt save.
|
| 59 |
+
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
| 60 |
+
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"),
|
| 61 |
+
# load_ckpt_folder="local:llm_ckpts/",
|
| 62 |
+
# 'load_ckpt_info' setting guide:
|
| 63 |
+
# 1. the 'path' indicate ckpt path,
|
| 64 |
+
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
| 65 |
+
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
| 66 |
+
# load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
| 67 |
+
checkpoint_every=CHECKPOINT_EVERY,
|
| 68 |
+
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
| 69 |
+
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
| 70 |
+
# oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
| 71 |
+
oss_snapshot_freq=0,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
TRAIN_FOLDER = "/path/to/dataset"
|
| 75 |
+
VALID_FOLDER = "/path/to/dataset"
|
| 76 |
+
data = dict(
|
| 77 |
+
seq_len=SEQ_LEN,
|
| 78 |
+
# micro_num means the number of micro_batch contained in one gradient update
|
| 79 |
+
micro_num=1,
|
| 80 |
+
# packed_length = micro_bsz * SEQ_LEN
|
| 81 |
+
micro_bsz=16,
|
| 82 |
+
# defaults to the value of micro_num
|
| 83 |
+
valid_micro_num=1,
|
| 84 |
+
# defaults to 0, means disable evaluate
|
| 85 |
+
valid_every=0,
|
| 86 |
+
pack_sample_into_one=False,
|
| 87 |
+
train_one_epoch=False,
|
| 88 |
+
total_steps=40000,
|
| 89 |
+
skip_batches="",
|
| 90 |
+
rampup_batch_size="",
|
| 91 |
+
# Datasets with less than 50 rows will be discarded
|
| 92 |
+
min_length=50,
|
| 93 |
+
train_folder=TRAIN_FOLDER,
|
| 94 |
+
valid_folder=None,
|
| 95 |
+
empty_cache_and_diag_interval=10000,
|
| 96 |
+
diag_outlier_ratio=1.1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
grad_scaler = dict(
|
| 100 |
+
fp16=dict(
|
| 101 |
+
# the initial loss scale, defaults to 2**16
|
| 102 |
+
initial_scale=2**16,
|
| 103 |
+
# the minimum loss scale, defaults to None
|
| 104 |
+
min_scale=1,
|
| 105 |
+
# the number of steps to increase loss scale when no overflow occurs
|
| 106 |
+
growth_interval=1000,
|
| 107 |
+
),
|
| 108 |
+
# the multiplication factor for increasing loss scale, defaults to 2
|
| 109 |
+
growth_factor=2,
|
| 110 |
+
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
| 111 |
+
backoff_factor=0.5,
|
| 112 |
+
# the maximum loss scale, defaults to None
|
| 113 |
+
max_scale=2**24,
|
| 114 |
+
# the number of overflows before decreasing loss scale, defaults to 2
|
| 115 |
+
hysteresis=2,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
hybrid_zero_optimizer = dict(
|
| 119 |
+
# Enable low_level_optimzer overlap_communication
|
| 120 |
+
overlap_sync_grad=True,
|
| 121 |
+
overlap_sync_param=True,
|
| 122 |
+
# bucket size for nccl communication params
|
| 123 |
+
reduce_bucket_size=512 * 1024 * 1024,
|
| 124 |
+
# grad clipping
|
| 125 |
+
clip_grad_norm=1.0,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
loss = dict(
|
| 129 |
+
label_smoothing=0,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
adam = dict(
|
| 133 |
+
lr=1.5e-4,
|
| 134 |
+
adam_beta1=0.9,
|
| 135 |
+
adam_beta2=0.95,
|
| 136 |
+
adam_beta2_c=0,
|
| 137 |
+
adam_eps=1e-8,
|
| 138 |
+
weight_decay=0.1,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
lr_scheduler = dict(
|
| 142 |
+
total_steps=data["total_steps"],
|
| 143 |
+
init_steps=0, # optimizer_warmup_step
|
| 144 |
+
warmup_ratio=0.0056,
|
| 145 |
+
eta_min=1.5e-5,
|
| 146 |
+
last_epoch=-1,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
beta2_scheduler = dict(
|
| 150 |
+
init_beta2=adam["adam_beta2"],
|
| 151 |
+
c=adam["adam_beta2_c"],
|
| 152 |
+
cur_iter=-1,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
model = dict(
|
| 156 |
+
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
| 157 |
+
num_attention_heads=NUM_ATTENTION_HEAD,
|
| 158 |
+
embed_split_hidden=True,
|
| 159 |
+
vocab_size=VOCAB_SIZE,
|
| 160 |
+
embed_grad_scale=1,
|
| 161 |
+
parallel_output=True,
|
| 162 |
+
hidden_size=HIDDEN_SIZE,
|
| 163 |
+
num_layers=NUM_LAYER,
|
| 164 |
+
mlp_ratio=MLP_RATIO,
|
| 165 |
+
apply_post_layer_norm=False,
|
| 166 |
+
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
| 167 |
+
norm_type="rmsnorm",
|
| 168 |
+
layer_norm_epsilon=1e-5,
|
| 169 |
+
use_flash_attn=True,
|
| 170 |
+
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
| 171 |
+
lvm_config=dict(
|
| 172 |
+
enable=True,
|
| 173 |
+
embedding_cfg=dict(
|
| 174 |
+
vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/',
|
| 175 |
+
embedding_dim=HIDDEN_SIZE,
|
| 176 |
+
freeze_vq_model=True,
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
)
|
| 180 |
+
"""
|
| 181 |
+
zero1 parallel:
|
| 182 |
+
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
| 183 |
+
so parameters will be divided within the range of dp.
|
| 184 |
+
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
| 185 |
+
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
| 186 |
+
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
| 187 |
+
pipeline parallel (dict):
|
| 188 |
+
1. size: int, the size of pipeline parallel.
|
| 189 |
+
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
| 190 |
+
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
| 191 |
+
"""
|
| 192 |
+
parallel = dict(
|
| 193 |
+
zero1=8,
|
| 194 |
+
pipeline=dict(size=1, interleaved_overlap=True),
|
| 195 |
+
sequence_parallel=False,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
cudnn_deterministic = False
|
| 199 |
+
cudnn_benchmark = False
|
| 200 |
+
|
| 201 |
+
monitor = dict(
|
| 202 |
+
# feishu alert configs
|
| 203 |
+
alert=dict(
|
| 204 |
+
enable_feishu_alert=DO_ALERT,
|
| 205 |
+
feishu_alert_address=None, # feishu webhook to send alert message
|
| 206 |
+
light_monitor_address=None, # light_monitor address to send heartbeat
|
| 207 |
+
),
|
| 208 |
+
)
|
InternLM/configs/pretrain_300m.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
JOB_NAME = "lvm_llama"
|
| 2 |
+
DO_ALERT = False
|
| 3 |
+
model_type = "INTERNLM"
|
| 4 |
+
|
| 5 |
+
SEQ_LEN = 2048
|
| 6 |
+
HIDDEN_SIZE = 1024
|
| 7 |
+
NUM_ATTENTION_HEAD = 8
|
| 8 |
+
MLP_RATIO = 8 / 3
|
| 9 |
+
NUM_LAYER = 22
|
| 10 |
+
VOCAB_SIZE = 8192
|
| 11 |
+
|
| 12 |
+
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
| 13 |
+
SAVE_CKPT_FOLDER = "local:/path_to_save/"
|
| 14 |
+
LOAD_CKPT_FOLDER = "local:/path_to_load/"
|
| 15 |
+
|
| 16 |
+
CHECKPOINT_EVERY = 10000
|
| 17 |
+
ckpt = dict(
|
| 18 |
+
enable_save_ckpt=True, # set True to enable ckpt save.
|
| 19 |
+
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
| 20 |
+
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"),
|
| 21 |
+
# load_ckpt_folder="local:llm_ckpts/",
|
| 22 |
+
# 'load_ckpt_info' setting guide:
|
| 23 |
+
# 1. the 'path' indicate ckpt path,
|
| 24 |
+
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
| 25 |
+
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
| 26 |
+
# load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
| 27 |
+
checkpoint_every=CHECKPOINT_EVERY,
|
| 28 |
+
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
| 29 |
+
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
| 30 |
+
# oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
| 31 |
+
oss_snapshot_freq=0,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
TRAIN_FOLDER = "/path/to/dataset"
|
| 35 |
+
VALID_FOLDER = "/path/to/dataset"
|
| 36 |
+
data = dict(
|
| 37 |
+
seq_len=SEQ_LEN,
|
| 38 |
+
# micro_num means the number of micro_batch contained in one gradient update
|
| 39 |
+
micro_num=1,
|
| 40 |
+
# packed_length = micro_bsz * SEQ_LEN
|
| 41 |
+
micro_bsz=16,
|
| 42 |
+
# defaults to the value of micro_num
|
| 43 |
+
valid_micro_num=1,
|
| 44 |
+
# defaults to 0, means disable evaluate
|
| 45 |
+
valid_every=0,
|
| 46 |
+
pack_sample_into_one=False,
|
| 47 |
+
train_one_epoch=False,
|
| 48 |
+
total_steps=40000,
|
| 49 |
+
skip_batches="",
|
| 50 |
+
rampup_batch_size="",
|
| 51 |
+
# Datasets with less than 50 rows will be discarded
|
| 52 |
+
min_length=50,
|
| 53 |
+
train_folder=TRAIN_FOLDER,
|
| 54 |
+
valid_folder=None,
|
| 55 |
+
empty_cache_and_diag_interval=10000,
|
| 56 |
+
diag_outlier_ratio=1.1,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
grad_scaler = dict(
|
| 60 |
+
fp16=dict(
|
| 61 |
+
# the initial loss scale, defaults to 2**16
|
| 62 |
+
initial_scale=2**16,
|
| 63 |
+
# the minimum loss scale, defaults to None
|
| 64 |
+
min_scale=1,
|
| 65 |
+
# the number of steps to increase loss scale when no overflow occurs
|
| 66 |
+
growth_interval=1000,
|
| 67 |
+
),
|
| 68 |
+
# the multiplication factor for increasing loss scale, defaults to 2
|
| 69 |
+
growth_factor=2,
|
| 70 |
+
# the multiplication factor for decreasing loss scale, defaults to 0.5
|
| 71 |
+
backoff_factor=0.5,
|
| 72 |
+
# the maximum loss scale, defaults to None
|
| 73 |
+
max_scale=2**24,
|
| 74 |
+
# the number of overflows before decreasing loss scale, defaults to 2
|
| 75 |
+
hysteresis=2,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
hybrid_zero_optimizer = dict(
|
| 79 |
+
# Enable low_level_optimzer overlap_communication
|
| 80 |
+
overlap_sync_grad=True,
|
| 81 |
+
overlap_sync_param=True,
|
| 82 |
+
# bucket size for nccl communication params
|
| 83 |
+
reduce_bucket_size=512 * 1024 * 1024,
|
| 84 |
+
# grad clipping
|
| 85 |
+
clip_grad_norm=1.0,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
loss = dict(
|
| 89 |
+
label_smoothing=0,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
adam = dict(
|
| 93 |
+
lr=1.5e-4,
|
| 94 |
+
adam_beta1=0.9,
|
| 95 |
+
adam_beta2=0.95,
|
| 96 |
+
adam_beta2_c=0,
|
| 97 |
+
adam_eps=1e-8,
|
| 98 |
+
weight_decay=0.1,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
lr_scheduler = dict(
|
| 102 |
+
total_steps=data["total_steps"],
|
| 103 |
+
init_steps=0, # optimizer_warmup_step
|
| 104 |
+
warmup_ratio=0.0056,
|
| 105 |
+
eta_min=1.5e-5,
|
| 106 |
+
last_epoch=-1,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
beta2_scheduler = dict(
|
| 110 |
+
init_beta2=adam["adam_beta2"],
|
| 111 |
+
c=adam["adam_beta2_c"],
|
| 112 |
+
cur_iter=-1,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
model = dict(
|
| 116 |
+
checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
|
| 117 |
+
num_attention_heads=NUM_ATTENTION_HEAD,
|
| 118 |
+
embed_split_hidden=True,
|
| 119 |
+
vocab_size=VOCAB_SIZE,
|
| 120 |
+
embed_grad_scale=1,
|
| 121 |
+
parallel_output=True,
|
| 122 |
+
hidden_size=HIDDEN_SIZE,
|
| 123 |
+
num_layers=NUM_LAYER,
|
| 124 |
+
mlp_ratio=MLP_RATIO,
|
| 125 |
+
apply_post_layer_norm=False,
|
| 126 |
+
dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
|
| 127 |
+
norm_type="rmsnorm",
|
| 128 |
+
layer_norm_epsilon=1e-5,
|
| 129 |
+
use_flash_attn=True,
|
| 130 |
+
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
| 131 |
+
lvm_config=dict(
|
| 132 |
+
enable=True,
|
| 133 |
+
embedding_cfg=dict(
|
| 134 |
+
vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/',
|
| 135 |
+
embedding_dim=HIDDEN_SIZE,
|
| 136 |
+
freeze_vq_model=True,
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
"""
|
| 141 |
+
zero1 parallel:
|
| 142 |
+
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
| 143 |
+
so parameters will be divided within the range of dp.
|
| 144 |
+
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
| 145 |
+
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
| 146 |
+
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
| 147 |
+
pipeline parallel (dict):
|
| 148 |
+
1. size: int, the size of pipeline parallel.
|
| 149 |
+
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
| 150 |
+
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
| 151 |
+
"""
|
| 152 |
+
parallel = dict(
|
| 153 |
+
zero1=8,
|
| 154 |
+
pipeline=dict(size=1, interleaved_overlap=True),
|
| 155 |
+
sequence_parallel=False,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
cudnn_deterministic = False
|
| 159 |
+
cudnn_benchmark = False
|
| 160 |
+
|
| 161 |
+
monitor = dict(
|
| 162 |
+
# feishu alert configs
|
| 163 |
+
alert=dict(
|
| 164 |
+
enable_feishu_alert=DO_ALERT,
|
| 165 |
+
feishu_alert_address=None, # feishu webhook to send alert message
|
| 166 |
+
light_monitor_address=None, # light_monitor address to send heartbeat
|
| 167 |
+
),
|
| 168 |
+
)
|
InternLM/internlm/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .initialize.initialize_trainer import initialize_trainer, initialize_kd_trainer
|
| 2 |
+
from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"get_default_parser",
|
| 6 |
+
"initialize_kd_trainer",
|
| 7 |
+
"initialize_trainer",
|
| 8 |
+
"launch_from_slurm",
|
| 9 |
+
"launch_from_torch",
|
| 10 |
+
]
|
InternLM/internlm/apis/__init__.py
ADDED
|
File without changes
|
InternLM/internlm/apis/inference.py
ADDED
|
@@ -0,0 +1,848 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
__all__ = ["SequenceGenerator"]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class InferenceParams:
|
| 12 |
+
"""
|
| 13 |
+
Intermediate cache objects for inference
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
max_sequence_len,
|
| 19 |
+
max_batch_size,
|
| 20 |
+
sequence_len_offset=0,
|
| 21 |
+
batch_size_offset=0,
|
| 22 |
+
key_value_memory_dict: dict = None,
|
| 23 |
+
lengths_per_sample=None,
|
| 24 |
+
attention_mask=None,
|
| 25 |
+
) -> None:
|
| 26 |
+
|
| 27 |
+
self.max_sequence_len: int = max_sequence_len
|
| 28 |
+
self.max_batch_size: int = max_batch_size
|
| 29 |
+
self.sequence_len_offset: int = sequence_len_offset
|
| 30 |
+
self.batch_size_offset: int = batch_size_offset
|
| 31 |
+
if key_value_memory_dict is None:
|
| 32 |
+
key_value_memory_dict = {}
|
| 33 |
+
self.key_value_memory_dict: dict = key_value_memory_dict
|
| 34 |
+
self.fused_ft_kernel: bool = False
|
| 35 |
+
self.lengths_per_sample = lengths_per_sample
|
| 36 |
+
self.attention_mask = attention_mask
|
| 37 |
+
|
| 38 |
+
def reorder_state(self, indices):
|
| 39 |
+
if self.lengths_per_sample is not None:
|
| 40 |
+
self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0)
|
| 41 |
+
for key, value in list(self.key_value_memory_dict.items()):
|
| 42 |
+
value = value.index_select(index=indices, dim=0)
|
| 43 |
+
self.key_value_memory_dict[key] = value
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_model_device(model):
|
| 47 |
+
"""
|
| 48 |
+
obtain the device of an nn.Module.model
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model: nn.Module
|
| 52 |
+
|
| 53 |
+
Return: torch.device. if None, the parameters of this model is None.
|
| 54 |
+
"""
|
| 55 |
+
assert isinstance(model, nn.Module)
|
| 56 |
+
|
| 57 |
+
parameters = list(model.parameters())
|
| 58 |
+
if len(parameters) == 0:
|
| 59 |
+
return None
|
| 60 |
+
else:
|
| 61 |
+
return parameters[0].device
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SequenceGenerator:
|
| 65 |
+
"""
|
| 66 |
+
Sequence Generator.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, decoder, eos_token_id, pad_token_id, bos_token_id):
|
| 70 |
+
self.decoder = decoder
|
| 71 |
+
self.eos_token_id = eos_token_id
|
| 72 |
+
self.pad_token_id = pad_token_id
|
| 73 |
+
self.bos_token_id = bos_token_id
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def generate(
|
| 77 |
+
self,
|
| 78 |
+
tokens: "torch.LongTensor" = None,
|
| 79 |
+
num_return_sequences=1,
|
| 80 |
+
max_length: int = 20,
|
| 81 |
+
num_beams: int = 1,
|
| 82 |
+
do_sample: bool = True,
|
| 83 |
+
temperature: float = 1.0,
|
| 84 |
+
top_k: int = 50,
|
| 85 |
+
top_p: float = 1.0,
|
| 86 |
+
repetition_penalty: float = 1,
|
| 87 |
+
length_penalty: float = 1.0,
|
| 88 |
+
):
|
| 89 |
+
"""
|
| 90 |
+
Args:
|
| 91 |
+
tokens: the beginning tokens whose shape is [bsz, length]. If shape is None, default ''bos_token'' will be
|
| 92 |
+
added to conduct generation.
|
| 93 |
+
num_return_sequences: number of returned sequences.
|
| 94 |
+
max_length: the max length of generated sequence.
|
| 95 |
+
num_beams: the size of beam search.
|
| 96 |
+
do_sample: whether using sample.
|
| 97 |
+
temperature: it's meaningful when do_sample is True.
|
| 98 |
+
top_k: sampling from top_k.
|
| 99 |
+
top_p: sampling from top_p tokens(nucleus sampling).
|
| 100 |
+
|
| 101 |
+
Return:
|
| 102 |
+
the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None,
|
| 103 |
+
the ending of each sequence must be eos_token_id.
|
| 104 |
+
"""
|
| 105 |
+
assert num_return_sequences <= num_beams, f"The `{num_return_sequences}` must be less than `{num_beams}`..."
|
| 106 |
+
if do_sample:
|
| 107 |
+
return sample_generate(
|
| 108 |
+
self.decoder,
|
| 109 |
+
tokens=tokens,
|
| 110 |
+
max_length=max_length,
|
| 111 |
+
num_beams=num_beams,
|
| 112 |
+
num_return_sequences=num_return_sequences,
|
| 113 |
+
temperature=temperature,
|
| 114 |
+
top_k=top_k,
|
| 115 |
+
top_p=top_p,
|
| 116 |
+
eos_token_id=self.eos_token_id, # the ending token id
|
| 117 |
+
pad_token_id=self.pad_token_id,
|
| 118 |
+
repetition_penalty=repetition_penalty, # the penalty degree for repetition tokens
|
| 119 |
+
length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
|
| 120 |
+
# Otherwise, encourages short sequence.
|
| 121 |
+
bos_token_id=self.bos_token_id,
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
return greedy_generate(
|
| 125 |
+
self.decoder,
|
| 126 |
+
tokens=tokens,
|
| 127 |
+
max_length=max_length,
|
| 128 |
+
num_beams=num_beams,
|
| 129 |
+
num_return_sequences=num_return_sequences,
|
| 130 |
+
eos_token_id=self.eos_token_id,
|
| 131 |
+
pad_token_id=self.pad_token_id,
|
| 132 |
+
repetition_penalty=repetition_penalty,
|
| 133 |
+
length_penalty=length_penalty,
|
| 134 |
+
bos_token_id=self.bos_token_id,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def greedy_generate(
|
| 140 |
+
decoder,
|
| 141 |
+
tokens=None,
|
| 142 |
+
max_length=20,
|
| 143 |
+
num_beams=1,
|
| 144 |
+
num_return_sequences=1,
|
| 145 |
+
eos_token_id=None,
|
| 146 |
+
pad_token_id=0,
|
| 147 |
+
repetition_penalty=1,
|
| 148 |
+
length_penalty=1.0,
|
| 149 |
+
bos_token_id=1,
|
| 150 |
+
feat_mask=None,
|
| 151 |
+
ffn_mask=None,
|
| 152 |
+
layer_mask=None,
|
| 153 |
+
):
|
| 154 |
+
"""
|
| 155 |
+
Search sequence greedily.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
decoder: the Decoder object.
|
| 159 |
+
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
|
| 160 |
+
max_length: the max length for generated sequence.
|
| 161 |
+
num_beams: the size of beam to decode.
|
| 162 |
+
eos_token_id: the ending token id. If None, the decode length is max_length.
|
| 163 |
+
pad_token_id: the token id of pad.
|
| 164 |
+
repetition_penalty: the penalty degree for repetition tokens
|
| 165 |
+
length_penalty: the penalty for length.
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
if num_beams == 1:
|
| 169 |
+
token_ids = _no_beam_search_generate(
|
| 170 |
+
decoder,
|
| 171 |
+
tokens=tokens,
|
| 172 |
+
max_length=max_length,
|
| 173 |
+
temperature=1,
|
| 174 |
+
top_k=50,
|
| 175 |
+
top_p=1,
|
| 176 |
+
eos_token_id=eos_token_id,
|
| 177 |
+
do_sample=False,
|
| 178 |
+
repetition_penalty=repetition_penalty,
|
| 179 |
+
length_penalty=length_penalty,
|
| 180 |
+
pad_token_id=pad_token_id,
|
| 181 |
+
bos_token_id=bos_token_id,
|
| 182 |
+
feat_mask=feat_mask,
|
| 183 |
+
ffn_mask=ffn_mask,
|
| 184 |
+
layer_mask=layer_mask,
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
token_ids = _beam_search_generate(
|
| 188 |
+
decoder,
|
| 189 |
+
tokens=tokens,
|
| 190 |
+
max_length=max_length,
|
| 191 |
+
num_beams=num_beams,
|
| 192 |
+
num_return_sequences=num_return_sequences,
|
| 193 |
+
temperature=1,
|
| 194 |
+
top_k=50,
|
| 195 |
+
top_p=1,
|
| 196 |
+
eos_token_id=eos_token_id,
|
| 197 |
+
do_sample=False,
|
| 198 |
+
repetition_penalty=repetition_penalty,
|
| 199 |
+
length_penalty=length_penalty,
|
| 200 |
+
pad_token_id=pad_token_id,
|
| 201 |
+
bos_token_id=bos_token_id,
|
| 202 |
+
feat_mask=feat_mask,
|
| 203 |
+
ffn_mask=ffn_mask,
|
| 204 |
+
layer_mask=layer_mask,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
return token_ids
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@torch.no_grad()
|
| 211 |
+
def sample_generate(
|
| 212 |
+
decoder,
|
| 213 |
+
tokens,
|
| 214 |
+
max_length=20,
|
| 215 |
+
num_beams=1,
|
| 216 |
+
num_return_sequences=1,
|
| 217 |
+
temperature=1.0,
|
| 218 |
+
top_k=50,
|
| 219 |
+
top_p=1.0,
|
| 220 |
+
eos_token_id=None,
|
| 221 |
+
pad_token_id=0,
|
| 222 |
+
repetition_penalty=1.0,
|
| 223 |
+
length_penalty=1.0,
|
| 224 |
+
bos_token_id=1,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
generate sequence in sampling way.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
decoder: the Decoder object.
|
| 231 |
+
tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
|
| 232 |
+
max_length: the max length for generated sequence.
|
| 233 |
+
num_beams: the size of beam to decode.
|
| 234 |
+
num_return_sequences: number of returned sequence.
|
| 235 |
+
temperature: annealing magnitude during sampling.
|
| 236 |
+
top_k: sampling from top_k. (Default: 50)
|
| 237 |
+
top_p: sampling from top_p tokens(nucleus sampling). (Default: 1.0)
|
| 238 |
+
eos_token_id: the ending token id. If None, the decode length is max_length.
|
| 239 |
+
pad_token_id: the token id of pad.
|
| 240 |
+
repetition_penalty: the penalty degree for repetition tokens
|
| 241 |
+
length_penalty: the penalty for length.
|
| 242 |
+
|
| 243 |
+
"""
|
| 244 |
+
if num_beams == 1:
|
| 245 |
+
token_ids = _no_beam_search_generate(
|
| 246 |
+
decoder,
|
| 247 |
+
tokens=tokens,
|
| 248 |
+
max_length=max_length,
|
| 249 |
+
temperature=temperature,
|
| 250 |
+
top_k=top_k,
|
| 251 |
+
top_p=top_p,
|
| 252 |
+
eos_token_id=eos_token_id,
|
| 253 |
+
do_sample=True,
|
| 254 |
+
repetition_penalty=repetition_penalty,
|
| 255 |
+
length_penalty=length_penalty,
|
| 256 |
+
pad_token_id=pad_token_id,
|
| 257 |
+
bos_token_id=bos_token_id,
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
token_ids = _beam_search_generate(
|
| 261 |
+
decoder,
|
| 262 |
+
tokens=tokens,
|
| 263 |
+
max_length=max_length,
|
| 264 |
+
num_beams=num_beams,
|
| 265 |
+
num_return_sequences=num_return_sequences,
|
| 266 |
+
temperature=temperature,
|
| 267 |
+
top_k=top_k,
|
| 268 |
+
top_p=top_p,
|
| 269 |
+
eos_token_id=eos_token_id,
|
| 270 |
+
do_sample=True,
|
| 271 |
+
repetition_penalty=repetition_penalty,
|
| 272 |
+
length_penalty=length_penalty,
|
| 273 |
+
pad_token_id=pad_token_id,
|
| 274 |
+
bos_token_id=bos_token_id,
|
| 275 |
+
)
|
| 276 |
+
return token_ids
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@torch.no_grad()
|
| 280 |
+
def _no_beam_search_generate(
|
| 281 |
+
decoder,
|
| 282 |
+
tokens,
|
| 283 |
+
inference_params=None,
|
| 284 |
+
max_length=20,
|
| 285 |
+
temperature=1.0,
|
| 286 |
+
top_k=50,
|
| 287 |
+
top_p=1.0,
|
| 288 |
+
eos_token_id=None,
|
| 289 |
+
do_sample=True,
|
| 290 |
+
repetition_penalty=1.0,
|
| 291 |
+
length_penalty=1.0,
|
| 292 |
+
pad_token_id=0,
|
| 293 |
+
bos_token_id=1,
|
| 294 |
+
feat_mask=None,
|
| 295 |
+
ffn_mask=None,
|
| 296 |
+
layer_mask=None,
|
| 297 |
+
):
|
| 298 |
+
# delete num_return_sequences=1 for lint check;
|
| 299 |
+
batch_size = tokens.size(0)
|
| 300 |
+
if eos_token_id is None:
|
| 301 |
+
_eos_token_id = -1
|
| 302 |
+
else:
|
| 303 |
+
_eos_token_id = eos_token_id
|
| 304 |
+
|
| 305 |
+
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
|
| 306 |
+
if has_bos:
|
| 307 |
+
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
| 308 |
+
bos_sum = bos_pos.cumsum(dim=-1)
|
| 309 |
+
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
| 310 |
+
to_atten_x = bos_pos[:, :, None]
|
| 311 |
+
to_atten_y = bos_pos[:, None, :]
|
| 312 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 313 |
+
else:
|
| 314 |
+
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
| 315 |
+
to_atten_x = bos_pos[:, :, None]
|
| 316 |
+
to_atten_y = bos_pos[:, None, :]
|
| 317 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 318 |
+
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
| 319 |
+
if inference_params is None:
|
| 320 |
+
inference_params = InferenceParams(
|
| 321 |
+
max_sequence_len=max_length,
|
| 322 |
+
max_batch_size=tokens.size(0),
|
| 323 |
+
sequence_len_offset=0,
|
| 324 |
+
batch_size_offset=0,
|
| 325 |
+
key_value_memory_dict=None,
|
| 326 |
+
lengths_per_sample=None,
|
| 327 |
+
attention_mask=attention_mask,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
if layer_mask is None:
|
| 331 |
+
if feat_mask is None and ffn_mask is None:
|
| 332 |
+
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
|
| 333 |
+
else:
|
| 334 |
+
scores = decoder(
|
| 335 |
+
**{
|
| 336 |
+
"input_ids": tokens,
|
| 337 |
+
"inference_params": inference_params,
|
| 338 |
+
"feat_mask": feat_mask,
|
| 339 |
+
"ffn_mask": ffn_mask,
|
| 340 |
+
}
|
| 341 |
+
)
|
| 342 |
+
else:
|
| 343 |
+
scores = decoder(
|
| 344 |
+
**{
|
| 345 |
+
"input_ids": tokens,
|
| 346 |
+
"inference_params": inference_params,
|
| 347 |
+
"feat_mask": feat_mask,
|
| 348 |
+
"ffn_mask": ffn_mask,
|
| 349 |
+
"layer_mask": layer_mask,
|
| 350 |
+
}
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if isinstance(scores, (list, tuple)):
|
| 354 |
+
scores = scores[0]
|
| 355 |
+
scores = scores[:, -1].float()
|
| 356 |
+
inference_params.sequence_len_offset += tokens.size(1)
|
| 357 |
+
if _eos_token_id != -1:
|
| 358 |
+
scores[:, _eos_token_id] = -1e12
|
| 359 |
+
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
| 360 |
+
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
| 361 |
+
cur_len = token_ids.size(1)
|
| 362 |
+
dones = token_ids.new_zeros(batch_size).eq(1)
|
| 363 |
+
# tokens = tokens[:, -1:]
|
| 364 |
+
|
| 365 |
+
real_max_length = max_length
|
| 366 |
+
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
|
| 367 |
+
|
| 368 |
+
while cur_len < real_max_length:
|
| 369 |
+
# batch_size x vocab_size
|
| 370 |
+
if has_bos:
|
| 371 |
+
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
| 372 |
+
bos_sum = bos_pos.cumsum(dim=-1)
|
| 373 |
+
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
| 374 |
+
to_atten_x = bos_pos[:, :, None]
|
| 375 |
+
to_atten_y = bos_pos[:, None, :]
|
| 376 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 377 |
+
else:
|
| 378 |
+
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
| 379 |
+
to_atten_x = bos_pos[:, :, None]
|
| 380 |
+
to_atten_y = bos_pos[:, None, :]
|
| 381 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 382 |
+
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
| 383 |
+
inference_params.attention_mask = attention_mask
|
| 384 |
+
if layer_mask is None:
|
| 385 |
+
if feat_mask is None and ffn_mask is None:
|
| 386 |
+
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
| 387 |
+
else:
|
| 388 |
+
scores = decoder(
|
| 389 |
+
**{
|
| 390 |
+
"input_ids": token_ids[:, -1:],
|
| 391 |
+
"inference_params": inference_params,
|
| 392 |
+
"feat_mask": feat_mask,
|
| 393 |
+
"ffn_mask": ffn_mask,
|
| 394 |
+
}
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
scores = decoder(
|
| 398 |
+
**{
|
| 399 |
+
"input_ids": token_ids[:, -1:],
|
| 400 |
+
"inference_params": inference_params,
|
| 401 |
+
"feat_mask": feat_mask,
|
| 402 |
+
"ffn_mask": ffn_mask,
|
| 403 |
+
"layer_mask": layer_mask,
|
| 404 |
+
}
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if isinstance(scores, (list, tuple)):
|
| 408 |
+
scores = scores[0]
|
| 409 |
+
scores = scores[:, -1].float()
|
| 410 |
+
inference_params.sequence_len_offset += 1
|
| 411 |
+
|
| 412 |
+
if repetition_penalty != 1.0:
|
| 413 |
+
token_scores = scores.gather(dim=1, index=token_ids)
|
| 414 |
+
lt_zero_mask = token_scores.lt(0).float()
|
| 415 |
+
ge_zero_mask = lt_zero_mask.eq(0).float()
|
| 416 |
+
token_scores = (
|
| 417 |
+
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
| 418 |
+
)
|
| 419 |
+
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
| 420 |
+
|
| 421 |
+
if eos_token_id is not None and length_penalty != 1.0:
|
| 422 |
+
# batch_size x vocab_size
|
| 423 |
+
token_scores = scores / cur_len**length_penalty
|
| 424 |
+
eos_mask = scores.new_ones(scores.size(1))
|
| 425 |
+
eos_mask[eos_token_id] = 0
|
| 426 |
+
eos_mask = eos_mask.unsqueeze(0).eq(1)
|
| 427 |
+
|
| 428 |
+
scores = scores.masked_scatter(eos_mask, token_scores)
|
| 429 |
+
|
| 430 |
+
if do_sample:
|
| 431 |
+
if temperature > 0 and temperature != 1:
|
| 432 |
+
scores = scores / temperature
|
| 433 |
+
|
| 434 |
+
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
|
| 435 |
+
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
|
| 436 |
+
probs = F.softmax(scores, dim=-1) + 1e-12
|
| 437 |
+
|
| 438 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
|
| 439 |
+
else:
|
| 440 |
+
next_tokens = torch.argmax(scores, dim=-1) # batch_size
|
| 441 |
+
|
| 442 |
+
if _eos_token_id != -1:
|
| 443 |
+
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), _eos_token_id)
|
| 444 |
+
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
| 445 |
+
tokens = next_tokens.unsqueeze(1)
|
| 446 |
+
|
| 447 |
+
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
| 448 |
+
|
| 449 |
+
end_mask = next_tokens.eq(_eos_token_id)
|
| 450 |
+
dones = dones.__or__(end_mask)
|
| 451 |
+
cur_len += 1
|
| 452 |
+
|
| 453 |
+
if dones.min() == 1:
|
| 454 |
+
break
|
| 455 |
+
|
| 456 |
+
# if eos_token_id is not None:
|
| 457 |
+
# # setting the eos at the maximum length position
|
| 458 |
+
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id)
|
| 459 |
+
# if cur_len == max_length:
|
| 460 |
+
# # If eos is not reached by the maximum length, forcibly replace the last word with eos
|
| 461 |
+
# token_ids[:, -1].masked_fill_(~dones, eos_token_id)
|
| 462 |
+
# TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
|
| 463 |
+
# be able to return multiple real results
|
| 464 |
+
return token_ids[:, None]
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@torch.no_grad()
|
| 468 |
+
def _beam_search_generate(
|
| 469 |
+
decoder,
|
| 470 |
+
tokens,
|
| 471 |
+
inference_params=None,
|
| 472 |
+
max_length=20,
|
| 473 |
+
num_beams=4,
|
| 474 |
+
num_return_sequences=1,
|
| 475 |
+
temperature=1.0,
|
| 476 |
+
top_k=50,
|
| 477 |
+
top_p=1.0,
|
| 478 |
+
eos_token_id=None,
|
| 479 |
+
do_sample=True,
|
| 480 |
+
repetition_penalty=1.0,
|
| 481 |
+
length_penalty=1.0,
|
| 482 |
+
pad_token_id=0,
|
| 483 |
+
bos_token_id=1,
|
| 484 |
+
feat_mask=None,
|
| 485 |
+
ffn_mask=None,
|
| 486 |
+
layer_mask=None,
|
| 487 |
+
) -> torch.LongTensor:
|
| 488 |
+
|
| 489 |
+
device = _get_model_device(decoder)
|
| 490 |
+
batch_size = tokens.size(0)
|
| 491 |
+
|
| 492 |
+
if eos_token_id is None:
|
| 493 |
+
_eos_token_id = -1
|
| 494 |
+
else:
|
| 495 |
+
_eos_token_id = eos_token_id
|
| 496 |
+
|
| 497 |
+
has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
|
| 498 |
+
|
| 499 |
+
if has_bos:
|
| 500 |
+
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
| 501 |
+
bos_sum = bos_pos.cumsum(dim=-1)
|
| 502 |
+
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
| 503 |
+
to_atten_x = bos_pos[:, :, None]
|
| 504 |
+
to_atten_y = bos_pos[:, None, :]
|
| 505 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 506 |
+
else:
|
| 507 |
+
bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
|
| 508 |
+
to_atten_x = bos_pos[:, :, None]
|
| 509 |
+
to_atten_y = bos_pos[:, None, :]
|
| 510 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 511 |
+
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
| 512 |
+
|
| 513 |
+
if inference_params is None:
|
| 514 |
+
inference_params = InferenceParams(
|
| 515 |
+
max_sequence_len=max_length,
|
| 516 |
+
max_batch_size=tokens.size(0),
|
| 517 |
+
sequence_len_offset=0,
|
| 518 |
+
batch_size_offset=0,
|
| 519 |
+
key_value_memory_dict=None,
|
| 520 |
+
lengths_per_sample=None,
|
| 521 |
+
attention_mask=attention_mask,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if layer_mask is None:
|
| 525 |
+
if feat_mask is None and ffn_mask is None:
|
| 526 |
+
scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
|
| 527 |
+
else:
|
| 528 |
+
scores = decoder(
|
| 529 |
+
**{
|
| 530 |
+
"input_ids": tokens,
|
| 531 |
+
"inference_params": inference_params,
|
| 532 |
+
"feat_mask": feat_mask,
|
| 533 |
+
"ffn_mask": ffn_mask,
|
| 534 |
+
}
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
scores = decoder(
|
| 538 |
+
**{
|
| 539 |
+
"input_ids": tokens,
|
| 540 |
+
"inference_params": inference_params,
|
| 541 |
+
"feat_mask": feat_mask,
|
| 542 |
+
"ffn_mask": ffn_mask,
|
| 543 |
+
"layer_mask": layer_mask,
|
| 544 |
+
}
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if isinstance(scores, (list, tuple)):
|
| 548 |
+
scores = scores[0]
|
| 549 |
+
scores = scores[:, -1].float()
|
| 550 |
+
inference_params.sequence_len_offset += tokens.size(1)
|
| 551 |
+
if _eos_token_id != -1:
|
| 552 |
+
scores[:, _eos_token_id] = -1e12
|
| 553 |
+
vocab_size = scores.size(1)
|
| 554 |
+
assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size."
|
| 555 |
+
|
| 556 |
+
if do_sample:
|
| 557 |
+
probs = F.softmax(scores, dim=-1) + 1e-12
|
| 558 |
+
# (batch_size, num_beams)
|
| 559 |
+
next_tokens = torch.multinomial(probs, num_samples=num_beams)
|
| 560 |
+
logits = probs.log()
|
| 561 |
+
# (batch_size, num_beams)
|
| 562 |
+
next_scores = logits.gather(dim=1, index=next_tokens)
|
| 563 |
+
else:
|
| 564 |
+
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
|
| 565 |
+
# obtain (batch_size, num_beams), (batch_size, num_beams)
|
| 566 |
+
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
|
| 567 |
+
|
| 568 |
+
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
| 569 |
+
indices = indices.repeat_interleave(num_beams)
|
| 570 |
+
inference_params.reorder_state(indices)
|
| 571 |
+
|
| 572 |
+
# batch_size * num_beams x length
|
| 573 |
+
tokens = tokens.index_select(dim=0, index=indices)
|
| 574 |
+
# genrated token (batch_size', cur_len)
|
| 575 |
+
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
|
| 576 |
+
dones = [False] * batch_size
|
| 577 |
+
|
| 578 |
+
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
| 579 |
+
|
| 580 |
+
cur_len = token_ids.size(1)
|
| 581 |
+
|
| 582 |
+
real_max_length = max_length
|
| 583 |
+
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
|
| 584 |
+
hypos = [
|
| 585 |
+
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
| 586 |
+
]
|
| 587 |
+
# 0, num_beams, 2*num_beams, ...
|
| 588 |
+
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
| 589 |
+
|
| 590 |
+
while cur_len < real_max_length:
|
| 591 |
+
if has_bos:
|
| 592 |
+
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
| 593 |
+
bos_sum = bos_pos.cumsum(dim=-1)
|
| 594 |
+
bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
|
| 595 |
+
to_atten_x = bos_pos[:, :, None]
|
| 596 |
+
to_atten_y = bos_pos[:, None, :]
|
| 597 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 598 |
+
else:
|
| 599 |
+
bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
|
| 600 |
+
to_atten_x = bos_pos[:, :, None]
|
| 601 |
+
to_atten_y = bos_pos[:, None, :]
|
| 602 |
+
# attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
|
| 603 |
+
attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
|
| 604 |
+
|
| 605 |
+
inference_params.attention_mask = attention_mask
|
| 606 |
+
# (bsz x num_beams, vocab_size)
|
| 607 |
+
|
| 608 |
+
if layer_mask is None:
|
| 609 |
+
if feat_mask is None and ffn_mask is None:
|
| 610 |
+
scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
|
| 611 |
+
else:
|
| 612 |
+
scores = decoder(
|
| 613 |
+
**{
|
| 614 |
+
"input_ids": token_ids[:, -1:],
|
| 615 |
+
"inference_params": inference_params,
|
| 616 |
+
"feat_mask": feat_mask,
|
| 617 |
+
"ffn_mask": ffn_mask,
|
| 618 |
+
}
|
| 619 |
+
)
|
| 620 |
+
else:
|
| 621 |
+
scores = decoder(
|
| 622 |
+
**{
|
| 623 |
+
"input_ids": token_ids[:, -1:],
|
| 624 |
+
"inference_params": inference_params,
|
| 625 |
+
"feat_mask": feat_mask,
|
| 626 |
+
"ffn_mask": ffn_mask,
|
| 627 |
+
"layer_mask": layer_mask,
|
| 628 |
+
}
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
if isinstance(scores, (list, tuple)):
|
| 632 |
+
scores = scores[0]
|
| 633 |
+
scores = scores[:, -1].float()
|
| 634 |
+
inference_params.sequence_len_offset += 1
|
| 635 |
+
if repetition_penalty != 1.0:
|
| 636 |
+
token_scores = scores.gather(dim=1, index=token_ids)
|
| 637 |
+
lt_zero_mask = token_scores.lt(0).float()
|
| 638 |
+
ge_zero_mask = lt_zero_mask.eq(0).float()
|
| 639 |
+
token_scores = (
|
| 640 |
+
lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
| 641 |
+
)
|
| 642 |
+
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
| 643 |
+
|
| 644 |
+
if _eos_token_id != -1:
|
| 645 |
+
max_len_eos_mask = max_lengths.eq(cur_len + 1)
|
| 646 |
+
eos_scores = scores[:, _eos_token_id]
|
| 647 |
+
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores)
|
| 648 |
+
|
| 649 |
+
if do_sample:
|
| 650 |
+
if temperature > 0 and temperature != 1:
|
| 651 |
+
scores = scores / temperature
|
| 652 |
+
|
| 653 |
+
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
|
| 654 |
+
# add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
|
| 655 |
+
probs = F.softmax(scores, dim=-1) + 1e-12
|
| 656 |
+
|
| 657 |
+
# batch_size' x (num_beams+1)
|
| 658 |
+
_tokens = torch.multinomial(probs, num_samples=num_beams + 1)
|
| 659 |
+
|
| 660 |
+
logits = probs.log()
|
| 661 |
+
# batch_size' x (num_beams+1)
|
| 662 |
+
_scores = logits.gather(dim=1, index=_tokens)
|
| 663 |
+
# batch_size' x (num_beams+1)
|
| 664 |
+
_scores = _scores + beam_scores[:, None]
|
| 665 |
+
_scores = _scores.view(batch_size, num_beams * (num_beams + 1))
|
| 666 |
+
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
|
| 667 |
+
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
|
| 668 |
+
# (batch_size, 2*num_beams)
|
| 669 |
+
next_tokens = _tokens.gather(dim=1, index=ids)
|
| 670 |
+
# (batch_size, 2*num_beams)
|
| 671 |
+
from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long()
|
| 672 |
+
else:
|
| 673 |
+
# (batch_size * num_beams, vocab_size)
|
| 674 |
+
scores = F.log_softmax(scores, dim=-1)
|
| 675 |
+
# (batch_size * num_beams, vocab_size)
|
| 676 |
+
_scores = scores + beam_scores[:, None]
|
| 677 |
+
# (batch_size, num_beams*vocab_size)
|
| 678 |
+
_scores = _scores.view(batch_size, -1)
|
| 679 |
+
# (bsz, 2*num_beams)
|
| 680 |
+
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
| 681 |
+
# (batch_size, 2*num_beams)
|
| 682 |
+
from_which_beam = torch.floor(ids.float() / vocab_size).long()
|
| 683 |
+
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
|
| 684 |
+
|
| 685 |
+
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
|
| 686 |
+
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
|
| 687 |
+
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
|
| 688 |
+
|
| 689 |
+
not_eos_mask = next_tokens.ne(_eos_token_id)
|
| 690 |
+
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
|
| 691 |
+
keep_mask = not_eos_mask.__and__(keep_mask)
|
| 692 |
+
|
| 693 |
+
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
|
| 694 |
+
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
|
| 695 |
+
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
| 696 |
+
beam_scores = _next_scores.view(-1)
|
| 697 |
+
|
| 698 |
+
flag = True
|
| 699 |
+
if cur_len + 1 == real_max_length:
|
| 700 |
+
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
|
| 701 |
+
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
|
| 702 |
+
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
|
| 703 |
+
else:
|
| 704 |
+
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
|
| 705 |
+
if effective_eos_mask.sum().gt(0):
|
| 706 |
+
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
|
| 707 |
+
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
|
| 708 |
+
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
|
| 709 |
+
else:
|
| 710 |
+
flag = False
|
| 711 |
+
|
| 712 |
+
if flag:
|
| 713 |
+
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
|
| 714 |
+
for batch_idx, beam_ind, beam_idx in zip(
|
| 715 |
+
eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()
|
| 716 |
+
):
|
| 717 |
+
if not dones[batch_idx]:
|
| 718 |
+
score = next_scores[batch_idx, beam_ind].item()
|
| 719 |
+
if _eos_token_id != -1:
|
| 720 |
+
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
| 721 |
+
else:
|
| 722 |
+
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
|
| 723 |
+
|
| 724 |
+
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
|
| 725 |
+
inference_params.reorder_state(reorder_inds)
|
| 726 |
+
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
|
| 727 |
+
|
| 728 |
+
for batch_idx in range(batch_size):
|
| 729 |
+
dones[batch_idx] = (
|
| 730 |
+
dones[batch_idx]
|
| 731 |
+
or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
|
| 732 |
+
or max_lengths[batch_idx * num_beams] == cur_len + 1
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
cur_len += 1
|
| 736 |
+
|
| 737 |
+
if all(dones):
|
| 738 |
+
break
|
| 739 |
+
|
| 740 |
+
# select the best hypotheses
|
| 741 |
+
tgt_len = token_ids.new_zeros(batch_size, num_return_sequences)
|
| 742 |
+
best = []
|
| 743 |
+
|
| 744 |
+
for i, hypotheses in enumerate(hypos):
|
| 745 |
+
# best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
| 746 |
+
sorted_hyp = list(sorted(hypotheses.hyp, key=lambda x: x[0], reverse=True))
|
| 747 |
+
_best = []
|
| 748 |
+
for j, hyp in zip(range(num_return_sequences), sorted_hyp):
|
| 749 |
+
hyp = hyp[1]
|
| 750 |
+
if _eos_token_id != -1:
|
| 751 |
+
hyp = torch.cat([hyp, token_ids.new_ones(1) * _eos_token_id])
|
| 752 |
+
tgt_len[i, j] = len(hyp)
|
| 753 |
+
_best.append(hyp)
|
| 754 |
+
best.append(_best)
|
| 755 |
+
|
| 756 |
+
# generate target batch
|
| 757 |
+
decoded = token_ids.new_zeros(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
|
| 758 |
+
for i, hypo in enumerate(best):
|
| 759 |
+
for j, _hypo in enumerate(hypo):
|
| 760 |
+
decoded[i, j, : tgt_len[i, j]] = _hypo
|
| 761 |
+
|
| 762 |
+
return decoded
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
class BeamHypotheses(object):
|
| 766 |
+
"""
|
| 767 |
+
BeamHypotheses
|
| 768 |
+
"""
|
| 769 |
+
|
| 770 |
+
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
| 771 |
+
"""Initialize n-best list of hypotheses."""
|
| 772 |
+
self.max_length = max_length - 1 # ignoring bos_token
|
| 773 |
+
self.length_penalty = length_penalty
|
| 774 |
+
self.early_stopping = early_stopping
|
| 775 |
+
self.num_beams = num_beams
|
| 776 |
+
self.hyp = []
|
| 777 |
+
self.worst_score = 1e9
|
| 778 |
+
|
| 779 |
+
def __len__(self):
|
| 780 |
+
"""Number of hypotheses in the list."""
|
| 781 |
+
return len(self.hyp)
|
| 782 |
+
|
| 783 |
+
def add(self, hyp, sum_logprobs):
|
| 784 |
+
"""Add a new hypothesis to the list."""
|
| 785 |
+
score = sum_logprobs / len(hyp) ** self.length_penalty
|
| 786 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
| 787 |
+
self.hyp.append((score, hyp))
|
| 788 |
+
if len(self) > self.num_beams:
|
| 789 |
+
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
| 790 |
+
del self.hyp[sorted_scores[0][1]]
|
| 791 |
+
self.worst_score = sorted_scores[1][0]
|
| 792 |
+
else:
|
| 793 |
+
self.worst_score = min(score, self.worst_score)
|
| 794 |
+
|
| 795 |
+
def is_done(self, best_sum_logprobs):
|
| 796 |
+
"""If there are enough hypotheses and that none of the hypotheses being
|
| 797 |
+
generated can become better than the worst one in the heap, then we are
|
| 798 |
+
done with this sentence."""
|
| 799 |
+
if len(self) < self.num_beams:
|
| 800 |
+
return False
|
| 801 |
+
elif self.early_stopping:
|
| 802 |
+
return True
|
| 803 |
+
else:
|
| 804 |
+
return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty
|
| 805 |
+
|
| 806 |
+
|
| 807 |
+
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
| 808 |
+
"""
|
| 809 |
+
Based on the values of top_k and top_p, set the values that do not meet the criteria to the filter_value.
|
| 810 |
+
|
| 811 |
+
Args:
|
| 812 |
+
logits: logit value, shape is [bsz, vocab_size].
|
| 813 |
+
top_k: If it is greater than 0, only the probabilities of the top_k vocabulary are kept, and the rest of
|
| 814 |
+
the positions are set to filter_value.
|
| 815 |
+
top_p: according to http://arxiv.org/abs/1904.09751.
|
| 816 |
+
filter_value: filter value
|
| 817 |
+
min_tokens_to_keep: The probability of words in each sample‘s returned distribution will not be
|
| 818 |
+
lower than this value.
|
| 819 |
+
|
| 820 |
+
"""
|
| 821 |
+
if top_k > 0:
|
| 822 |
+
# Safety check
|
| 823 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
|
| 824 |
+
# Remove all tokens with a probability less than the last token of
|
| 825 |
+
# the top-k
|
| 826 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 827 |
+
logits[indices_to_remove] = filter_value
|
| 828 |
+
|
| 829 |
+
if top_p < 1.0:
|
| 830 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 831 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 832 |
+
|
| 833 |
+
# Remove tokens with cumulative probability above the threshold
|
| 834 |
+
# (token with 0 are kept)
|
| 835 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 836 |
+
if min_tokens_to_keep > 1:
|
| 837 |
+
# Keep at least min_tokens_to_keep
|
| 838 |
+
# (set to min_tokens_to_keep-1 because we add the first one below)
|
| 839 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 840 |
+
# Shift the indices to the right to keep also the first token
|
| 841 |
+
# above the threshold
|
| 842 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 843 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 844 |
+
|
| 845 |
+
# scatter sorted tensors to original indexing
|
| 846 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 847 |
+
logits[indices_to_remove] = filter_value
|
| 848 |
+
return logits
|
InternLM/internlm/core/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .engine import Engine
|
| 2 |
+
from .naive_amp import NaiveAMPModel
|
| 3 |
+
from .trainer import Trainer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"NaiveAMPModel",
|
| 7 |
+
"Engine",
|
| 8 |
+
"Trainer",
|
| 9 |
+
]
|
InternLM/internlm/core/communication/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .p2p import (
|
| 2 |
+
AsynCommunicator,
|
| 3 |
+
recv_backward,
|
| 4 |
+
recv_forward,
|
| 5 |
+
send_backward,
|
| 6 |
+
send_backward_and_recv_next_backward_async,
|
| 7 |
+
send_backward_recv_backward,
|
| 8 |
+
send_backward_recv_forward,
|
| 9 |
+
send_forward,
|
| 10 |
+
send_forward_and_recv_next_forward_async,
|
| 11 |
+
send_forward_backward_recv_forward_backward,
|
| 12 |
+
send_forward_recv_backward,
|
| 13 |
+
send_forward_recv_forward,
|
| 14 |
+
)
|
| 15 |
+
from .utils import recv_obj_meta, send_obj_meta
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"send_forward",
|
| 19 |
+
"send_forward_recv_forward",
|
| 20 |
+
"send_forward_backward_recv_forward_backward",
|
| 21 |
+
"send_backward",
|
| 22 |
+
"send_backward_recv_backward",
|
| 23 |
+
"send_backward_recv_forward",
|
| 24 |
+
"send_forward_recv_backward",
|
| 25 |
+
"recv_backward",
|
| 26 |
+
"recv_forward",
|
| 27 |
+
"send_obj_meta",
|
| 28 |
+
"recv_obj_meta",
|
| 29 |
+
"send_backward_and_recv_next_backward_async",
|
| 30 |
+
"send_forward_and_recv_next_forward_async",
|
| 31 |
+
"AsynCommunicator",
|
| 32 |
+
]
|
InternLM/internlm/core/communication/p2p.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
| 5 |
+
|
| 6 |
+
import operator
|
| 7 |
+
from functools import reduce
|
| 8 |
+
from typing import List, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
|
| 13 |
+
from internlm.core.context import ParallelMode
|
| 14 |
+
from internlm.core.context import global_context as gpc
|
| 15 |
+
from internlm.utils.common import get_current_device
|
| 16 |
+
|
| 17 |
+
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
| 18 |
+
|
| 19 |
+
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
|
| 23 |
+
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
tensor_shape (:class:`torch.Size`): shape of tensor
|
| 27 |
+
chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
|
| 31 |
+
"""
|
| 32 |
+
if chunk_tensor:
|
| 33 |
+
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
|
| 34 |
+
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
| 35 |
+
if tensor_chunk_shape % tensor_parallel_world_size == 0:
|
| 36 |
+
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
|
| 37 |
+
else:
|
| 38 |
+
tensor_chunk_shape = tensor_shape
|
| 39 |
+
chunk_tensor = False
|
| 40 |
+
else:
|
| 41 |
+
tensor_chunk_shape = tensor_shape
|
| 42 |
+
return tensor_chunk_shape, chunk_tensor
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
| 46 |
+
if isinstance(recv_shapes, torch.Size):
|
| 47 |
+
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
| 48 |
+
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
| 49 |
+
return buffer_recv, recv_split
|
| 50 |
+
buffer_recv = []
|
| 51 |
+
for recv_shape in recv_shapes:
|
| 52 |
+
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
| 53 |
+
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
| 54 |
+
buffer_recv.append(tensor_recv)
|
| 55 |
+
return buffer_recv, recv_split
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def process_object_to_send(object_send, scatter_gather_tensors):
|
| 59 |
+
if isinstance(object_send, torch.Tensor):
|
| 60 |
+
send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
|
| 61 |
+
if send_split:
|
| 62 |
+
object_send = split_tensor_into_1d_equal_chunks(object_send)
|
| 63 |
+
return object_send
|
| 64 |
+
|
| 65 |
+
object_send_list = []
|
| 66 |
+
for tensor_send in object_send:
|
| 67 |
+
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
|
| 68 |
+
if send_split:
|
| 69 |
+
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
|
| 70 |
+
else:
|
| 71 |
+
object_send_list.append(tensor_send)
|
| 72 |
+
object_send = tuple(object_send_list)
|
| 73 |
+
|
| 74 |
+
return object_send
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
|
| 78 |
+
if isinstance(obj, torch.Tensor):
|
| 79 |
+
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
|
| 80 |
+
ops_queue.append(op_to_add)
|
| 81 |
+
else:
|
| 82 |
+
for tensor_to_comm in obj:
|
| 83 |
+
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
|
| 84 |
+
ops_queue.append(op_to_add)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _communicate(
|
| 88 |
+
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
|
| 89 |
+
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
|
| 90 |
+
recv_prev: bool = False,
|
| 91 |
+
recv_next: bool = False,
|
| 92 |
+
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
| 93 |
+
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
| 94 |
+
prev_rank: int = None,
|
| 95 |
+
next_rank: int = None,
|
| 96 |
+
dtype: torch.dtype = None,
|
| 97 |
+
scatter_gather_tensors: bool = False,
|
| 98 |
+
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
| 99 |
+
"""
|
| 100 |
+
Adapted from megatron.p2p_communication.
|
| 101 |
+
Communicate tensors between stages. Used as helper method in other
|
| 102 |
+
communication methods that are used in pipeline schedule.
|
| 103 |
+
Takes the following arguments:
|
| 104 |
+
object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
|
| 105 |
+
(no tensor sent if set to None).
|
| 106 |
+
object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
|
| 107 |
+
(no tensor sent if set to None).
|
| 108 |
+
recv_prev (bool): boolean for whether tensor should be received from
|
| 109 |
+
previous rank.
|
| 110 |
+
recv_next (bool): boolean for whether tensor should be received from
|
| 111 |
+
next rank.
|
| 112 |
+
recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
| 113 |
+
from the previous stage, defualts to None.
|
| 114 |
+
recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
|
| 115 |
+
from the next stage, defualts to None.
|
| 116 |
+
prev_rank (int): the rank of the previous pipeline stage, defualts to None,
|
| 117 |
+
next_rank (int): the rank of the next pipeline stage, defualts to None,
|
| 118 |
+
dtype (torch.dtype): data type of intermediate buffers, defaults to None
|
| 119 |
+
scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
# Create placeholder tensors for receive in forward and backward directions
|
| 126 |
+
# if needed.
|
| 127 |
+
tensor_recv_prev = None
|
| 128 |
+
tensor_recv_next = None
|
| 129 |
+
|
| 130 |
+
if recv_prev:
|
| 131 |
+
assert recv_prev_shape is not None
|
| 132 |
+
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
| 133 |
+
recv_prev_shape, dtype, scatter_gather_tensors
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if recv_next:
|
| 137 |
+
assert recv_next_shape is not None
|
| 138 |
+
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
| 139 |
+
recv_next_shape, dtype, scatter_gather_tensors
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if object_send_prev is not None or recv_prev:
|
| 143 |
+
if prev_rank is None:
|
| 144 |
+
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
| 145 |
+
|
| 146 |
+
if object_send_next is not None or recv_next:
|
| 147 |
+
if next_rank is None:
|
| 148 |
+
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
| 149 |
+
|
| 150 |
+
if object_send_prev is not None:
|
| 151 |
+
object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
|
| 152 |
+
|
| 153 |
+
if object_send_next is not None:
|
| 154 |
+
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
|
| 155 |
+
|
| 156 |
+
ops = []
|
| 157 |
+
if object_send_prev is not None:
|
| 158 |
+
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
|
| 159 |
+
|
| 160 |
+
if tensor_recv_prev is not None:
|
| 161 |
+
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
|
| 162 |
+
|
| 163 |
+
if tensor_recv_next is not None:
|
| 164 |
+
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
|
| 165 |
+
|
| 166 |
+
if object_send_next is not None:
|
| 167 |
+
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
|
| 168 |
+
|
| 169 |
+
if len(ops) > 0:
|
| 170 |
+
reqs = dist.batch_isend_irecv(ops)
|
| 171 |
+
for req in reqs:
|
| 172 |
+
req.wait()
|
| 173 |
+
# To protect against race condition when using batch_isend_irecv().
|
| 174 |
+
torch.cuda.synchronize()
|
| 175 |
+
|
| 176 |
+
if recv_prev and recv_prev_split:
|
| 177 |
+
if isinstance(tensor_recv_prev, torch.Tensor):
|
| 178 |
+
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
| 179 |
+
else:
|
| 180 |
+
for index in range(len(tensor_recv_prev)):
|
| 181 |
+
tensor_recv_prev[index] = (
|
| 182 |
+
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if recv_next and recv_next_split:
|
| 186 |
+
if isinstance(tensor_recv_next, torch.Tensor):
|
| 187 |
+
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
| 188 |
+
else:
|
| 189 |
+
for index in range(len(tensor_recv_next)):
|
| 190 |
+
tensor_recv_next[index] = (
|
| 191 |
+
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
return tensor_recv_prev, tensor_recv_next
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def recv_forward(
|
| 198 |
+
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
| 199 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 200 |
+
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 204 |
+
to be received.
|
| 205 |
+
prev_rank (int, optional): The rank of the source of the tensor.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
|
| 209 |
+
"""
|
| 210 |
+
input_tensor, _ = _communicate(
|
| 211 |
+
recv_prev=True,
|
| 212 |
+
recv_prev_shape=input_tensor_shape,
|
| 213 |
+
prev_rank=prev_rank,
|
| 214 |
+
dtype=dtype,
|
| 215 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 216 |
+
)
|
| 217 |
+
return input_tensor
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def recv_backward(
|
| 221 |
+
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
| 222 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 223 |
+
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 227 |
+
to be received.
|
| 228 |
+
next_rank (int, optional): The rank of the source of the tensor.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
|
| 232 |
+
"""
|
| 233 |
+
_, output_tensor_grad = _communicate(
|
| 234 |
+
recv_next=True,
|
| 235 |
+
recv_next_shape=output_grad_shape,
|
| 236 |
+
next_rank=next_rank,
|
| 237 |
+
dtype=dtype,
|
| 238 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 239 |
+
)
|
| 240 |
+
return output_tensor_grad
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
|
| 244 |
+
"""Sends the input tensor to the next stage in pipeline.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
| 248 |
+
next_rank (int, optional): The rank of the recipient of the tensor.
|
| 249 |
+
"""
|
| 250 |
+
_communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
|
| 254 |
+
"""Sends the gradient tensor to the previous stage in pipeline.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
|
| 258 |
+
prev_rank (int, optional): The rank of the recipient of the tensor
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
_communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def send_forward_recv_backward(
|
| 265 |
+
output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
|
| 266 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 267 |
+
"""Batched communication operation. Sends the input tensor to the
|
| 268 |
+
next stage in pipeline, while receives the gradient tensor from the
|
| 269 |
+
next stage in pipeline as the input gradient tensor of this stage.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
| 273 |
+
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 274 |
+
to be received.
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
| 278 |
+
"""
|
| 279 |
+
_, output_tensor_grad = _communicate(
|
| 280 |
+
object_send_next=output_tensor,
|
| 281 |
+
recv_next=output_grad_shape is not None,
|
| 282 |
+
recv_next_shape=output_grad_shape,
|
| 283 |
+
next_rank=next_rank,
|
| 284 |
+
dtype=dtype,
|
| 285 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
return output_tensor_grad
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def send_backward_recv_forward(
|
| 292 |
+
input_tensor_grad,
|
| 293 |
+
input_tensor_shape,
|
| 294 |
+
prev_rank=None,
|
| 295 |
+
dtype=torch.float,
|
| 296 |
+
scatter_gather_tensors=False,
|
| 297 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 298 |
+
"""Batched communication operation. Sends the gradient tensor to the
|
| 299 |
+
previous stage in pipeline, while receives the output tensor from the
|
| 300 |
+
previous stage in pipeline as the input of this stage.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
| 304 |
+
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 305 |
+
to be received.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
| 309 |
+
"""
|
| 310 |
+
input_tensor, _ = _communicate(
|
| 311 |
+
object_send_prev=input_tensor_grad,
|
| 312 |
+
recv_prev=input_tensor_shape is not None,
|
| 313 |
+
recv_prev_shape=input_tensor_shape,
|
| 314 |
+
prev_rank=prev_rank,
|
| 315 |
+
dtype=dtype,
|
| 316 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return input_tensor
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def send_forward_recv_forward(
|
| 323 |
+
output_tensor,
|
| 324 |
+
input_tensor_shape,
|
| 325 |
+
prev_rank=None,
|
| 326 |
+
next_rank=None,
|
| 327 |
+
dtype=torch.float,
|
| 328 |
+
scatter_gather_tensors=False,
|
| 329 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 330 |
+
"""Batched communication operation. Sends the input tensor to the
|
| 331 |
+
next stage in pipeline, while receives the output tensor from the
|
| 332 |
+
previous stage in pipeline as the input of this stage.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
| 336 |
+
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 337 |
+
to be received.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
|
| 341 |
+
"""
|
| 342 |
+
input_tensor, _ = _communicate(
|
| 343 |
+
object_send_next=output_tensor,
|
| 344 |
+
recv_prev=input_tensor_shape is not None,
|
| 345 |
+
recv_prev_shape=input_tensor_shape,
|
| 346 |
+
prev_rank=prev_rank,
|
| 347 |
+
next_rank=next_rank,
|
| 348 |
+
dtype=dtype,
|
| 349 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 350 |
+
)
|
| 351 |
+
return input_tensor
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def send_backward_recv_backward(
|
| 355 |
+
input_tensor_grad,
|
| 356 |
+
output_grad_shape,
|
| 357 |
+
prev_rank=None,
|
| 358 |
+
next_rank=None,
|
| 359 |
+
dtype=torch.float,
|
| 360 |
+
scatter_gather_tensors=False,
|
| 361 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 362 |
+
"""Batched communication operation. Sends the gradient tensor to the
|
| 363 |
+
previous stage in pipeline, while receives the gradient tensor from the
|
| 364 |
+
next member in pipeline as the input of this stage.
|
| 365 |
+
|
| 366 |
+
Args:
|
| 367 |
+
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
|
| 368 |
+
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
|
| 369 |
+
to be received.
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
|
| 373 |
+
"""
|
| 374 |
+
_, output_tensor_grad = _communicate(
|
| 375 |
+
object_send_prev=input_tensor_grad,
|
| 376 |
+
recv_next=output_grad_shape is not None,
|
| 377 |
+
recv_next_shape=output_grad_shape,
|
| 378 |
+
prev_rank=prev_rank,
|
| 379 |
+
next_rank=next_rank,
|
| 380 |
+
dtype=dtype,
|
| 381 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 382 |
+
)
|
| 383 |
+
return output_tensor_grad
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def send_forward_backward_recv_forward_backward(
|
| 387 |
+
output_tensor,
|
| 388 |
+
input_tensor_grad,
|
| 389 |
+
input_tensor_shape,
|
| 390 |
+
output_grad_shape,
|
| 391 |
+
prev_rank=None,
|
| 392 |
+
next_rank=None,
|
| 393 |
+
dtype=torch.float,
|
| 394 |
+
scatter_gather_tensors=False,
|
| 395 |
+
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
|
| 396 |
+
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
|
| 397 |
+
the gradient tensor to the previous stage, while receives the input gradient tensor from the
|
| 398 |
+
next stage and the input tensor from the previous stage.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
|
| 402 |
+
input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
|
| 403 |
+
input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
| 404 |
+
from the previous.
|
| 405 |
+
output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
|
| 406 |
+
from the next.
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`,
|
| 410 |
+
List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
|
| 411 |
+
"""
|
| 412 |
+
input_tensor, output_tensor_grad = _communicate(
|
| 413 |
+
object_send_next=output_tensor,
|
| 414 |
+
object_send_prev=input_tensor_grad,
|
| 415 |
+
recv_prev=input_tensor_shape is not None,
|
| 416 |
+
recv_next=output_grad_shape is not None,
|
| 417 |
+
recv_prev_shape=input_tensor_shape,
|
| 418 |
+
recv_next_shape=output_grad_shape,
|
| 419 |
+
prev_rank=prev_rank,
|
| 420 |
+
next_rank=next_rank,
|
| 421 |
+
dtype=dtype,
|
| 422 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 423 |
+
)
|
| 424 |
+
return input_tensor, output_tensor_grad
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def send_forward_and_recv_next_forward_async(
|
| 428 |
+
output_tensor,
|
| 429 |
+
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
|
| 430 |
+
dtype: torch.dtype = None,
|
| 431 |
+
scatter_gather_tensors=False,
|
| 432 |
+
):
|
| 433 |
+
"""send forward output to next rank and recv forward input from prev rank"""
|
| 434 |
+
|
| 435 |
+
reqs = []
|
| 436 |
+
tensor_recv_prev = None
|
| 437 |
+
|
| 438 |
+
# prepare send opreations
|
| 439 |
+
if output_tensor is not None:
|
| 440 |
+
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
| 441 |
+
|
| 442 |
+
output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
|
| 443 |
+
|
| 444 |
+
if isinstance(output_tensor, torch.Tensor):
|
| 445 |
+
reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
|
| 446 |
+
else:
|
| 447 |
+
for tensor_to_comm in output_tensor:
|
| 448 |
+
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
|
| 449 |
+
|
| 450 |
+
# prepare receive opreations
|
| 451 |
+
if recv_prev_shape is not None:
|
| 452 |
+
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
| 453 |
+
# create receive buffer
|
| 454 |
+
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
|
| 455 |
+
recv_prev_shape, dtype, scatter_gather_tensors
|
| 456 |
+
)
|
| 457 |
+
# generate async receive opterations
|
| 458 |
+
if isinstance(tensor_recv_prev, torch.Tensor):
|
| 459 |
+
reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
|
| 460 |
+
else:
|
| 461 |
+
for tensor_to_comm in tensor_recv_prev:
|
| 462 |
+
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
|
| 463 |
+
|
| 464 |
+
if len(reqs) > 0:
|
| 465 |
+
reqs = dist.batch_isend_irecv(reqs)
|
| 466 |
+
|
| 467 |
+
# return and do other things
|
| 468 |
+
yield
|
| 469 |
+
|
| 470 |
+
# check communication completed
|
| 471 |
+
for req in reqs:
|
| 472 |
+
req.wait()
|
| 473 |
+
# To protect against race condition when using batch_isend_irecv()
|
| 474 |
+
torch.cuda.synchronize()
|
| 475 |
+
|
| 476 |
+
# Process received data
|
| 477 |
+
if recv_prev_shape is not None and recv_prev_split:
|
| 478 |
+
if isinstance(tensor_recv_prev, torch.Tensor):
|
| 479 |
+
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
|
| 480 |
+
else:
|
| 481 |
+
for index in range(len(tensor_recv_prev)):
|
| 482 |
+
tensor_recv_prev[index] = (
|
| 483 |
+
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
yield tensor_recv_prev
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def send_backward_and_recv_next_backward_async(
|
| 490 |
+
input_tensor,
|
| 491 |
+
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
|
| 492 |
+
dtype: torch.dtype = None,
|
| 493 |
+
scatter_gather_tensors=False,
|
| 494 |
+
):
|
| 495 |
+
reqs = []
|
| 496 |
+
tensor_recv_next = None
|
| 497 |
+
|
| 498 |
+
# prepare send opreations
|
| 499 |
+
if input_tensor is not None:
|
| 500 |
+
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
| 501 |
+
|
| 502 |
+
input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
|
| 503 |
+
|
| 504 |
+
if isinstance(input_tensor, torch.Tensor):
|
| 505 |
+
reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
|
| 506 |
+
else:
|
| 507 |
+
for tensor_to_comm in input_tensor:
|
| 508 |
+
reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
|
| 509 |
+
|
| 510 |
+
# prepare receive opreations
|
| 511 |
+
if recv_next_shape is not None:
|
| 512 |
+
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
| 513 |
+
# create receive buffer
|
| 514 |
+
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
|
| 515 |
+
recv_next_shape, dtype, scatter_gather_tensors
|
| 516 |
+
)
|
| 517 |
+
# generate async receive opreations
|
| 518 |
+
if isinstance(tensor_recv_next, torch.Tensor):
|
| 519 |
+
reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
|
| 520 |
+
else:
|
| 521 |
+
for tensor_to_comm in tensor_recv_next:
|
| 522 |
+
reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
|
| 523 |
+
|
| 524 |
+
if len(reqs) > 0:
|
| 525 |
+
reqs = dist.batch_isend_irecv(reqs)
|
| 526 |
+
|
| 527 |
+
# return and do other things
|
| 528 |
+
yield
|
| 529 |
+
|
| 530 |
+
# check communication completed
|
| 531 |
+
for req in reqs:
|
| 532 |
+
req.wait()
|
| 533 |
+
# To protect against race condition when using batch_isend_irecv()
|
| 534 |
+
torch.cuda.synchronize()
|
| 535 |
+
|
| 536 |
+
# Process received data
|
| 537 |
+
if recv_next_shape is not None and recv_next_split:
|
| 538 |
+
if isinstance(tensor_recv_next, torch.Tensor):
|
| 539 |
+
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
|
| 540 |
+
else:
|
| 541 |
+
for index in range(len(tensor_recv_next)):
|
| 542 |
+
tensor_recv_next[index] = (
|
| 543 |
+
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
yield tensor_recv_next
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class AsynCommunicator:
|
| 550 |
+
"""AsynCommunicator for managing async communication."""
|
| 551 |
+
|
| 552 |
+
def __init__(
|
| 553 |
+
self,
|
| 554 |
+
tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
|
| 555 |
+
recv_shape: Union[torch.Size, List[torch.Size]],
|
| 556 |
+
dtype: torch.dtype = None,
|
| 557 |
+
scatter_gather_tensors=False,
|
| 558 |
+
forward: bool = True,
|
| 559 |
+
) -> None:
|
| 560 |
+
self._need_receive = recv_shape is not None
|
| 561 |
+
|
| 562 |
+
if forward:
|
| 563 |
+
self._coroutine = send_forward_and_recv_next_forward_async(
|
| 564 |
+
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
| 565 |
+
)
|
| 566 |
+
else:
|
| 567 |
+
self._coroutine = send_backward_and_recv_next_backward_async(
|
| 568 |
+
tensor_to_send, recv_shape, dtype, scatter_gather_tensors
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
@property
|
| 572 |
+
def need_receive(self) -> bool:
|
| 573 |
+
return self._need_receive
|
| 574 |
+
|
| 575 |
+
def start(self) -> None:
|
| 576 |
+
next(self._coroutine)
|
| 577 |
+
|
| 578 |
+
def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
|
| 579 |
+
received = next(self._coroutine)
|
| 580 |
+
self._coroutine.close()
|
| 581 |
+
|
| 582 |
+
return received
|
InternLM/internlm/core/communication/utils.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
|
| 2 |
+
|
| 3 |
+
from typing import List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
|
| 8 |
+
from internlm.core.context import ParallelMode
|
| 9 |
+
from internlm.core.context import global_context as gpc
|
| 10 |
+
from internlm.utils.common import get_current_device
|
| 11 |
+
|
| 12 |
+
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
| 16 |
+
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
| 17 |
+
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
| 18 |
+
dist.send(send_ndims, next_rank)
|
| 19 |
+
dist.send(send_shape, next_rank)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def send_obj_meta(obj, next_rank=None):
|
| 23 |
+
"""Sends obj meta information before sending a specific obj.
|
| 24 |
+
Since the recipient must know the shape of the obj in p2p communications,
|
| 25 |
+
meta information of the obj should be sent before communications. This function
|
| 26 |
+
synchronizes with :func:`recv_obj_meta`.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
| 30 |
+
need_meta (bool, optional): If False, meta information won't be sent.
|
| 31 |
+
next_rank (int): The rank of the next member in pipeline parallel group.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
bool: False
|
| 35 |
+
"""
|
| 36 |
+
if next_rank is None:
|
| 37 |
+
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
| 38 |
+
|
| 39 |
+
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
| 40 |
+
if isinstance(obj, torch.Tensor):
|
| 41 |
+
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
| 42 |
+
dist.send(send_obj_nums, next_rank)
|
| 43 |
+
send_meta_helper(obj, next_rank, tensor_kwargs)
|
| 44 |
+
else:
|
| 45 |
+
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
| 46 |
+
dist.send(send_obj_nums, next_rank)
|
| 47 |
+
for tensor_to_send in obj:
|
| 48 |
+
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def recv_meta_helper(prev_rank, tensor_kwargs):
|
| 52 |
+
recv_ndims = torch.empty((), **tensor_kwargs)
|
| 53 |
+
dist.recv(recv_ndims, prev_rank)
|
| 54 |
+
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
| 55 |
+
dist.recv(recv_shape, prev_rank)
|
| 56 |
+
return recv_shape
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def recv_obj_meta(prev_rank=None) -> torch.Size:
|
| 60 |
+
"""Receives obj meta information before receiving a specific obj.
|
| 61 |
+
Since the recipient must know the shape of the obj in p2p communications,
|
| 62 |
+
meta information of the obj should be received before communications. This function
|
| 63 |
+
synchronizes with :func:`send_obj_meta`.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
| 67 |
+
prev_rank (int): The rank of the source of the obj.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
| 71 |
+
"""
|
| 72 |
+
if prev_rank is None:
|
| 73 |
+
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
| 74 |
+
|
| 75 |
+
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
| 76 |
+
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
| 77 |
+
dist.recv(recv_obj_nums, prev_rank)
|
| 78 |
+
if recv_obj_nums.item() == 1:
|
| 79 |
+
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
| 80 |
+
obj_shape = torch.Size(recv_shape)
|
| 81 |
+
else:
|
| 82 |
+
obj_shape = []
|
| 83 |
+
for _ in range(recv_obj_nums.item()):
|
| 84 |
+
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
| 85 |
+
obj_shape.append(torch.Size(recv_shape))
|
| 86 |
+
|
| 87 |
+
return obj_shape
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
| 91 |
+
"""Break a tensor into equal 1D chunks.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
tensor (:class:`torch.Tensor`): Tensor to be split before communication.
|
| 95 |
+
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
:class:`torch.Tensor`: The split tensor
|
| 99 |
+
"""
|
| 100 |
+
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
|
| 101 |
+
start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
|
| 102 |
+
end_index = start_index + partition_size
|
| 103 |
+
if new_buffer:
|
| 104 |
+
data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
| 105 |
+
data.copy_(tensor.view(-1)[start_index:end_index])
|
| 106 |
+
else:
|
| 107 |
+
data = tensor.view(-1)[start_index:end_index]
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
| 112 |
+
"""Opposite of above function, gather values from model parallel ranks.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
|
| 116 |
+
Returns:
|
| 117 |
+
:class:`torch.Tensor`: The gathered tensor.
|
| 118 |
+
"""
|
| 119 |
+
world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
| 120 |
+
numel = torch.numel(tensor)
|
| 121 |
+
numel_gathered = world_size * numel
|
| 122 |
+
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
|
| 123 |
+
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
|
| 124 |
+
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
|
| 125 |
+
return gathered
|
InternLM/internlm/core/context/__init__.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .parallel_context import (
|
| 2 |
+
IS_TENSOR_PARALLEL,
|
| 3 |
+
Config,
|
| 4 |
+
ParallelContext,
|
| 5 |
+
global_context,
|
| 6 |
+
)
|
| 7 |
+
from .process_group_initializer import (
|
| 8 |
+
Initializer_Data,
|
| 9 |
+
Initializer_Model,
|
| 10 |
+
Initializer_Nettest,
|
| 11 |
+
Initializer_Pipeline,
|
| 12 |
+
Initializer_Tensor,
|
| 13 |
+
Initializer_Zero1,
|
| 14 |
+
ParallelMode,
|
| 15 |
+
ProcessGroupInitializer,
|
| 16 |
+
)
|
| 17 |
+
from .random import (
|
| 18 |
+
add_seed,
|
| 19 |
+
get_current_mode,
|
| 20 |
+
get_seeds,
|
| 21 |
+
get_states,
|
| 22 |
+
seed,
|
| 23 |
+
set_mode,
|
| 24 |
+
set_seed_states,
|
| 25 |
+
sync_states,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
__all__ = [
|
| 29 |
+
"Config",
|
| 30 |
+
"IS_TENSOR_PARALLEL",
|
| 31 |
+
"global_context",
|
| 32 |
+
"ParallelContext",
|
| 33 |
+
"ParallelMode",
|
| 34 |
+
"Initializer_Tensor",
|
| 35 |
+
"Initializer_Pipeline",
|
| 36 |
+
"Initializer_Data",
|
| 37 |
+
"Initializer_Zero1",
|
| 38 |
+
"Initializer_Nettest",
|
| 39 |
+
"ProcessGroupInitializer",
|
| 40 |
+
"Initializer_Model",
|
| 41 |
+
"seed",
|
| 42 |
+
"set_mode",
|
| 43 |
+
"add_seed",
|
| 44 |
+
"get_seeds",
|
| 45 |
+
"get_states",
|
| 46 |
+
"get_current_mode",
|
| 47 |
+
"set_seed_states",
|
| 48 |
+
"sync_states",
|
| 49 |
+
]
|
InternLM/internlm/core/context/parallel_context.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
| 5 |
+
|
| 6 |
+
import inspect
|
| 7 |
+
import random
|
| 8 |
+
import socket
|
| 9 |
+
import sys
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from importlib.machinery import SourceFileLoader
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Union
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.distributed as dist
|
| 18 |
+
|
| 19 |
+
from internlm.utils.common import SingletonMeta
|
| 20 |
+
from internlm.utils.logger import get_logger
|
| 21 |
+
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
| 22 |
+
|
| 23 |
+
from . import process_group_initializer as pgroup_initializer
|
| 24 |
+
from .process_group_initializer import ParallelMode
|
| 25 |
+
from .random import add_seed, get_seeds, set_mode
|
| 26 |
+
|
| 27 |
+
IS_TENSOR_PARALLEL = "is_tensor_parallel"
|
| 28 |
+
|
| 29 |
+
logger = get_logger(__file__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Config(dict):
|
| 33 |
+
"""This is a wrapper class for dict objects so that values of which can be
|
| 34 |
+
accessed as attributes.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
config (dict): The dict object to be wrapped.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, config: dict = None): # pylint: disable=W0231
|
| 41 |
+
if config is not None:
|
| 42 |
+
for k, v in config.items():
|
| 43 |
+
self._add_item(k, v)
|
| 44 |
+
|
| 45 |
+
def __missing__(self, key):
|
| 46 |
+
raise KeyError(key)
|
| 47 |
+
|
| 48 |
+
def __getattr__(self, key):
|
| 49 |
+
try:
|
| 50 |
+
value = super().__getitem__(key)
|
| 51 |
+
return value
|
| 52 |
+
except KeyError:
|
| 53 |
+
raise AttributeError(key)
|
| 54 |
+
|
| 55 |
+
def __setattr__(self, key, value):
|
| 56 |
+
super().__setitem__(key, value)
|
| 57 |
+
|
| 58 |
+
def _add_item(self, key, value):
|
| 59 |
+
if isinstance(value, dict):
|
| 60 |
+
self.__setattr__(key, Config(value))
|
| 61 |
+
else:
|
| 62 |
+
self.__setattr__(key, value)
|
| 63 |
+
|
| 64 |
+
def update(self, config):
|
| 65 |
+
assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
|
| 66 |
+
for k, v in config.items():
|
| 67 |
+
self._add_item(k, v)
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def from_file(filename: str) -> object:
|
| 72 |
+
"""Reads a python file and constructs a corresponding :class:`Config` object.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
filename (str): Name of the file to construct the return object.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
:class:`Config`: A :class:`Config` object constructed with information in the file.
|
| 79 |
+
|
| 80 |
+
Raises:
|
| 81 |
+
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
# check config path
|
| 85 |
+
if isinstance(filename, str):
|
| 86 |
+
filepath = Path(filename).absolute()
|
| 87 |
+
elif isinstance(filename, Path):
|
| 88 |
+
filepath = filename.absolute()
|
| 89 |
+
|
| 90 |
+
assert filepath.exists(), f"{filename} is not found, please check your configuration path"
|
| 91 |
+
|
| 92 |
+
# check extension
|
| 93 |
+
extension = filepath.suffix
|
| 94 |
+
assert extension == ".py", "only .py files are supported"
|
| 95 |
+
|
| 96 |
+
# import the config as module
|
| 97 |
+
remove_path = False
|
| 98 |
+
if filepath.parent not in sys.path:
|
| 99 |
+
sys.path.insert(0, (filepath))
|
| 100 |
+
remove_path = True
|
| 101 |
+
|
| 102 |
+
module_name = filepath.stem
|
| 103 |
+
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
|
| 104 |
+
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
|
| 105 |
+
|
| 106 |
+
# load into config
|
| 107 |
+
config = Config()
|
| 108 |
+
|
| 109 |
+
for k, v in module.__dict__.items():
|
| 110 |
+
if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
|
| 111 |
+
continue
|
| 112 |
+
else:
|
| 113 |
+
config._add_item(k, v)
|
| 114 |
+
|
| 115 |
+
# remove module
|
| 116 |
+
del sys.modules[module_name]
|
| 117 |
+
if remove_path:
|
| 118 |
+
sys.path.pop(0)
|
| 119 |
+
|
| 120 |
+
return config
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ParallelContext(metaclass=SingletonMeta):
|
| 124 |
+
"""This class provides interface functions for users to get the parallel context,
|
| 125 |
+
such as the global rank, the local rank, the world size, etc. of each device.
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self):
|
| 130 |
+
# distributed settings
|
| 131 |
+
self._global_ranks = dict()
|
| 132 |
+
self._local_ranks = dict()
|
| 133 |
+
self._world_sizes = dict()
|
| 134 |
+
self._groups = dict()
|
| 135 |
+
self._cpu_groups = dict()
|
| 136 |
+
self._ranks_in_group = dict()
|
| 137 |
+
|
| 138 |
+
# load config from file
|
| 139 |
+
self._config = None
|
| 140 |
+
|
| 141 |
+
# default parallel args, will be overwritten during process group intialization
|
| 142 |
+
self.world_size = 1
|
| 143 |
+
self.data_parallel_size = 1
|
| 144 |
+
self.pipeline_parallel_size = 1
|
| 145 |
+
self.tensor_parallel_size = 1
|
| 146 |
+
self.zero1_parallel_size = -1
|
| 147 |
+
self.nettest_parallel_size = 1
|
| 148 |
+
self.num_processes_on_current_node = -1
|
| 149 |
+
self.virtual_pipeline_parallel_size = None
|
| 150 |
+
self.virtual_pipeline_parallel_rank = None
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def config(self):
|
| 154 |
+
return self._config
|
| 155 |
+
|
| 156 |
+
def load_config(self, config: Union[dict, str]):
|
| 157 |
+
"""Loads the configuration from either a dict or a file.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
config (dict or str): Either a dict containing the configuration information or the filename
|
| 161 |
+
of a file containing the configuration information.
|
| 162 |
+
|
| 163 |
+
Raises:
|
| 164 |
+
TypeError: Raises a TypeError if `config` is neither a dict nor a str.
|
| 165 |
+
"""
|
| 166 |
+
if isinstance(config, str):
|
| 167 |
+
self._config = Config.from_file(config)
|
| 168 |
+
elif isinstance(config, dict):
|
| 169 |
+
self._config = Config(config)
|
| 170 |
+
else:
|
| 171 |
+
raise TypeError("Invalid type for config, only dictionary or string is supported")
|
| 172 |
+
|
| 173 |
+
def detect_num_processes_on_current_node(self):
|
| 174 |
+
hostname = socket.gethostname()
|
| 175 |
+
hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
|
| 176 |
+
dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
|
| 177 |
+
counter = Counter(hostname_list)
|
| 178 |
+
self.num_processes_on_current_node = counter[hostname]
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _check_parallel_mode(parallel_mode: ParallelMode):
|
| 182 |
+
assert isinstance(
|
| 183 |
+
parallel_mode, ParallelMode
|
| 184 |
+
), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}"
|
| 185 |
+
|
| 186 |
+
def get_global_rank(self):
|
| 187 |
+
"""Returns the global rank of the current device.
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
int: The global rank of the current device
|
| 191 |
+
"""
|
| 192 |
+
return self._global_ranks[ParallelMode.GLOBAL]
|
| 193 |
+
|
| 194 |
+
def get_local_rank(self, parallel_mode: ParallelMode):
|
| 195 |
+
"""Returns the local rank of the current device.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
parallel_mode: The parallel mode for the rank.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
int: The local rank of the current device for `parallel_mode`.
|
| 202 |
+
"""
|
| 203 |
+
self._check_parallel_mode(parallel_mode)
|
| 204 |
+
return self._local_ranks.get(parallel_mode, 0)
|
| 205 |
+
|
| 206 |
+
def get_next_global_rank(self, parallel_mode: ParallelMode):
|
| 207 |
+
"""Returns the global rank of the next device.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
parallel_mode: The parallel mode for the rank.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
int: The global rank of the next device for `parallel_mode`.
|
| 214 |
+
"""
|
| 215 |
+
self._check_parallel_mode(parallel_mode)
|
| 216 |
+
|
| 217 |
+
# get rank and world size
|
| 218 |
+
local_rank = self.get_local_rank(parallel_mode)
|
| 219 |
+
world_size = self.get_world_size(parallel_mode)
|
| 220 |
+
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
| 221 |
+
|
| 222 |
+
return ranks_in_group[(local_rank + 1) % world_size]
|
| 223 |
+
|
| 224 |
+
def get_prev_global_rank(self, parallel_mode: ParallelMode):
|
| 225 |
+
"""Returns the global rank of the previous device.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
parallel_mode: The chosen parallel mode.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
int: The global rank of the previous device for `parallel_mode`.
|
| 232 |
+
"""
|
| 233 |
+
self._check_parallel_mode(parallel_mode)
|
| 234 |
+
|
| 235 |
+
# get rank and world size
|
| 236 |
+
local_rank = self.get_local_rank(parallel_mode)
|
| 237 |
+
world_size = self.get_world_size(parallel_mode)
|
| 238 |
+
ranks_in_group = self.get_ranks_in_group(parallel_mode)
|
| 239 |
+
|
| 240 |
+
return ranks_in_group[(local_rank - 1) % world_size]
|
| 241 |
+
|
| 242 |
+
def is_using_dp(self):
|
| 243 |
+
"""Returns a boolean value indicating whether the current device is initilized with
|
| 244 |
+
ParallelMode.DATA and its world_size is greater than 1.
|
| 245 |
+
"""
|
| 246 |
+
return self.is_initialized(ParallelMode.DATA) and self.get_world_size(ParallelMode.DATA) > 1
|
| 247 |
+
|
| 248 |
+
def is_using_tp(self):
|
| 249 |
+
"""Returns a boolean value indicating whether the current device is initilized with
|
| 250 |
+
ParallelMode.TENSOR and its world_size is greater than 1.
|
| 251 |
+
"""
|
| 252 |
+
return self.is_initialized(ParallelMode.TENSOR) and self.get_world_size(ParallelMode.TENSOR) > 1
|
| 253 |
+
|
| 254 |
+
def is_using_pp(self):
|
| 255 |
+
"""Returns a boolean value indicating whether the current device is initilized with
|
| 256 |
+
ParallelMode.PIPELINE and its world_size is greater than 1.
|
| 257 |
+
"""
|
| 258 |
+
return self.is_initialized(ParallelMode.PIPELINE) and self.get_world_size(ParallelMode.PIPELINE) > 1
|
| 259 |
+
|
| 260 |
+
def is_using_sequence(self):
|
| 261 |
+
"""Returns a boolean value indicating whether the current device is initilized with
|
| 262 |
+
ParallelMode.SEQUENCE and its world_size is greater than 1.
|
| 263 |
+
"""
|
| 264 |
+
return False
|
| 265 |
+
# return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
|
| 266 |
+
|
| 267 |
+
def is_first_rank(self, parallel_mode: ParallelMode):
|
| 268 |
+
"""Returns a boolean value indicating whether the current device is the first one
|
| 269 |
+
among its group for `parallel_mode`.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
parallel_mode: The chosen parallel mode.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
bool: a boolean value indicating whether the current device is the first one
|
| 276 |
+
among its group for `parallel_mode`.
|
| 277 |
+
"""
|
| 278 |
+
rank = 0
|
| 279 |
+
if self.is_initialized(parallel_mode):
|
| 280 |
+
rank = self.get_local_rank(parallel_mode)
|
| 281 |
+
return rank == 0
|
| 282 |
+
|
| 283 |
+
def is_rank_for_log(self):
|
| 284 |
+
"""Returns a boolean value indicating whether the current device should print log."""
|
| 285 |
+
is_log_rank = (
|
| 286 |
+
self.is_first_rank(ParallelMode.DATA)
|
| 287 |
+
and self.is_first_rank(ParallelMode.TENSOR)
|
| 288 |
+
and self.is_last_rank(ParallelMode.PIPELINE)
|
| 289 |
+
)
|
| 290 |
+
return is_log_rank
|
| 291 |
+
|
| 292 |
+
def is_last_rank(self, parallel_mode: ParallelMode):
|
| 293 |
+
"""Returns a boolean value indicating whether the current device is the last one
|
| 294 |
+
among its group for `parallel_mode`.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
parallel_mode: The chosen parallel mode.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
bool: a boolean value indicating whether the current device is the first one
|
| 301 |
+
among its group for `parallel_mode`.
|
| 302 |
+
"""
|
| 303 |
+
rank = 0
|
| 304 |
+
world_size = 1
|
| 305 |
+
if self.is_initialized(parallel_mode):
|
| 306 |
+
rank = self.get_local_rank(parallel_mode)
|
| 307 |
+
world_size = self.get_world_size(parallel_mode)
|
| 308 |
+
return rank == world_size - 1
|
| 309 |
+
|
| 310 |
+
def is_pipeline_first_stage(self, ignore_virtual=False):
|
| 311 |
+
if not ignore_virtual:
|
| 312 |
+
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
|
| 313 |
+
return False
|
| 314 |
+
return self.is_first_rank(ParallelMode.PIPELINE)
|
| 315 |
+
|
| 316 |
+
def is_pipeline_last_stage(self, ignore_virtual=False):
|
| 317 |
+
if not ignore_virtual:
|
| 318 |
+
if (
|
| 319 |
+
self.virtual_pipeline_parallel_size is not None
|
| 320 |
+
and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
|
| 321 |
+
):
|
| 322 |
+
return False
|
| 323 |
+
return self.is_last_rank(ParallelMode.PIPELINE)
|
| 324 |
+
|
| 325 |
+
def get_world_size(self, parallel_mode: ParallelMode):
|
| 326 |
+
"""Returns the world size for `parallel_mode`.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
parallel_mode: The chosen parallel mode.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
int: The world size for `parallel_mode`.
|
| 333 |
+
"""
|
| 334 |
+
self._check_parallel_mode(parallel_mode)
|
| 335 |
+
return self._world_sizes.get(parallel_mode, 1)
|
| 336 |
+
|
| 337 |
+
def get_group(self, parallel_mode: ParallelMode):
|
| 338 |
+
"""Returns the group of the current device for `parallel_mode`.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
parallel_mode: The chosen parallel mode.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
|
| 345 |
+
"""
|
| 346 |
+
self._check_parallel_mode(parallel_mode)
|
| 347 |
+
return self._groups[parallel_mode]
|
| 348 |
+
|
| 349 |
+
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
| 350 |
+
"""Returns the rank of the current device for `parallel_mode` in the group.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
parallel_mode: The chosen parallel mode.
|
| 354 |
+
|
| 355 |
+
Returns:
|
| 356 |
+
int: The rank of the current device for `parallel_mode` in the group.
|
| 357 |
+
"""
|
| 358 |
+
self._check_parallel_mode(parallel_mode)
|
| 359 |
+
return self._ranks_in_group[parallel_mode]
|
| 360 |
+
|
| 361 |
+
def get_cpu_group(self, parallel_mode: ParallelMode):
|
| 362 |
+
self._check_parallel_mode(parallel_mode)
|
| 363 |
+
return self._cpu_groups[parallel_mode]
|
| 364 |
+
|
| 365 |
+
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int, use_cpu: bool = False):
|
| 366 |
+
"""Initializes the global distributed environment
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
rank (int): rank for the default process group.
|
| 370 |
+
world_size (int): world size of the default process group.
|
| 371 |
+
backend (str): backend for ``torch.distributed``
|
| 372 |
+
host (str): the master address for distributed training.
|
| 373 |
+
port (str): the master port for distributed training.
|
| 374 |
+
use_cpu (bool): whether to set up cpu process group.
|
| 375 |
+
"""
|
| 376 |
+
# initialize the default process group
|
| 377 |
+
init_method = f"tcp://[{host}]:{port}"
|
| 378 |
+
dist.init_process_group(
|
| 379 |
+
rank=rank,
|
| 380 |
+
world_size=world_size,
|
| 381 |
+
backend=backend,
|
| 382 |
+
init_method=init_method,
|
| 383 |
+
timeout=LLM_NCCL_TIMEOUT,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# None will give the default global process group for pytorch dist operations
|
| 387 |
+
ranks = list(range(world_size))
|
| 388 |
+
if use_cpu:
|
| 389 |
+
cpu_group = (
|
| 390 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 391 |
+
if dist.get_backend() != "gloo"
|
| 392 |
+
else None
|
| 393 |
+
)
|
| 394 |
+
else:
|
| 395 |
+
cpu_group = None
|
| 396 |
+
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
|
| 397 |
+
self._global_ranks[ParallelMode.GLOBAL] = rank
|
| 398 |
+
|
| 399 |
+
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
| 400 |
+
self._check_parallel_mode(mode)
|
| 401 |
+
self._local_ranks[mode] = local_rank
|
| 402 |
+
self._world_sizes[mode] = world_size
|
| 403 |
+
self._groups[mode] = process_group
|
| 404 |
+
self._cpu_groups[mode] = cpu_group
|
| 405 |
+
self._ranks_in_group[mode] = ranks_in_group
|
| 406 |
+
|
| 407 |
+
def check_sanity(self):
|
| 408 |
+
"""Checks sanity of the parallel context.
|
| 409 |
+
|
| 410 |
+
Raises:
|
| 411 |
+
AssertionError: Raises an AssertionError if the world size does not equal to the product
|
| 412 |
+
of data parallel size, pipeline parallel size and tensor parallel size.
|
| 413 |
+
"""
|
| 414 |
+
dps = self.data_parallel_size
|
| 415 |
+
pps = self.pipeline_parallel_size
|
| 416 |
+
tps = self.tensor_parallel_size
|
| 417 |
+
ws = self.world_size
|
| 418 |
+
assert ws == dps * pps * tps, (
|
| 419 |
+
f"Expected the world size {ws} to be equal to data"
|
| 420 |
+
f" parallel size ({dps}) * pipeline parallel size "
|
| 421 |
+
f"({pps}) * tensor parallel size ({tps})"
|
| 422 |
+
)
|
| 423 |
+
assert self.zero1_parallel_size > 0
|
| 424 |
+
assert self.data_parallel_size % self.zero1_parallel_size == 0
|
| 425 |
+
|
| 426 |
+
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
|
| 427 |
+
if key in config:
|
| 428 |
+
ele = config[key]
|
| 429 |
+
if isinstance(ele, int):
|
| 430 |
+
setattr(self, attr_name, ele)
|
| 431 |
+
elif isinstance(ele, dict):
|
| 432 |
+
setattr(self, attr_name, ele["size"])
|
| 433 |
+
else:
|
| 434 |
+
raise NotImplementedError(
|
| 435 |
+
f'{"Parallel configuration does not support this kind of argument, please use int or dict"}'
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
def init_parallel_groups(self):
|
| 439 |
+
"""Initializes the parallel groups."""
|
| 440 |
+
|
| 441 |
+
# get rank and world size
|
| 442 |
+
rank = self.get_global_rank()
|
| 443 |
+
world_size = self.get_world_size(ParallelMode.GLOBAL)
|
| 444 |
+
self.world_size = world_size
|
| 445 |
+
|
| 446 |
+
# set parallel size as attributes for global context
|
| 447 |
+
parallel_config = self.config.get("parallel", None)
|
| 448 |
+
if parallel_config is not None:
|
| 449 |
+
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
|
| 450 |
+
self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
|
| 451 |
+
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
|
| 452 |
+
|
| 453 |
+
# the user should not set the data parallel size manually
|
| 454 |
+
# instead, it should be calculated based on other parallel config
|
| 455 |
+
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
|
| 456 |
+
|
| 457 |
+
# the recommended nettest_parallel_size is 32 GPUs
|
| 458 |
+
self.nettest_parallel_size = 32
|
| 459 |
+
|
| 460 |
+
if self.zero1_parallel_size <= 0:
|
| 461 |
+
self.zero1_parallel_size = self.data_parallel_size
|
| 462 |
+
|
| 463 |
+
self.check_sanity()
|
| 464 |
+
|
| 465 |
+
initializer_args = [
|
| 466 |
+
rank,
|
| 467 |
+
world_size,
|
| 468 |
+
self.data_parallel_size,
|
| 469 |
+
self.pipeline_parallel_size,
|
| 470 |
+
self.tensor_parallel_size,
|
| 471 |
+
self.zero1_parallel_size,
|
| 472 |
+
self.nettest_parallel_size,
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
# run initialization of different process groups
|
| 476 |
+
initializers = []
|
| 477 |
+
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
|
| 478 |
+
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
|
| 479 |
+
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
|
| 480 |
+
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
|
| 481 |
+
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
|
| 482 |
+
if self.pipeline_parallel_size > 1:
|
| 483 |
+
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
|
| 484 |
+
for initializer in initializers:
|
| 485 |
+
parallel_setting = initializer.init_dist_group()
|
| 486 |
+
if isinstance(parallel_setting, list):
|
| 487 |
+
for args in parallel_setting:
|
| 488 |
+
self._register_dist(*args)
|
| 489 |
+
else:
|
| 490 |
+
self._register_dist(*parallel_setting)
|
| 491 |
+
|
| 492 |
+
def is_initialized(self, parallel_mode: ParallelMode):
|
| 493 |
+
"""Returns a boolean value indicating whether `parallel_mode` is initialized
|
| 494 |
+
in the current system.
|
| 495 |
+
"""
|
| 496 |
+
return parallel_mode in self._groups
|
| 497 |
+
|
| 498 |
+
def destroy(self):
|
| 499 |
+
"""Destroys the current distributed parallel environment."""
|
| 500 |
+
for mode, group in self._groups.items():
|
| 501 |
+
if mode is not ParallelMode.GLOBAL:
|
| 502 |
+
dist.destroy_process_group(group)
|
| 503 |
+
# destroy global process group
|
| 504 |
+
dist.destroy_process_group()
|
| 505 |
+
self._groups.clear()
|
| 506 |
+
|
| 507 |
+
def set_device(self, device_ordinal: int = None):
|
| 508 |
+
"""Sets distributed processes to be bound to devices.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
device_ordinal (int, optional): the device id to be bound to
|
| 512 |
+
"""
|
| 513 |
+
global_rank = self.get_global_rank()
|
| 514 |
+
if device_ordinal is None:
|
| 515 |
+
devices_per_node = torch.cuda.device_count()
|
| 516 |
+
device_ordinal = global_rank % devices_per_node
|
| 517 |
+
|
| 518 |
+
torch.cuda.set_device(device_ordinal)
|
| 519 |
+
logger.info(f"process rank {global_rank} is bound to host:{socket.gethostname()} device: {device_ordinal}")
|
| 520 |
+
|
| 521 |
+
def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False):
|
| 522 |
+
"""Sets seeds for all random libraries.
|
| 523 |
+
|
| 524 |
+
Args:
|
| 525 |
+
seed (int): seed for random states
|
| 526 |
+
"""
|
| 527 |
+
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
|
| 528 |
+
global_rank = self.get_global_rank()
|
| 529 |
+
|
| 530 |
+
random.seed(seed)
|
| 531 |
+
np.random.seed(seed)
|
| 532 |
+
torch.manual_seed(seed)
|
| 533 |
+
assert torch.cuda.is_available()
|
| 534 |
+
|
| 535 |
+
# data parallel seed are kept the same in the same pipeline stage
|
| 536 |
+
dp_seed = seed
|
| 537 |
+
if dpseed_with_tpoffset:
|
| 538 |
+
dp_seed = seed + pipeline_offset * 1024
|
| 539 |
+
add_seed(ParallelMode.DATA, dp_seed)
|
| 540 |
+
add_seed(ParallelMode.DUMMY, dp_seed)
|
| 541 |
+
|
| 542 |
+
# model parallel seeds are different across ranks
|
| 543 |
+
if self.is_initialized(ParallelMode.TENSOR):
|
| 544 |
+
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
|
| 545 |
+
tp_seed = seed + tp_rank + pipeline_offset * 1024
|
| 546 |
+
add_seed(ParallelMode.TENSOR, tp_seed)
|
| 547 |
+
|
| 548 |
+
# we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
|
| 549 |
+
# during model construction), this is because the random state will be different in different tensor parallel
|
| 550 |
+
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
|
| 551 |
+
# additional random operations during the RowParallelLinear module building process.
|
| 552 |
+
set_mode(ParallelMode.DUMMY)
|
| 553 |
+
|
| 554 |
+
seeds = get_seeds()
|
| 555 |
+
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
| 556 |
+
logger.info(
|
| 557 |
+
f"initialized seed on rank {global_rank}, "
|
| 558 |
+
f"numpy: {seed}, python random: {seed}, {seed_str},"
|
| 559 |
+
f"the default parallel seed is {ParallelMode.DATA}."
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
def set_virtual_pipeline_parallel_size(self, size):
|
| 563 |
+
self.virtual_pipeline_parallel_size = size
|
| 564 |
+
|
| 565 |
+
def set_virtual_pipeline_parallel_rank(self, rank):
|
| 566 |
+
self.virtual_pipeline_parallel_rank = rank
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
global_context = ParallelContext()
|
InternLM/internlm/core/context/process_group_initializer.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from enum import Enum
|
| 9 |
+
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# parallel modes
|
| 16 |
+
class ParallelMode(Enum):
|
| 17 |
+
"""This is an enumeration class containing all possible parallel modes."""
|
| 18 |
+
|
| 19 |
+
GLOBAL = "global"
|
| 20 |
+
|
| 21 |
+
# common parallel
|
| 22 |
+
DATA = "data"
|
| 23 |
+
|
| 24 |
+
# model parallel - containing tensor and pipeline parallel groups
|
| 25 |
+
# this is added to facilitate amp and grad clipping in hybrid parallel
|
| 26 |
+
MODEL = "model"
|
| 27 |
+
|
| 28 |
+
# pipeline parallel
|
| 29 |
+
PIPELINE = "pipe"
|
| 30 |
+
|
| 31 |
+
# containing all ranks in tensor parallel
|
| 32 |
+
TENSOR = "tensor"
|
| 33 |
+
|
| 34 |
+
# zero1 parallel
|
| 35 |
+
ZERO1 = "zero1"
|
| 36 |
+
|
| 37 |
+
# runntime network test
|
| 38 |
+
NETTEST = "nettest"
|
| 39 |
+
|
| 40 |
+
# dummy mode, only used during mode construction
|
| 41 |
+
DUMMY = "dummy"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ProcessGroupInitializer(ABC):
|
| 45 |
+
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
rank (int): The rank of current process.
|
| 49 |
+
world_size (int): Size of whole communication world.
|
| 50 |
+
data_parallel_size (int): Size of data parallel.
|
| 51 |
+
pipeline_parallel_size (int): Size of pipeline parallel.
|
| 52 |
+
tensor_parallel_size (int): Size of tensor parallel.
|
| 53 |
+
zero1_parallel_size (int): Size of zero1 parallel.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
rank: int,
|
| 59 |
+
world_size: int,
|
| 60 |
+
data_parallel_size: int,
|
| 61 |
+
pipeline_parallel_size: int,
|
| 62 |
+
tensor_parallel_size: int,
|
| 63 |
+
zero1_parallel_size: int,
|
| 64 |
+
nettest_parallel_size: int,
|
| 65 |
+
):
|
| 66 |
+
self.rank = rank
|
| 67 |
+
self.world_size = world_size
|
| 68 |
+
self.data_parallel_size = data_parallel_size
|
| 69 |
+
self.pipeline_parallel_size = pipeline_parallel_size
|
| 70 |
+
self.tensor_parallel_size = tensor_parallel_size
|
| 71 |
+
self.zero1_parallel_size = zero1_parallel_size
|
| 72 |
+
self.nettest_parallel_size = nettest_parallel_size
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
@abstractmethod
|
| 76 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Initializer_Data(ProcessGroupInitializer):
|
| 81 |
+
"""A ProcessGroupInitializer for data parallelism.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
rank (int): The rank of current process.
|
| 85 |
+
world_size (int): Size of whole communication world.
|
| 86 |
+
data_parallel_size (int): Size of data parallel.
|
| 87 |
+
pipeline_parallel_size (int): Size of pipeline parallel.
|
| 88 |
+
tensor_parallel_size (int): Size of tensor parallel.
|
| 89 |
+
zero1_parallel_size (int): Size of zero1 parallel.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, *args, **kwargs):
|
| 93 |
+
super().__init__(*args, **kwargs)
|
| 94 |
+
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
| 95 |
+
|
| 96 |
+
assert self.world_size % self.data_parallel_size == 0
|
| 97 |
+
|
| 98 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 99 |
+
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
| 103 |
+
A Data parallelism's information tuple.
|
| 104 |
+
"""
|
| 105 |
+
local_rank = None
|
| 106 |
+
ranks_in_group = None
|
| 107 |
+
process_group = None
|
| 108 |
+
cpu_group = None
|
| 109 |
+
group_world_size = None
|
| 110 |
+
mode = ParallelMode.DATA
|
| 111 |
+
|
| 112 |
+
for i in range(self.rank_num_per_dp_group):
|
| 113 |
+
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
| 114 |
+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 115 |
+
if use_cpu:
|
| 116 |
+
group_cpu = (
|
| 117 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 118 |
+
if dist.get_backend() != "gloo"
|
| 119 |
+
else group
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
group_cpu = None
|
| 123 |
+
|
| 124 |
+
if self.rank in ranks:
|
| 125 |
+
local_rank = ranks.index(self.rank)
|
| 126 |
+
group_world_size = len(ranks)
|
| 127 |
+
process_group = group
|
| 128 |
+
cpu_group = group_cpu
|
| 129 |
+
ranks_in_group = ranks
|
| 130 |
+
|
| 131 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Initializer_Model(ProcessGroupInitializer):
|
| 135 |
+
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
|
| 136 |
+
groups).
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
rank (int): The rank of current process.
|
| 140 |
+
world_size (int): Size of whole communication world.
|
| 141 |
+
data_parallel_size (int): Size of data parallel.
|
| 142 |
+
pipeline_parallel_size (int): Size of pipeline parallel.
|
| 143 |
+
tensor_parallel_size (int): Size of tensor parallel.
|
| 144 |
+
zero1_parallel_size (int): Size of zero1 parallel.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, *args, **kwargs):
|
| 148 |
+
super().__init__(*args, **kwargs)
|
| 149 |
+
self.rank_num_per_group = self.tensor_parallel_size * self.pipeline_parallel_size
|
| 150 |
+
self.num_group = self.world_size // self.rank_num_per_group
|
| 151 |
+
|
| 152 |
+
assert self.world_size % self.rank_num_per_group == 0
|
| 153 |
+
|
| 154 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 155 |
+
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
| 159 |
+
A Model parallelism's information tuple.
|
| 160 |
+
"""
|
| 161 |
+
local_rank = None
|
| 162 |
+
ranks_in_group = None
|
| 163 |
+
process_group = None
|
| 164 |
+
cpu_group = None
|
| 165 |
+
group_world_size = None
|
| 166 |
+
mode = ParallelMode.MODEL
|
| 167 |
+
|
| 168 |
+
for i in range(self.num_group):
|
| 169 |
+
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
|
| 170 |
+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 171 |
+
if use_cpu:
|
| 172 |
+
group_cpu = (
|
| 173 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 174 |
+
if dist.get_backend() != "gloo"
|
| 175 |
+
else group
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
group_cpu = None
|
| 179 |
+
|
| 180 |
+
if self.rank in ranks:
|
| 181 |
+
local_rank = ranks.index(self.rank)
|
| 182 |
+
group_world_size = len(ranks)
|
| 183 |
+
process_group = group
|
| 184 |
+
cpu_group = group_cpu
|
| 185 |
+
ranks_in_group = ranks
|
| 186 |
+
|
| 187 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class Initializer_Pipeline(ProcessGroupInitializer):
|
| 191 |
+
"""A ProcessGroupInitializer for pipeline parallelism.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
rank (int): The rank of current process
|
| 195 |
+
world_size (int): Size of whole communication world
|
| 196 |
+
data_parallel_size (int): Size of data parallel
|
| 197 |
+
pipeline_parallel_size (int): Size of pipeline parallel
|
| 198 |
+
tensor_parallel_size (int): Size of tensor parallel
|
| 199 |
+
zero1_parallel_size (int): Size of zero1 parallel.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, *args, **kwargs):
|
| 203 |
+
super().__init__(*args, **kwargs)
|
| 204 |
+
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
| 205 |
+
self.pipeline_stage_size = self.rank_num_per_dp_group // self.pipeline_parallel_size
|
| 206 |
+
|
| 207 |
+
assert self.world_size % self.data_parallel_size == 0
|
| 208 |
+
assert self.rank_num_per_dp_group % self.pipeline_parallel_size == 0
|
| 209 |
+
|
| 210 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 211 |
+
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
|
| 215 |
+
A Pipeline parallelism's information in list of tuples.
|
| 216 |
+
"""
|
| 217 |
+
local_rank = None
|
| 218 |
+
ranks_in_group = None
|
| 219 |
+
process_group = None
|
| 220 |
+
cpu_group = None
|
| 221 |
+
group_world_size = None
|
| 222 |
+
mode = ParallelMode.PIPELINE
|
| 223 |
+
|
| 224 |
+
for i in range(self.data_parallel_size):
|
| 225 |
+
for j in range(self.pipeline_stage_size):
|
| 226 |
+
ranks = list(
|
| 227 |
+
range(
|
| 228 |
+
i * self.rank_num_per_dp_group + j,
|
| 229 |
+
(i + 1) * self.rank_num_per_dp_group,
|
| 230 |
+
self.pipeline_stage_size,
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
pipe_group_size = len(ranks)
|
| 234 |
+
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 235 |
+
if use_cpu:
|
| 236 |
+
group_cpu = (
|
| 237 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 238 |
+
if dist.get_backend() != "gloo"
|
| 239 |
+
else pipe_group
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
group_cpu = None
|
| 243 |
+
|
| 244 |
+
if self.rank in ranks:
|
| 245 |
+
local_rank = ranks.index(self.rank)
|
| 246 |
+
group_world_size = pipe_group_size
|
| 247 |
+
process_group = pipe_group
|
| 248 |
+
cpu_group = group_cpu
|
| 249 |
+
ranks_in_group = ranks
|
| 250 |
+
|
| 251 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class Initializer_Tensor(ProcessGroupInitializer):
|
| 255 |
+
"""A ProcessGroupInitializer for tensor parallelism.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
rank (int): The rank of current process.
|
| 259 |
+
world_size (int): Size of whole communication world.
|
| 260 |
+
data_parallel_size (int): Size of data parallel.
|
| 261 |
+
pipeline_parallel_size (int): Size of pipeline parallel.
|
| 262 |
+
tensor_parallel_size (int): Size of tensor parallel.
|
| 263 |
+
zero1_parallel_size (int): Size of zero1 parallel.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self, *args, **kwargs):
|
| 267 |
+
super().__init__(*args, **kwargs)
|
| 268 |
+
self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
|
| 269 |
+
|
| 270 |
+
assert self.world_size % self.tensor_parallel_size == 0
|
| 271 |
+
|
| 272 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 273 |
+
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
| 277 |
+
A Tensor parallelism's information tuple.
|
| 278 |
+
"""
|
| 279 |
+
local_rank = None
|
| 280 |
+
ranks_in_group = None
|
| 281 |
+
process_group = None
|
| 282 |
+
cpu_group = None
|
| 283 |
+
group_world_size = None
|
| 284 |
+
mode = ParallelMode.TENSOR
|
| 285 |
+
|
| 286 |
+
for i in range(self.num_tensor_parallel_group):
|
| 287 |
+
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
| 288 |
+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 289 |
+
if use_cpu:
|
| 290 |
+
group_cpu = (
|
| 291 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 292 |
+
if dist.get_backend() != "gloo"
|
| 293 |
+
else group
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
group_cpu = None
|
| 297 |
+
|
| 298 |
+
if self.rank in ranks:
|
| 299 |
+
local_rank = ranks.index(self.rank)
|
| 300 |
+
group_world_size = len(ranks)
|
| 301 |
+
process_group = group
|
| 302 |
+
cpu_group = group_cpu
|
| 303 |
+
ranks_in_group = ranks
|
| 304 |
+
|
| 305 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class Initializer_Zero1(ProcessGroupInitializer):
|
| 309 |
+
"""A ProcessGroupInitializer for zero-1 parallelism.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
rank (int): The rank of current process.
|
| 313 |
+
world_size (int): Size of whole communication world.
|
| 314 |
+
data_parallel_size (int): Size of data parallel.
|
| 315 |
+
pipeline_parallel_size (int): Size of pipeline parallel.
|
| 316 |
+
tensor_parallel_size (int): Size of tensor parallel.
|
| 317 |
+
zero1_parallel_size (int): Size of zero-1 parallel.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
def __init__(self, *args, **kwargs):
|
| 321 |
+
super().__init__(*args, **kwargs)
|
| 322 |
+
self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
|
| 323 |
+
self.num_zero1_parallel_group = self.data_parallel_size // self.zero1_parallel_size
|
| 324 |
+
|
| 325 |
+
assert self.world_size % self.data_parallel_size == 0
|
| 326 |
+
assert self.world_size % self.zero1_parallel_size == 0
|
| 327 |
+
|
| 328 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 329 |
+
"""Initialize zero1 parallel groups, and assign local_ranks and groups to each gpu.
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
| 333 |
+
A zero1 parallelism's information tuple.
|
| 334 |
+
"""
|
| 335 |
+
local_rank = None
|
| 336 |
+
ranks_in_group = None
|
| 337 |
+
process_group = None
|
| 338 |
+
cpu_group = None
|
| 339 |
+
group_world_size = None
|
| 340 |
+
mode = ParallelMode.ZERO1
|
| 341 |
+
|
| 342 |
+
for i in range(self.rank_num_per_dp_group):
|
| 343 |
+
for j in range(self.num_zero1_parallel_group):
|
| 344 |
+
ranks = [
|
| 345 |
+
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
|
| 346 |
+
for k in range(self.zero1_parallel_size)
|
| 347 |
+
]
|
| 348 |
+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 349 |
+
if use_cpu:
|
| 350 |
+
group_cpu = (
|
| 351 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 352 |
+
if dist.get_backend() != "gloo"
|
| 353 |
+
else group
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
group_cpu = None
|
| 357 |
+
|
| 358 |
+
if self.rank in ranks:
|
| 359 |
+
local_rank = ranks.index(self.rank)
|
| 360 |
+
group_world_size = len(ranks)
|
| 361 |
+
process_group = group
|
| 362 |
+
cpu_group = group_cpu
|
| 363 |
+
ranks_in_group = ranks
|
| 364 |
+
|
| 365 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class Initializer_Nettest(ProcessGroupInitializer):
|
| 369 |
+
"""A ProcessGroupInitializer for network test, especailly for NCCL.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
rank (int): The rank of current process.
|
| 373 |
+
world_size (int): Size of whole communication world.
|
| 374 |
+
nettest_parallel_size (int): Size of a network test group.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
def __init__(self, *args, **kwargs):
|
| 378 |
+
super().__init__(*args, **kwargs)
|
| 379 |
+
self.num_nettest_group = math.ceil(self.world_size / self.nettest_parallel_size)
|
| 380 |
+
|
| 381 |
+
def init_dist_group(self, use_cpu: bool = False):
|
| 382 |
+
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
|
| 386 |
+
A Tensor parallelism's information tuple.
|
| 387 |
+
"""
|
| 388 |
+
local_rank = None
|
| 389 |
+
ranks_in_group = None
|
| 390 |
+
process_group = None
|
| 391 |
+
cpu_group = None
|
| 392 |
+
group_world_size = None
|
| 393 |
+
mode = ParallelMode.NETTEST
|
| 394 |
+
|
| 395 |
+
for i in range(self.num_nettest_group):
|
| 396 |
+
ranks = []
|
| 397 |
+
for j in range(self.nettest_parallel_size):
|
| 398 |
+
rank = i * self.nettest_parallel_size + j
|
| 399 |
+
if rank < self.world_size:
|
| 400 |
+
ranks.append(rank)
|
| 401 |
+
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
| 402 |
+
if use_cpu:
|
| 403 |
+
group_cpu = (
|
| 404 |
+
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
| 405 |
+
if dist.get_backend() != "gloo"
|
| 406 |
+
else group
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
group_cpu = None
|
| 410 |
+
|
| 411 |
+
if self.rank in ranks:
|
| 412 |
+
local_rank = ranks.index(self.rank)
|
| 413 |
+
group_world_size = len(ranks)
|
| 414 |
+
process_group = group
|
| 415 |
+
cpu_group = group_cpu
|
| 416 |
+
ranks_in_group = ranks
|
| 417 |
+
|
| 418 |
+
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
InternLM/internlm/core/context/random.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
|
| 4 |
+
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.cuda
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from .process_group_initializer import ParallelMode
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SeedManager:
|
| 15 |
+
"""This class is a manager of all random seeds involved in the system."""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self._current_mode = None
|
| 19 |
+
self._seeds = {}
|
| 20 |
+
self._seed_states = {}
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def current_mode(self):
|
| 24 |
+
return self._current_mode
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def seeds(self):
|
| 28 |
+
return self._seeds
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def seed_states(self):
|
| 32 |
+
return self._seed_states
|
| 33 |
+
|
| 34 |
+
def set_state(self, parallel_mode: ParallelMode, state: Tensor):
|
| 35 |
+
"""Sets the state of the seed manager for `parallel_mode`."""
|
| 36 |
+
assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
|
| 37 |
+
self._seed_states[parallel_mode] = state
|
| 38 |
+
|
| 39 |
+
def set_mode(self, parallel_mode: ParallelMode):
|
| 40 |
+
"""Sets the current mode of the seed manager."""
|
| 41 |
+
if self.current_mode:
|
| 42 |
+
# save state for current mode
|
| 43 |
+
self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
|
| 44 |
+
|
| 45 |
+
# set new state for new mode
|
| 46 |
+
self._current_mode = parallel_mode
|
| 47 |
+
torch.cuda.set_rng_state(self._seed_states[parallel_mode])
|
| 48 |
+
|
| 49 |
+
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
|
| 50 |
+
"""Adds a seed to the seed manager for `parallel_mode`."""
|
| 51 |
+
assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
|
| 52 |
+
if not overwrite:
|
| 53 |
+
assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
|
| 54 |
+
elif parallel_mode in self._seed_states:
|
| 55 |
+
print(f"Warning: {parallel_mode} seed overwritten.", flush=True)
|
| 56 |
+
|
| 57 |
+
current_state = torch.cuda.get_rng_state()
|
| 58 |
+
torch.cuda.manual_seed(seed)
|
| 59 |
+
self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
|
| 60 |
+
self._seeds[parallel_mode] = seed
|
| 61 |
+
torch.cuda.set_rng_state(current_state)
|
| 62 |
+
|
| 63 |
+
def reset(self):
|
| 64 |
+
self._current_mode = None
|
| 65 |
+
self._seeds = {}
|
| 66 |
+
self._seed_states = {}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
_SEED_MANAGER = SeedManager()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_seeds():
|
| 73 |
+
"""Returns the seeds of the seed manager.
|
| 74 |
+
Returns:
|
| 75 |
+
dict: The seeds of the seed manager.
|
| 76 |
+
"""
|
| 77 |
+
return _SEED_MANAGER.seeds
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_states(copy=False):
|
| 81 |
+
"""Returns the seed states of the seed manager.
|
| 82 |
+
Returns:
|
| 83 |
+
dict: The seed states of the seed manager.
|
| 84 |
+
"""
|
| 85 |
+
states = _SEED_MANAGER.seed_states
|
| 86 |
+
if copy:
|
| 87 |
+
new_states = dict()
|
| 88 |
+
for parallel_mode, state in states.items():
|
| 89 |
+
new_states[parallel_mode] = state.clone()
|
| 90 |
+
return new_states
|
| 91 |
+
else:
|
| 92 |
+
return _SEED_MANAGER.seed_states
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_current_mode():
|
| 96 |
+
"""Returns the current mode of the seed manager.
|
| 97 |
+
Returns:
|
| 98 |
+
:class:`torch.ByteTensor`: The current mode of the seed manager.
|
| 99 |
+
"""
|
| 100 |
+
return _SEED_MANAGER.current_mode
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
|
| 104 |
+
"""Adds a seed to the seed manager for `parallel_mode`."""
|
| 105 |
+
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def set_mode(parallel_mode: ParallelMode):
|
| 109 |
+
"""Sets the current mode of the seed manager."""
|
| 110 |
+
_SEED_MANAGER.set_mode(parallel_mode)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
|
| 114 |
+
"""Sets the state of the seed manager for `parallel_mode`."""
|
| 115 |
+
_SEED_MANAGER.set_state(parallel_mode, state)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def sync_states():
|
| 119 |
+
current_mode = get_current_mode()
|
| 120 |
+
current_states = torch.cuda.get_rng_state()
|
| 121 |
+
set_seed_states(current_mode, current_states)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@contextmanager
|
| 125 |
+
def seed(parallel_mode: ParallelMode):
|
| 126 |
+
"""A context for seed switch"""
|
| 127 |
+
current_mode = _SEED_MANAGER.current_mode
|
| 128 |
+
try:
|
| 129 |
+
yield _SEED_MANAGER.set_mode(parallel_mode)
|
| 130 |
+
finally:
|
| 131 |
+
_SEED_MANAGER.set_mode(current_mode)
|
InternLM/internlm/core/engine.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import Module
|
| 10 |
+
from torch.nn.modules.loss import _Loss
|
| 11 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 12 |
+
|
| 13 |
+
from internlm.core.gradient_handler import BaseGradientHandler
|
| 14 |
+
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
| 15 |
+
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
| 16 |
+
from internlm.utils.common import get_batch_size, move_to_device
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Engine:
|
| 20 |
+
"""
|
| 21 |
+
The Engine class is responsible for managing the training and evaluation process of a neural network model.
|
| 22 |
+
It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between
|
| 23 |
+
training and evaluation.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model (torch.nn.Module): The neural network model to be trained or evaluated.
|
| 27 |
+
optimizer (BaseOptimizer): The optimizer used for updating the parameters of the model.
|
| 28 |
+
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler for the optimizer.
|
| 29 |
+
Default is None.
|
| 30 |
+
beta2_scheduler (internlm.solver.beta2_scheduler.Beta2Scheduler, optional): The beta2 scheduler for the
|
| 31 |
+
optimizer. Default is None.
|
| 32 |
+
criterion (torch.nn.modules.loss._Loss, optional): The loss function used for calculating the loss during
|
| 33 |
+
training. Default is None.
|
| 34 |
+
gradient_handlers (List[BaseGradientHandler], optional): A list of gradient handlers used in the backward pass.
|
| 35 |
+
Default is None.
|
| 36 |
+
clip_grad_norm (float, optional): The norm value for gradient clipping. Default is 0.0.
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
>>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
|
| 40 |
+
>>> model = ...
|
| 41 |
+
>>> criterion = ...
|
| 42 |
+
>>> optimizer = ...
|
| 43 |
+
>>> train_dataloader = ...
|
| 44 |
+
>>> engine, _, _, _ = internlm.initialize_engine(model, optimizer, criterion)
|
| 45 |
+
>>> engine.train()
|
| 46 |
+
>>> for inputs, labels in train_dataloader
|
| 47 |
+
>>> # set gradients to zero
|
| 48 |
+
>>> engine.zero_grad()
|
| 49 |
+
>>> # run forward pass
|
| 50 |
+
>>> outputs = engine(inputs)
|
| 51 |
+
>>> # compute loss value and run backward pass
|
| 52 |
+
>>> loss = engine.criterion(outputs, labels)
|
| 53 |
+
>>> engine.backward(loss)
|
| 54 |
+
>>> # update parameters
|
| 55 |
+
>>> engine.step()
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
model: Module,
|
| 61 |
+
optimizer: BaseOptimizer,
|
| 62 |
+
lr_scheduler: Optional[_LRScheduler] = None,
|
| 63 |
+
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
| 64 |
+
criterion: Optional[_Loss] = None,
|
| 65 |
+
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
| 66 |
+
clip_grad_norm: float = 0.0,
|
| 67 |
+
):
|
| 68 |
+
self._model = model
|
| 69 |
+
self._optimizer = optimizer
|
| 70 |
+
self._lr_scheduler = lr_scheduler
|
| 71 |
+
self._beta2_scheduler = beta2_scheduler
|
| 72 |
+
self._criterion = criterion
|
| 73 |
+
self._clip_grad_norm = clip_grad_norm
|
| 74 |
+
|
| 75 |
+
# state
|
| 76 |
+
self.training = True # default
|
| 77 |
+
|
| 78 |
+
# build gradient handler
|
| 79 |
+
self._gradient_handlers = gradient_handlers if gradient_handlers else []
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def model(self):
|
| 83 |
+
"""Returns the model attached to the engine."""
|
| 84 |
+
return self._model
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def optimizer(self):
|
| 88 |
+
"""Returns the optimizer attached to the engine."""
|
| 89 |
+
return self._optimizer
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def criterion(self):
|
| 93 |
+
"""Returns the criterion (loss function) attached to the engine."""
|
| 94 |
+
return self._criterion
|
| 95 |
+
|
| 96 |
+
def _all_reduce_gradients(self):
|
| 97 |
+
"""Handles all-reduce operations of gradients across different parallel groups."""
|
| 98 |
+
for handler in self._gradient_handlers:
|
| 99 |
+
handler.handle_gradient()
|
| 100 |
+
|
| 101 |
+
def zero_grad(self):
|
| 102 |
+
"""Sets the gradient of all parameters in the model to zero."""
|
| 103 |
+
self.optimizer.zero_grad()
|
| 104 |
+
|
| 105 |
+
def step(self):
|
| 106 |
+
"""
|
| 107 |
+
Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
|
| 108 |
+
and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
|
| 109 |
+
if they exist.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
success (bool): Whether the parameter update was successful.
|
| 113 |
+
grad_norm (float): The norm of the gradient after clipping.
|
| 114 |
+
"""
|
| 115 |
+
self._all_reduce_gradients()
|
| 116 |
+
self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
|
| 117 |
+
|
| 118 |
+
success, grad_norm = self.optimizer.step()
|
| 119 |
+
|
| 120 |
+
if success and self._lr_scheduler is not None:
|
| 121 |
+
self._lr_scheduler.step()
|
| 122 |
+
|
| 123 |
+
if success and self._beta2_scheduler is not None:
|
| 124 |
+
self._beta2_scheduler.step()
|
| 125 |
+
|
| 126 |
+
return success, grad_norm
|
| 127 |
+
|
| 128 |
+
def train(self):
|
| 129 |
+
"""Sets the model to training mode."""
|
| 130 |
+
self.training = True
|
| 131 |
+
self._model.train()
|
| 132 |
+
|
| 133 |
+
def eval(self):
|
| 134 |
+
"""Sets the model to evaluation mode."""
|
| 135 |
+
self.training = False
|
| 136 |
+
self._model.eval()
|
| 137 |
+
|
| 138 |
+
def backward(self, loss: torch.Tensor):
|
| 139 |
+
"""
|
| 140 |
+
Starts the backward propagation given the loss value computed by a loss function.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
loss (torch.Tensor): The loss value computed by a loss function.
|
| 144 |
+
"""
|
| 145 |
+
return self.optimizer.backward(loss)
|
| 146 |
+
|
| 147 |
+
def backward_by_grad(self, tensor, grad):
|
| 148 |
+
"""
|
| 149 |
+
Starts the backward propagation given the gradient of the output tensor.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
tensor (torch.Tensor): The output tensor.
|
| 153 |
+
grad (torch.Tensor): The gradient passed back to the output tensor.
|
| 154 |
+
"""
|
| 155 |
+
return self.optimizer.backward_by_grad(tensor, grad)
|
| 156 |
+
|
| 157 |
+
def __call__(self, *args, **kwargs):
|
| 158 |
+
"""
|
| 159 |
+
Runs the forward step for the model.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
torch.Tensor: The output of the model.
|
| 163 |
+
"""
|
| 164 |
+
return self.model(*args, **kwargs)
|
| 165 |
+
|
| 166 |
+
def load_batch(self, data_iter, to_gpu=True):
|
| 167 |
+
"""
|
| 168 |
+
Loads a batch from the data iterator. It returns the data and labels which are
|
| 169 |
+
already in the same GPU as where the model is.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
data_iter (Iterable): The data iterator from which to get a batch of data, obtained by calling
|
| 173 |
+
iter(dataloader).
|
| 174 |
+
to_gpu (bool, optional): Whether the data should be moved to the GPU. Default is True.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
Tuple (torch.Tensor, torch.Tensor): A tuple of (data, label).
|
| 178 |
+
"""
|
| 179 |
+
if data_iter is None:
|
| 180 |
+
raise RuntimeError("Dataloader is not defined.")
|
| 181 |
+
try:
|
| 182 |
+
batch_data = next(data_iter)
|
| 183 |
+
except TypeError:
|
| 184 |
+
batch_data = data_iter
|
| 185 |
+
|
| 186 |
+
if to_gpu:
|
| 187 |
+
batch_data = move_to_device(batch_data)
|
| 188 |
+
batch_size = get_batch_size(batch_data)
|
| 189 |
+
|
| 190 |
+
return batch_data, batch_size
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class KDEngine(Engine):
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
model: Module,
|
| 197 |
+
teacher: Module,
|
| 198 |
+
optimizer: BaseOptimizer,
|
| 199 |
+
lr_scheduler: Optional[_LRScheduler] = None,
|
| 200 |
+
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
| 201 |
+
criterion: Optional[_Loss] = None,
|
| 202 |
+
kd_criterion: Optional[_Loss] = None,
|
| 203 |
+
gradient_handlers: Optional[List[BaseGradientHandler]] = None,
|
| 204 |
+
clip_grad_norm: float = 0.0,
|
| 205 |
+
):
|
| 206 |
+
self._teacher = teacher
|
| 207 |
+
self._kd_criterion = kd_criterion
|
| 208 |
+
|
| 209 |
+
super().__init__(
|
| 210 |
+
model=model,
|
| 211 |
+
optimizer=optimizer,
|
| 212 |
+
lr_scheduler=lr_scheduler,
|
| 213 |
+
beta2_scheduler=beta2_scheduler,
|
| 214 |
+
criterion=criterion,
|
| 215 |
+
gradient_handlers=gradient_handlers,
|
| 216 |
+
clip_grad_norm=clip_grad_norm,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
@property
|
| 220 |
+
def teacher(self):
|
| 221 |
+
"""Returns the model attached to the engine."""
|
| 222 |
+
return self._teacher
|
| 223 |
+
|
| 224 |
+
@property
|
| 225 |
+
def kd_criterion(self):
|
| 226 |
+
"""Returns the model attached to the engine."""
|
| 227 |
+
return self._kd_criterion
|
InternLM/internlm/core/gradient_handler.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
| 10 |
+
|
| 11 |
+
from internlm.core.context import global_context as gpc
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BaseGradientHandler(ABC):
|
| 15 |
+
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups
|
| 16 |
+
before optimization.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
model (Module): Model where the gradients accumulate.
|
| 20 |
+
optimizer (Optimizer): Optimizer for updating the parameters.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self, model, optimizer):
|
| 24 |
+
self._model = model
|
| 25 |
+
self._optimizer = optimizer
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def handle_gradient(self):
|
| 29 |
+
"""A method to accumulate gradients across different parallel groups. Users should
|
| 30 |
+
write their own functions or just use the functions in pre-defined subclasses.
|
| 31 |
+
"""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PipelineSharedModuleGradientHandler(BaseGradientHandler):
|
| 36 |
+
"""A helper class to handle all-reduce operations in sub parallel groups.
|
| 37 |
+
A all-reduce collective communication will be operated in
|
| 38 |
+
:func:`handle_gradient` among all sub pipeline parallel groups.
|
| 39 |
+
For better performance, it bucketizes the gradients of all parameters that are
|
| 40 |
+
the same type to improve the efficiency of communication.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model (Module): Model where the gradients accumulate.
|
| 44 |
+
optimizer (Optimizer): Optimizer for updating the parameters.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def handle_gradient(self):
|
| 48 |
+
"""A method running a all-reduce operation in sub pipeline parallel groups."""
|
| 49 |
+
if gpc.pipeline_parallel_size > 1:
|
| 50 |
+
# bucketize and all-reduce
|
| 51 |
+
buckets = defaultdict(lambda: defaultdict(list))
|
| 52 |
+
# Pack the buckets.
|
| 53 |
+
for param in self._model.parameters():
|
| 54 |
+
group = getattr(param, "pipeline_shared_module_pg", None)
|
| 55 |
+
if (
|
| 56 |
+
param.requires_grad
|
| 57 |
+
and group is not None
|
| 58 |
+
and (
|
| 59 |
+
(hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
|
| 60 |
+
or param.grad is not None
|
| 61 |
+
)
|
| 62 |
+
):
|
| 63 |
+
tp = param.data.type()
|
| 64 |
+
buckets[group][tp].append(param)
|
| 65 |
+
|
| 66 |
+
# For each bucket, all-reduce and copy all-reduced grads.
|
| 67 |
+
for group, group_buckets in buckets.items():
|
| 68 |
+
for tp, bucket in group_buckets.items():
|
| 69 |
+
grads = [
|
| 70 |
+
param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
|
| 71 |
+
for param in bucket
|
| 72 |
+
]
|
| 73 |
+
coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
|
| 74 |
+
dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
|
| 75 |
+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
| 76 |
+
buf.copy_(synced)
|
InternLM/internlm/core/naive_amp.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
|
| 5 |
+
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
| 12 |
+
from torch.distributed import ReduceOp
|
| 13 |
+
|
| 14 |
+
from internlm.core.context import ParallelMode
|
| 15 |
+
from internlm.core.context.parallel_context import global_context as gpc
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NaiveAMPModel(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
|
| 21 |
+
It also provides options to cast the output back to fp32 and to synchronize buffers.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model (torch.nn.Module): The model to be wrapped and cast into fp16.
|
| 25 |
+
output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True.
|
| 26 |
+
parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module.
|
| 27 |
+
Defaults to ``ParallelMode.DATA``.
|
| 28 |
+
sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
model: nn.Module,
|
| 34 |
+
output_to_fp32: bool = True,
|
| 35 |
+
parallel_mode: ParallelMode = ParallelMode.DATA,
|
| 36 |
+
sync_buffer: bool = True,
|
| 37 |
+
dtype=torch.float16,
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.model = model.to(dtype)
|
| 41 |
+
self._output_to_fp32 = output_to_fp32
|
| 42 |
+
self._sync_buf = sync_buffer
|
| 43 |
+
self.dtype = dtype
|
| 44 |
+
|
| 45 |
+
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
|
| 46 |
+
self._process_group = gpc.get_group(parallel_mode)
|
| 47 |
+
self._world_size = gpc.get_world_size(parallel_mode)
|
| 48 |
+
else:
|
| 49 |
+
self._process_group = None
|
| 50 |
+
self._world_size = 1
|
| 51 |
+
self._sync_buf = False
|
| 52 |
+
self._first_eval_run = False
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def sync_buffer(self):
|
| 56 |
+
"""Returns the current state of the buffer synchronization."""
|
| 57 |
+
return self._sync_buf
|
| 58 |
+
|
| 59 |
+
@sync_buffer.setter
|
| 60 |
+
def sync_buffer(self, state: bool):
|
| 61 |
+
"""Sets the state of the buffer synchronization."""
|
| 62 |
+
self._sync_buf = state
|
| 63 |
+
|
| 64 |
+
def _convert_to_fp16(self, input_: Any):
|
| 65 |
+
"""Converts the input to fp16 if it is a Tensor of dtype float32."""
|
| 66 |
+
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
|
| 67 |
+
input_ = input_.to(self.dtype)
|
| 68 |
+
return input_
|
| 69 |
+
|
| 70 |
+
def _convert_to_fp32(self, input_: Any):
|
| 71 |
+
"""Converts the input to fp32 if it is a Tensor of dtype float16."""
|
| 72 |
+
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
|
| 73 |
+
input_ = input_.float()
|
| 74 |
+
return input_
|
| 75 |
+
|
| 76 |
+
def convert_to_fp32(self, out):
|
| 77 |
+
"""Converts the output to fp32"""
|
| 78 |
+
if isinstance(out, Tensor):
|
| 79 |
+
out = self._convert_to_fp32(out)
|
| 80 |
+
elif isinstance(out, (tuple, list)):
|
| 81 |
+
out = [self._convert_to_fp32(val) for val in out]
|
| 82 |
+
elif isinstance(out, dict):
|
| 83 |
+
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
|
| 84 |
+
|
| 85 |
+
return out
|
| 86 |
+
|
| 87 |
+
def _reduce_module_buffer(self):
|
| 88 |
+
"""
|
| 89 |
+
All-reduces the buffers (e.g., running stats of batch normalization) across
|
| 90 |
+
data parallel ranks so that all the ranks will produce consistent results
|
| 91 |
+
when given the same input.
|
| 92 |
+
"""
|
| 93 |
+
buf_list = []
|
| 94 |
+
|
| 95 |
+
# find valid buffers
|
| 96 |
+
for buf in self.model.buffers():
|
| 97 |
+
if buf is not None:
|
| 98 |
+
buf_list.append(buf)
|
| 99 |
+
|
| 100 |
+
# reduce buffers across data parallel ranks
|
| 101 |
+
if buf_list:
|
| 102 |
+
coalesced_buf = _flatten_dense_tensors(buf_list)
|
| 103 |
+
coalesced_buf.div_(self._world_size)
|
| 104 |
+
dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
|
| 105 |
+
unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
|
| 106 |
+
for old, new in zip(buf_list, unflattened_buf_list):
|
| 107 |
+
old.copy_(new)
|
| 108 |
+
|
| 109 |
+
def eval(self):
|
| 110 |
+
"""Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration."""
|
| 111 |
+
self.model.eval()
|
| 112 |
+
self._first_eval_run = True
|
| 113 |
+
|
| 114 |
+
def forward(self, *args, **kwargs):
|
| 115 |
+
"""
|
| 116 |
+
Performs a forward pass on the model. Buffers are synchronized before the forward pass.
|
| 117 |
+
The inputs are converted to fp16 and the outputs are optionally converted back to fp32.
|
| 118 |
+
"""
|
| 119 |
+
if (self.training or self._first_eval_run) and self._sync_buf:
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
self._reduce_module_buffer()
|
| 122 |
+
|
| 123 |
+
if self._first_eval_run:
|
| 124 |
+
self._first_eval_run = False
|
| 125 |
+
|
| 126 |
+
if args:
|
| 127 |
+
args = [self._convert_to_fp16(arg) for arg in args]
|
| 128 |
+
if kwargs:
|
| 129 |
+
for k, v in kwargs.items():
|
| 130 |
+
kwargs[k] = self._convert_to_fp16(v)
|
| 131 |
+
|
| 132 |
+
out = self.model(*args, **kwargs)
|
| 133 |
+
|
| 134 |
+
if self._output_to_fp32:
|
| 135 |
+
out = self.convert_to_fp32(out)
|
| 136 |
+
return out
|
InternLM/internlm/core/scheduler/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
|
| 2 |
+
from .no_pipeline_scheduler import NonPipelineScheduler, KDNonPipelineScheduler
|
| 3 |
+
from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler, KDPipelineScheduler
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"BaseScheduler",
|
| 7 |
+
"NonPipelineScheduler",
|
| 8 |
+
"KDNonPipelineScheduler",
|
| 9 |
+
"InterleavedPipelineScheduler",
|
| 10 |
+
"PipelineScheduler",
|
| 11 |
+
"KDPipelineScheduler",
|
| 12 |
+
"SchedulerHook",
|
| 13 |
+
"SchedulerMetricHook",
|
| 14 |
+
]
|
InternLM/internlm/core/scheduler/base_scheduler.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
from typing import Any, Callable, Iterable, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from internlm.core.engine import Engine
|
| 12 |
+
from internlm.utils.megatron_timers import megatron_timer as timer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BaseScheduler(ABC):
|
| 16 |
+
"""A basic helper class to control the process of training or evaluation.
|
| 17 |
+
It mainly composes of forward_backward_step for gradient backward and
|
| 18 |
+
optimizer_step for parameters update.
|
| 19 |
+
For the convenience to enable FP16, we aggregate all codes that contain the
|
| 20 |
+
control of FP16 in class schedule.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
|
| 24 |
+
them into data and label.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, data_process_func: Callable = None):
|
| 28 |
+
self.data_process_func = data_process_func
|
| 29 |
+
|
| 30 |
+
@abstractmethod
|
| 31 |
+
def pre_processing(self, engine: Engine):
|
| 32 |
+
"""To perform actions before running the schedule.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 36 |
+
"""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
def _load_micro_batch(self, data, label, offset, micro_bsz):
|
| 40 |
+
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
| 41 |
+
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
|
| 42 |
+
micro_batch_label = label[offset : offset + micro_bsz]
|
| 43 |
+
|
| 44 |
+
return micro_batch_data, micro_batch_label
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def forward_backward_step(
|
| 48 |
+
self,
|
| 49 |
+
engine: Engine,
|
| 50 |
+
data_iter: Iterable,
|
| 51 |
+
forward_only: bool,
|
| 52 |
+
return_loss: bool = True,
|
| 53 |
+
return_output_label: bool = True,
|
| 54 |
+
):
|
| 55 |
+
"""The process function over a batch of dataset for training or evaluation.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 59 |
+
data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
|
| 60 |
+
forward_only (bool): If True, the process won't include backward.
|
| 61 |
+
return_loss (bool, optional): If False, the loss won't be returned.
|
| 62 |
+
return_output_label (bool, optional): If False, the output and label won't be returned.
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def _call_engine(engine: Engine, inputs: Any):
|
| 68 |
+
"""Calls the engine with the given inputs.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 72 |
+
inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
|
| 73 |
+
"""
|
| 74 |
+
if isinstance(inputs, torch.Tensor):
|
| 75 |
+
return engine(inputs)
|
| 76 |
+
elif isinstance(inputs, (list, tuple)):
|
| 77 |
+
return engine(*inputs)
|
| 78 |
+
elif isinstance(inputs, dict):
|
| 79 |
+
return engine(**inputs)
|
| 80 |
+
else:
|
| 81 |
+
raise TypeError(
|
| 82 |
+
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _call_engine_criterion(criterion, outputs: Any, labels: Any):
|
| 87 |
+
"""Calls the engine's criterion with the given outputs and labels.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 91 |
+
outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
|
| 92 |
+
labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
|
| 93 |
+
"""
|
| 94 |
+
assert isinstance(
|
| 95 |
+
outputs, (torch.Tensor, list, tuple, dict)
|
| 96 |
+
), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
|
| 97 |
+
if isinstance(outputs, torch.Tensor):
|
| 98 |
+
outputs = (outputs,)
|
| 99 |
+
if isinstance(labels, torch.Tensor):
|
| 100 |
+
labels = (labels,)
|
| 101 |
+
|
| 102 |
+
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
| 103 |
+
return criterion(*outputs, *labels)
|
| 104 |
+
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
| 105 |
+
return criterion(*outputs, **labels)
|
| 106 |
+
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
| 107 |
+
return criterion(**outputs, **labels)
|
| 108 |
+
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
| 109 |
+
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
| 110 |
+
else:
|
| 111 |
+
raise TypeError(
|
| 112 |
+
f"Expected model outputs and labels to be of type torch.Tensor ' \
|
| 113 |
+
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
| 114 |
+
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SchedulerHook(ABC):
|
| 119 |
+
"""
|
| 120 |
+
Scheduler Hook.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
@abstractmethod
|
| 124 |
+
def before_forward(self, scheduler, inputs) -> None:
|
| 125 |
+
"""Actions before forward"""
|
| 126 |
+
|
| 127 |
+
@abstractmethod
|
| 128 |
+
def after_forward(self, scheduler, outputs) -> None:
|
| 129 |
+
"""Actions after forward"""
|
| 130 |
+
|
| 131 |
+
@abstractmethod
|
| 132 |
+
def before_criterion(self, scheduler, outputs, label) -> None:
|
| 133 |
+
"""Actions before criterion"""
|
| 134 |
+
|
| 135 |
+
@abstractmethod
|
| 136 |
+
def after_criterion(self, scheduler, loss) -> None:
|
| 137 |
+
"""Actions after criterion"""
|
| 138 |
+
|
| 139 |
+
@abstractmethod
|
| 140 |
+
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
| 141 |
+
"""Actions before backward"""
|
| 142 |
+
|
| 143 |
+
@abstractmethod
|
| 144 |
+
def after_backward(self, scheduler, inputs_grad) -> None:
|
| 145 |
+
"""Actions after backward"""
|
| 146 |
+
|
| 147 |
+
@abstractmethod
|
| 148 |
+
def post_helper_func(self, scheduler, outputs, label) -> None:
|
| 149 |
+
"""A post helper function"""
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class SchedulerMetricHook(SchedulerHook):
|
| 153 |
+
"""
|
| 154 |
+
Scheduler Metric Hook.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
|
| 158 |
+
self._post_func = metric
|
| 159 |
+
self._skip = skip
|
| 160 |
+
|
| 161 |
+
def before_forward(self, scheduler, inputs) -> None:
|
| 162 |
+
if not self._skip:
|
| 163 |
+
timer("fwd").start()
|
| 164 |
+
|
| 165 |
+
def after_forward(self, scheduler, outputs) -> None:
|
| 166 |
+
if not self._skip:
|
| 167 |
+
timer("fwd").stop()
|
| 168 |
+
|
| 169 |
+
def before_criterion(self, scheduler, outputs, label) -> None:
|
| 170 |
+
if not self._skip:
|
| 171 |
+
timer("cal_loss").start()
|
| 172 |
+
|
| 173 |
+
def after_criterion(self, scheduler, loss) -> None:
|
| 174 |
+
if not self._skip:
|
| 175 |
+
timer("cal_loss").stop()
|
| 176 |
+
|
| 177 |
+
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
|
| 178 |
+
if not self._skip:
|
| 179 |
+
timer("bwd").start()
|
| 180 |
+
|
| 181 |
+
def after_backward(self, scheduler, inputs_grad) -> None:
|
| 182 |
+
if not self._skip:
|
| 183 |
+
timer("bwd").stop()
|
| 184 |
+
|
| 185 |
+
def post_helper_func(self, scheduler, outputs, label) -> None:
|
| 186 |
+
if self._post_func is not None:
|
| 187 |
+
self._post_func(outputs, label)
|
InternLM/internlm/core/scheduler/no_pipeline_scheduler.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
| 5 |
+
|
| 6 |
+
from typing import Any, Callable, Iterable, List, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from internlm.core.context import global_context as gpc
|
| 11 |
+
from internlm.core.engine import Engine, KDEngine
|
| 12 |
+
from internlm.utils.common import conditional_context
|
| 13 |
+
from internlm.utils.timeout import llm_timeout
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
from .base_scheduler import BaseScheduler, SchedulerHook
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NonPipelineScheduler(BaseScheduler):
|
| 19 |
+
"""A helper schedule class for no pipeline parallelism running environment.
|
| 20 |
+
During one process, it loads a batch of dataset and feeds it to the model.
|
| 21 |
+
After getting the output and calculating the loss, it will use :meth:`step`
|
| 22 |
+
to update the parameters if it is in training mode.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
| 26 |
+
and returns a tuple in the form of (data, label), and it will be executed in load_batch.
|
| 27 |
+
gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
|
| 28 |
+
gradient accumulation.
|
| 29 |
+
|
| 30 |
+
Examples:
|
| 31 |
+
>>> # this shows an tools of customized data_process_func
|
| 32 |
+
>>> def data_process_func(dataloader_output):
|
| 33 |
+
>>> item1, item2, item3 = dataloader_output
|
| 34 |
+
>>> data = (item1, item2)
|
| 35 |
+
>>> label = item3
|
| 36 |
+
>>> return data, label
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
data_process_func: Callable = None,
|
| 42 |
+
gradient_accumulation_size: int = 1,
|
| 43 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 44 |
+
):
|
| 45 |
+
self._grad_accum_size = gradient_accumulation_size
|
| 46 |
+
self._grad_accum_offset = 0
|
| 47 |
+
|
| 48 |
+
self._hooks = scheduler_hooks
|
| 49 |
+
|
| 50 |
+
super().__init__(data_process_func)
|
| 51 |
+
|
| 52 |
+
def pre_processing(self, engine: Engine):
|
| 53 |
+
"""Performs actions before running the schedule.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
|
| 61 |
+
for hook in self._hooks:
|
| 62 |
+
getattr(hook, func_name)(self, *args, **kwargs)
|
| 63 |
+
|
| 64 |
+
def _load_accum_batch(self, data: Any, label: Any):
|
| 65 |
+
"""Loads a batch of data and label for gradient accumulation.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
data (Any): The data to be loaded.
|
| 69 |
+
label (Any): The label to be loaded.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
_data, _label = self._load_micro_batch(
|
| 73 |
+
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
| 74 |
+
)
|
| 75 |
+
self._grad_accum_offset += self._grad_accum_batch_size
|
| 76 |
+
|
| 77 |
+
if self.data_process_func:
|
| 78 |
+
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
| 79 |
+
_label = self.data_process_func(_label, _data["cu_seqlens"])
|
| 80 |
+
_data.pop("cu_seqlens")
|
| 81 |
+
_data.pop("indexes")
|
| 82 |
+
|
| 83 |
+
return _data, _label
|
| 84 |
+
|
| 85 |
+
def _train_one_batch(
|
| 86 |
+
self,
|
| 87 |
+
data: Any,
|
| 88 |
+
label: Any,
|
| 89 |
+
engine: Engine,
|
| 90 |
+
forward_only: bool = False,
|
| 91 |
+
return_loss: bool = True,
|
| 92 |
+
scale_loss: int = 1,
|
| 93 |
+
):
|
| 94 |
+
"""Trains one batch of data.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
data (Any): The data to be trained.
|
| 98 |
+
label (Any): The label for the data.
|
| 99 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 100 |
+
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
|
| 101 |
+
be executed.
|
| 102 |
+
return_loss (bool, optional): Loss will be returned if True.
|
| 103 |
+
scale_loss (int, optional): The scale factor for the loss.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
# forward
|
| 107 |
+
with conditional_context(torch.no_grad(), enable=forward_only):
|
| 108 |
+
self._call_hooks("before_forward", data)
|
| 109 |
+
output = self._call_engine(engine, data)
|
| 110 |
+
self._call_hooks("after_forward", output)
|
| 111 |
+
|
| 112 |
+
self._call_hooks("post_helper_func", output, label)
|
| 113 |
+
|
| 114 |
+
if return_loss:
|
| 115 |
+
self._call_hooks("before_criterion", output, label)
|
| 116 |
+
loss = self._call_engine_criterion(engine.criterion, output, label)
|
| 117 |
+
self._call_hooks("after_criterion", loss)
|
| 118 |
+
loss /= scale_loss
|
| 119 |
+
|
| 120 |
+
# backward
|
| 121 |
+
if not forward_only:
|
| 122 |
+
self._call_hooks("before_backward", None, None)
|
| 123 |
+
engine.backward(loss)
|
| 124 |
+
self._call_hooks("after_backward", None)
|
| 125 |
+
|
| 126 |
+
if not return_loss:
|
| 127 |
+
loss = None
|
| 128 |
+
|
| 129 |
+
return output, dict(loss=loss)
|
| 130 |
+
|
| 131 |
+
@llm_timeout(func_name="nopp_forward_backward_step")
|
| 132 |
+
def forward_backward_step(
|
| 133 |
+
self,
|
| 134 |
+
engine: Engine,
|
| 135 |
+
data_iter: Iterable,
|
| 136 |
+
forward_only: bool = False,
|
| 137 |
+
return_loss: bool = True,
|
| 138 |
+
return_output_label: bool = True,
|
| 139 |
+
):
|
| 140 |
+
"""The process function that loads a batch of dataset and feeds it to the model.
|
| 141 |
+
The returned labels and loss will None if :attr:`return_loss` is False.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 145 |
+
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
| 146 |
+
forward_only (bool, optional):
|
| 147 |
+
If True, the model is run for the forward pass, else back propagation will be executed.
|
| 148 |
+
return_loss (bool, optional): Loss will be returned if True.
|
| 149 |
+
return_output_label (bool, optional): Output and label will be returned if True.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
| 153 |
+
"""
|
| 154 |
+
assert (
|
| 155 |
+
forward_only or return_loss
|
| 156 |
+
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
| 157 |
+
|
| 158 |
+
batch_data, batch_size = engine.load_batch(data_iter)
|
| 159 |
+
|
| 160 |
+
assert (
|
| 161 |
+
batch_size % self._grad_accum_size == 0
|
| 162 |
+
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
|
| 163 |
+
self._grad_accum_batch_size = batch_size // self._grad_accum_size
|
| 164 |
+
|
| 165 |
+
data, label = batch_data
|
| 166 |
+
|
| 167 |
+
loss = defaultdict(int) if return_loss else None
|
| 168 |
+
outputs = []
|
| 169 |
+
labels = []
|
| 170 |
+
|
| 171 |
+
# reset accumulation microbatch offset
|
| 172 |
+
self._grad_accum_offset = 0
|
| 173 |
+
|
| 174 |
+
for _current_accum_step in range(self._grad_accum_size):
|
| 175 |
+
if _current_accum_step == self._grad_accum_size - 1:
|
| 176 |
+
engine.optimizer.skip_grad_reduce = False
|
| 177 |
+
else:
|
| 178 |
+
engine.optimizer.skip_grad_reduce = True
|
| 179 |
+
|
| 180 |
+
_data, _label = self._load_accum_batch(data, label)
|
| 181 |
+
|
| 182 |
+
_output, _loss = self._train_one_batch(
|
| 183 |
+
_data, _label, engine, forward_only, return_loss, self._grad_accum_size
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
if return_loss:
|
| 187 |
+
for k in _loss:
|
| 188 |
+
loss[k] += _loss[k]
|
| 189 |
+
if return_output_label:
|
| 190 |
+
outputs.append(_output)
|
| 191 |
+
labels.append(_label)
|
| 192 |
+
|
| 193 |
+
if not return_output_label:
|
| 194 |
+
outputs, labels = None, None
|
| 195 |
+
|
| 196 |
+
return outputs, labels, loss
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class KDNonPipelineScheduler(NonPipelineScheduler):
|
| 200 |
+
|
| 201 |
+
def __init__(
|
| 202 |
+
self,
|
| 203 |
+
data_process_func: Callable = None,
|
| 204 |
+
gradient_accumulation_size: int = 1,
|
| 205 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 206 |
+
):
|
| 207 |
+
super().__init__(
|
| 208 |
+
data_process_func=data_process_func,
|
| 209 |
+
gradient_accumulation_size=gradient_accumulation_size,
|
| 210 |
+
scheduler_hooks=scheduler_hooks,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def _train_one_batch(
|
| 214 |
+
self,
|
| 215 |
+
data: Any,
|
| 216 |
+
label: Any,
|
| 217 |
+
engine: KDEngine,
|
| 218 |
+
forward_only: bool = False,
|
| 219 |
+
return_loss: bool = True,
|
| 220 |
+
scale_loss: int = 1,
|
| 221 |
+
):
|
| 222 |
+
"""Trains one batch of data.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
data (Any): The data to be trained.
|
| 226 |
+
label (Any): The label for the data.
|
| 227 |
+
engine (internlm.core.Engine): InternLM engine for training and inference.
|
| 228 |
+
forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
|
| 229 |
+
be executed.
|
| 230 |
+
return_loss (bool, optional): Loss will be returned if True.
|
| 231 |
+
scale_loss (int, optional): The scale factor for the loss.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
# forward
|
| 235 |
+
with conditional_context(torch.no_grad(), enable=forward_only):
|
| 236 |
+
self._call_hooks("before_forward", data)
|
| 237 |
+
output = self._call_engine(engine, data)
|
| 238 |
+
self._call_hooks("after_forward", output)
|
| 239 |
+
|
| 240 |
+
self._call_hooks("post_helper_func", output, label)
|
| 241 |
+
|
| 242 |
+
if return_loss:
|
| 243 |
+
self._call_hooks("before_criterion", output, label)
|
| 244 |
+
loss_gt = gpc.config.kd_config['gt_weight'] * self._call_engine_criterion(engine.criterion, output, label)
|
| 245 |
+
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
engine.teacher.eval()
|
| 248 |
+
output_t = self._call_engine(engine.teacher, data)
|
| 249 |
+
|
| 250 |
+
loss_kd = gpc.config.kd_config['kd_weight'] * self._call_engine_criterion(engine.kd_criterion, output, (output_t, label))
|
| 251 |
+
|
| 252 |
+
self._call_hooks("after_criterion", loss_gt + loss_kd)
|
| 253 |
+
loss_gt /= scale_loss
|
| 254 |
+
loss_kd /= scale_loss
|
| 255 |
+
|
| 256 |
+
# backward
|
| 257 |
+
if not forward_only:
|
| 258 |
+
self._call_hooks("before_backward", None, None)
|
| 259 |
+
engine.backward(loss_gt+loss_kd)
|
| 260 |
+
self._call_hooks("after_backward", None)
|
| 261 |
+
|
| 262 |
+
if not return_loss:
|
| 263 |
+
loss_gt = None
|
| 264 |
+
loss_kd = None
|
| 265 |
+
|
| 266 |
+
return output, dict(loss_gt=loss_gt, loss_kd=loss_kd)
|
InternLM/internlm/core/scheduler/pipeline_scheduler.py
ADDED
|
@@ -0,0 +1,1363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
| 5 |
+
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch.cuda
|
| 10 |
+
|
| 11 |
+
import internlm.core.communication as comm
|
| 12 |
+
from internlm.core.context import ParallelMode
|
| 13 |
+
from internlm.core.context import global_context as gpc
|
| 14 |
+
from internlm.core.engine import Engine
|
| 15 |
+
from internlm.core.naive_amp import NaiveAMPModel
|
| 16 |
+
from internlm.utils.common import get_current_device, move_to_device
|
| 17 |
+
from internlm.utils.logger import get_logger
|
| 18 |
+
from internlm.utils.timeout import llm_timeout
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from .base_scheduler import BaseScheduler, SchedulerHook
|
| 21 |
+
|
| 22 |
+
logger = get_logger(__file__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_tensor_shape():
|
| 26 |
+
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
| 27 |
+
return gpc.config.TENSOR_SHAPE
|
| 28 |
+
|
| 29 |
+
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
| 33 |
+
if gpc.config.model.use_flash_attn:
|
| 34 |
+
if gpc.config.parallel.sequence_parallel:
|
| 35 |
+
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
| 36 |
+
tensor_shape = (
|
| 37 |
+
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
| 38 |
+
gpc.config.HIDDEN_SIZE,
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
tensor_shape = (
|
| 42 |
+
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
| 43 |
+
gpc.config.HIDDEN_SIZE,
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
tensor_shape = (
|
| 47 |
+
gpc.config.data["micro_bsz"],
|
| 48 |
+
gpc.config.SEQ_LEN,
|
| 49 |
+
gpc.config.HIDDEN_SIZE,
|
| 50 |
+
)
|
| 51 |
+
return tensor_shape
|
| 52 |
+
else:
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def pack_return_tensors(return_tensors):
|
| 57 |
+
output, label = tuple(zip(*return_tensors))
|
| 58 |
+
if isinstance(output[0], torch.Tensor):
|
| 59 |
+
output = torch.cat(output, dim=0)
|
| 60 |
+
elif isinstance(output[0], (list, tuple)):
|
| 61 |
+
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
| 62 |
+
else:
|
| 63 |
+
raise TypeError("Output of model must be tensor or list/tuple of tensors")
|
| 64 |
+
if isinstance(label[0], torch.Tensor):
|
| 65 |
+
label = torch.cat(label, dim=0)
|
| 66 |
+
else:
|
| 67 |
+
merged_label = {k: [] for k in label[0].keys()}
|
| 68 |
+
for d in label:
|
| 69 |
+
for k, v in d.items():
|
| 70 |
+
merged_label[k].append(v)
|
| 71 |
+
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
| 72 |
+
return output, label
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@contextmanager
|
| 76 |
+
def switch_virtual_pipeline_parallel_rank(rank):
|
| 77 |
+
prev_rank = gpc.virtual_pipeline_parallel_rank
|
| 78 |
+
try:
|
| 79 |
+
gpc.set_virtual_pipeline_parallel_rank(rank)
|
| 80 |
+
yield
|
| 81 |
+
finally:
|
| 82 |
+
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@contextmanager
|
| 86 |
+
def switch_optimizer_grad_sync_skip_mode(optimizer, skip: bool = True):
|
| 87 |
+
prev_mode = optimizer.skip_grad_reduce
|
| 88 |
+
try:
|
| 89 |
+
optimizer.skip_grad_reduce = skip
|
| 90 |
+
yield
|
| 91 |
+
finally:
|
| 92 |
+
optimizer.skip_grad_reduce = prev_mode
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class PipelineScheduler(BaseScheduler):
|
| 96 |
+
"""
|
| 97 |
+
A helper schedule class for pipeline parallelism running environment.
|
| 98 |
+
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
| 99 |
+
:class:`NonPipelineSchedule`.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
num_microbatches (int): The number of microbatches.
|
| 103 |
+
dtype (torch.dtype): Type of data. torch.float by default.
|
| 104 |
+
data_process_func (Callable, optional):
|
| 105 |
+
The post processing function which receives a micro batch of data, and it will be executed
|
| 106 |
+
in `load_micro_batch`.
|
| 107 |
+
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
| 108 |
+
scatter_gather_tensors (bool, optional):
|
| 109 |
+
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
| 110 |
+
scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
num_microbatches: int,
|
| 116 |
+
dtype: torch.dtype = torch.float,
|
| 117 |
+
data_process_func: Callable = None,
|
| 118 |
+
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
| 119 |
+
scatter_gather_tensors: bool = False,
|
| 120 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 121 |
+
):
|
| 122 |
+
assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
|
| 123 |
+
|
| 124 |
+
assert not isinstance(
|
| 125 |
+
tensor_shape, int
|
| 126 |
+
), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
| 127 |
+
|
| 128 |
+
super().__init__(data_process_func=data_process_func)
|
| 129 |
+
|
| 130 |
+
self.num_microbatches = num_microbatches
|
| 131 |
+
self.dtype = dtype
|
| 132 |
+
self._hooks = scheduler_hooks
|
| 133 |
+
|
| 134 |
+
self._tensor_shape = (
|
| 135 |
+
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.scatter_gather_tensors = (
|
| 139 |
+
scatter_gather_tensors
|
| 140 |
+
and gpc.is_initialized(ParallelMode.TENSOR)
|
| 141 |
+
and gpc.get_world_size(ParallelMode.TENSOR) > 1
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
if gpc.config.parallel.sequence_parallel:
|
| 145 |
+
self.scatter_gather_tensors = False
|
| 146 |
+
|
| 147 |
+
# cache for the batch data
|
| 148 |
+
self.batch_data = None
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def tensor_shape(self) -> torch.Size:
|
| 152 |
+
return self._tensor_shape
|
| 153 |
+
|
| 154 |
+
@tensor_shape.setter
|
| 155 |
+
def tensor_shape(self, tensor_shape: torch.Size):
|
| 156 |
+
self._tensor_shape = tensor_shape
|
| 157 |
+
|
| 158 |
+
def pre_processing(self, engine):
|
| 159 |
+
types = set()
|
| 160 |
+
|
| 161 |
+
for param in engine.model.parameters():
|
| 162 |
+
types.add(param.dtype)
|
| 163 |
+
assert len(types) == 1, f"Mixed types of parameter detected, {types}"
|
| 164 |
+
|
| 165 |
+
self.dtype = types.pop()
|
| 166 |
+
|
| 167 |
+
@staticmethod
|
| 168 |
+
def _call_engine(engine, data): # pylint: disable=W0237
|
| 169 |
+
if data is None:
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
if isinstance(data, torch.Tensor):
|
| 173 |
+
return engine(data)
|
| 174 |
+
elif isinstance(data, (list, tuple)):
|
| 175 |
+
return engine(*data)
|
| 176 |
+
elif isinstance(data, dict):
|
| 177 |
+
stage_output = data.pop("stage_output", None)
|
| 178 |
+
|
| 179 |
+
if stage_output is None:
|
| 180 |
+
return engine(**data)
|
| 181 |
+
elif isinstance(stage_output, torch.Tensor):
|
| 182 |
+
return engine(stage_output, **data)
|
| 183 |
+
elif isinstance(stage_output, (tuple, list)):
|
| 184 |
+
return engine(*stage_output, **data)
|
| 185 |
+
else:
|
| 186 |
+
raise TypeError(
|
| 187 |
+
f"Expected stage_output to be of type torch.Tensor, list, or tuple, "
|
| 188 |
+
f"but got {type(stage_output)}"
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
| 192 |
+
|
| 193 |
+
def load_batch(self, engine, data_iter):
|
| 194 |
+
# Pipeline schedule just puts data in memory
|
| 195 |
+
batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
|
| 196 |
+
assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
| 197 |
+
|
| 198 |
+
self.microbatch_offset = 0
|
| 199 |
+
self.batch_size = batch_size
|
| 200 |
+
self.batch_data, self.batch_label = batch_data
|
| 201 |
+
self.microbatch_size = self.batch_size // self.num_microbatches
|
| 202 |
+
|
| 203 |
+
def load_micro_batch(self):
|
| 204 |
+
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
| 205 |
+
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
| 206 |
+
)
|
| 207 |
+
if self.data_process_func:
|
| 208 |
+
micro_batch_data["input_ids"] = self.data_process_func(
|
| 209 |
+
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
|
| 210 |
+
)
|
| 211 |
+
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
|
| 212 |
+
|
| 213 |
+
micro_batch_data.pop("cu_seqlens")
|
| 214 |
+
micro_batch_data.pop("indexes")
|
| 215 |
+
|
| 216 |
+
micro_batch_data["label"] = micro_batch_label
|
| 217 |
+
self.microbatch_offset += self.microbatch_size
|
| 218 |
+
|
| 219 |
+
return move_to_device(micro_batch_data)
|
| 220 |
+
|
| 221 |
+
def _get_data_label_for_current_step(self, stage_output, micro_batch_data):
|
| 222 |
+
if isinstance(micro_batch_data, (tuple, list)):
|
| 223 |
+
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 224 |
+
# for the first stage, we use the data from the
|
| 225 |
+
# dataloader output by default
|
| 226 |
+
data, label = micro_batch_data
|
| 227 |
+
else:
|
| 228 |
+
# for non-first stage, we use the output passed
|
| 229 |
+
# by the previous as the model input
|
| 230 |
+
data = stage_output
|
| 231 |
+
_, label = micro_batch_data
|
| 232 |
+
elif isinstance(micro_batch_data, dict):
|
| 233 |
+
label = micro_batch_data.pop("label", None)
|
| 234 |
+
data = {"stage_output": stage_output, **micro_batch_data}
|
| 235 |
+
|
| 236 |
+
return data, label
|
| 237 |
+
|
| 238 |
+
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
|
| 239 |
+
for hook in self._hooks:
|
| 240 |
+
getattr(hook, func_name)(self, *args, **kwargs)
|
| 241 |
+
|
| 242 |
+
def _get_current_microbatch_id(self, step_id: int) -> int:
|
| 243 |
+
"""
|
| 244 |
+
Get the current microbatch ID based on the step ID.
|
| 245 |
+
In 1f1b scheduler, the microbatch ID is the same as the step ID,
|
| 246 |
+
but it is important to note that the step ID is calculated separately
|
| 247 |
+
for forward and backward passes.
|
| 248 |
+
"""
|
| 249 |
+
return step_id
|
| 250 |
+
|
| 251 |
+
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
| 252 |
+
"""
|
| 253 |
+
Forward step for passed-in model. If it is the first stage, the input tensor
|
| 254 |
+
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
| 255 |
+
Returns output tensor. This is a helper function and can be ignored by users.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 259 |
+
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
| 260 |
+
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
| 261 |
+
return_output_label (bool, optional): Whether returns output labels.
|
| 262 |
+
accum_loss (optional): Where accumulated loss stores.
|
| 263 |
+
Returns:
|
| 264 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
| 265 |
+
pipeline stage.
|
| 266 |
+
"""
|
| 267 |
+
micro_batch_data = self.load_micro_batch()
|
| 268 |
+
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
| 269 |
+
|
| 270 |
+
self._call_hooks("before_forward", data)
|
| 271 |
+
output_obj = self._call_engine(engine.model, data)
|
| 272 |
+
self._call_hooks("after_forward", output_obj)
|
| 273 |
+
|
| 274 |
+
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 275 |
+
self._call_hooks("post_helper_func", output_obj, label)
|
| 276 |
+
if return_output_label:
|
| 277 |
+
return_tensors.append((output_obj, label))
|
| 278 |
+
if accum_loss is not None:
|
| 279 |
+
self._call_hooks("before_criterion", output_obj, label)
|
| 280 |
+
loss = self._call_engine_criterion(engine.criterion, output_obj, label)
|
| 281 |
+
self._call_hooks("after_criterion", loss)
|
| 282 |
+
|
| 283 |
+
loss_reduced = loss / self.num_microbatches
|
| 284 |
+
accum_loss['loss'].add_(loss_reduced.detach())
|
| 285 |
+
output_obj = loss_reduced
|
| 286 |
+
|
| 287 |
+
return output_obj
|
| 288 |
+
|
| 289 |
+
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
|
| 290 |
+
"""
|
| 291 |
+
Backward step through the passed-in output tensor. If it is the last stage, the
|
| 292 |
+
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
| 293 |
+
Returns the gradients with respect to the input tensor (None if first stage).
|
| 294 |
+
This is a helper function and can be ignored by users.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 298 |
+
step_id (int): The ID of the current step.
|
| 299 |
+
input_obj (Union[torch.Tensor, List[torch.Tensor]]): Input tensor for this stage.
|
| 300 |
+
output_obj (Union[torch.Tensor, List[torch.Tensor]]): Output tensor for this stage.
|
| 301 |
+
output_obj_grad (Union[torch.Tensor, List[torch.Tensor]]): Gradient of output tensor for this stage.
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Union[torch.Tensor, List[torch.Tensor]]: Gradient of input tensor.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
# Retain the grad on the input_obj.
|
| 308 |
+
if input_obj is not None:
|
| 309 |
+
if isinstance(input_obj, torch.Tensor):
|
| 310 |
+
input_obj.retain_grad()
|
| 311 |
+
else:
|
| 312 |
+
for in_tensor in input_obj:
|
| 313 |
+
if in_tensor is not None:
|
| 314 |
+
in_tensor.retain_grad()
|
| 315 |
+
|
| 316 |
+
# Backward pass.
|
| 317 |
+
|
| 318 |
+
# Only the last microbatch does syncing grad.
|
| 319 |
+
skip_grad_sync = self._get_current_microbatch_id(step_id) != self.num_microbatches - 1
|
| 320 |
+
|
| 321 |
+
self._call_hooks("before_backward", output_obj, output_obj_grad)
|
| 322 |
+
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
|
| 323 |
+
if output_obj_grad is None:
|
| 324 |
+
engine.backward(output_obj)
|
| 325 |
+
else:
|
| 326 |
+
engine.backward_by_grad(output_obj, output_obj_grad)
|
| 327 |
+
|
| 328 |
+
# Collect the grad of the input_obj.
|
| 329 |
+
input_obj_grad = None
|
| 330 |
+
if input_obj is not None:
|
| 331 |
+
if isinstance(input_obj, torch.Tensor):
|
| 332 |
+
input_obj_grad = input_obj.grad
|
| 333 |
+
else:
|
| 334 |
+
input_obj_grad = []
|
| 335 |
+
for in_tensor in input_obj:
|
| 336 |
+
input_obj_grad.append(in_tensor.grad)
|
| 337 |
+
self._call_hooks("after_backward", input_obj_grad)
|
| 338 |
+
|
| 339 |
+
return input_obj_grad
|
| 340 |
+
|
| 341 |
+
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
|
| 342 |
+
"""
|
| 343 |
+
This function performs forward only computation process. The scheduling of microbatches is similar to the
|
| 344 |
+
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
|
| 345 |
+
the forward computation, and finally passes the forward computation output to the next stage. There are two
|
| 346 |
+
special cases to note:
|
| 347 |
+
1. The first stage of the pipeline does not need to receive forward input; its input comes from the dataloader.
|
| 348 |
+
2. The last stage of the pipeline does not need to send forward output; its output is returned to the user code
|
| 349 |
+
for processing.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
engine (colossalai.engine.Engine): internlm engine for training and inference.
|
| 353 |
+
return_loss (bool, optional): Whether to return the accumulated loss.
|
| 354 |
+
return_output_label (bool, optional): Whether to return outputs and labels.
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
|
| 358 |
+
output, label, and accumulated loss.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
# Input, output tensors only need to be saved when doing backward passes
|
| 362 |
+
return_tensors = []
|
| 363 |
+
accum_loss_init_func = lambda: torch.zeros(1, device=get_current_device())
|
| 364 |
+
accum_loss = defaultdict(accum_loss_init_func) if return_loss and gpc.is_pipeline_last_stage(
|
| 365 |
+
ignore_virtual=True) else None
|
| 366 |
+
|
| 367 |
+
# Used for tensor meta information communication
|
| 368 |
+
forward_recv_shapes = self.tensor_shape
|
| 369 |
+
need_forward_meta = self.tensor_shape is None
|
| 370 |
+
|
| 371 |
+
# Run all forward passes.
|
| 372 |
+
for _ in range(self.num_microbatches):
|
| 373 |
+
# Receive input from the previous stage
|
| 374 |
+
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 375 |
+
if forward_recv_shapes is None:
|
| 376 |
+
forward_recv_shapes = comm.recv_obj_meta()
|
| 377 |
+
input_obj = comm.recv_forward(
|
| 378 |
+
forward_recv_shapes,
|
| 379 |
+
dtype=self.dtype,
|
| 380 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 381 |
+
)
|
| 382 |
+
else:
|
| 383 |
+
input_obj = None
|
| 384 |
+
|
| 385 |
+
# Perform forward computation
|
| 386 |
+
output_obj = self._forward_step(
|
| 387 |
+
engine,
|
| 388 |
+
input_obj,
|
| 389 |
+
return_tensors,
|
| 390 |
+
return_output_label=return_output_label,
|
| 391 |
+
accum_loss=accum_loss,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 395 |
+
if need_forward_meta:
|
| 396 |
+
comm.send_obj_meta(output_obj)
|
| 397 |
+
need_forward_meta = False # send only once.
|
| 398 |
+
# Send the forward computation output to the next stage
|
| 399 |
+
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
| 400 |
+
|
| 401 |
+
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
| 402 |
+
|
| 403 |
+
return output, label, accum_loss
|
| 404 |
+
|
| 405 |
+
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
|
| 406 |
+
"""
|
| 407 |
+
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
|
| 408 |
+
It consists of three stages: warmup, 1F1B, and cooldown.
|
| 409 |
+
|
| 410 |
+
1. Warmup Stage:
|
| 411 |
+
The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length
|
| 412 |
+
minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous
|
| 413 |
+
stage, performs the forward computation, and then sends the result to the next stage.
|
| 414 |
+
|
| 415 |
+
2. 1F1B Stage:
|
| 416 |
+
The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations,
|
| 417 |
+
where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in
|
| 418 |
+
the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next
|
| 419 |
+
stage, receives input for the backward computation, performs the backward computation, and finally sends the
|
| 420 |
+
result to the previous stage to receive input for the next forward computation.
|
| 421 |
+
|
| 422 |
+
3. Cooldown Stage:
|
| 423 |
+
The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives
|
| 424 |
+
input for the backward computation, performs the backward computation, and finally sends the result to the
|
| 425 |
+
previous stage.
|
| 426 |
+
|
| 427 |
+
There are two special cases to consider:
|
| 428 |
+
1. The first stage of the pipeline does not need to receive forward input or send backward output. The last
|
| 429 |
+
stage does not need to send forward output or receive backward input.
|
| 430 |
+
2. Pay attention to the communication between stages and use additional communication to bridge the gap.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
engine (Engine): The engine used for computation.
|
| 434 |
+
return_loss (bool, optional): Whether to return the accumulated loss.
|
| 435 |
+
return_output_label (bool, optional): Whether to return outputs and labels.
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
|
| 439 |
+
The output, label, and accumulated loss.
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
num_warmup_microsteps = (
|
| 443 |
+
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
| 444 |
+
)
|
| 445 |
+
num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches)
|
| 446 |
+
num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps
|
| 447 |
+
|
| 448 |
+
# Input, output tensors only need to be saved when doing backward passes
|
| 449 |
+
input_objs = []
|
| 450 |
+
output_objs = []
|
| 451 |
+
return_tensors = []
|
| 452 |
+
accum_loss_init_func = lambda: torch.zeros(1, device=get_current_device())
|
| 453 |
+
accum_loss = defaultdict(accum_loss_init_func) if return_loss and gpc.is_pipeline_last_stage(
|
| 454 |
+
ignore_virtual=True) else None
|
| 455 |
+
|
| 456 |
+
# Used for tensor meta information communication
|
| 457 |
+
forward_recv_shapes = self.tensor_shape
|
| 458 |
+
backward_recv_shapes = None
|
| 459 |
+
need_forward_meta = self.tensor_shape is None
|
| 460 |
+
|
| 461 |
+
# Run warmup forward passes.
|
| 462 |
+
for i in range(num_warmup_microsteps):
|
| 463 |
+
# Receive the input from the previous stage
|
| 464 |
+
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 465 |
+
if forward_recv_shapes is None:
|
| 466 |
+
forward_recv_shapes = comm.recv_obj_meta()
|
| 467 |
+
input_obj = comm.recv_forward(
|
| 468 |
+
forward_recv_shapes,
|
| 469 |
+
dtype=self.dtype,
|
| 470 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
input_obj = None
|
| 474 |
+
|
| 475 |
+
# Perform forward computation
|
| 476 |
+
output_obj = self._forward_step(
|
| 477 |
+
engine,
|
| 478 |
+
input_obj,
|
| 479 |
+
return_tensors,
|
| 480 |
+
return_output_label=return_output_label,
|
| 481 |
+
accum_loss=accum_loss,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 485 |
+
if isinstance(output_obj, torch.Tensor):
|
| 486 |
+
backward_recv_shapes = output_obj.shape
|
| 487 |
+
else:
|
| 488 |
+
backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj]
|
| 489 |
+
|
| 490 |
+
if need_forward_meta:
|
| 491 |
+
comm.send_obj_meta(output_obj)
|
| 492 |
+
need_forward_meta = False # send only once.
|
| 493 |
+
|
| 494 |
+
# Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
|
| 495 |
+
# forward computation
|
| 496 |
+
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 497 |
+
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
| 498 |
+
|
| 499 |
+
input_objs.append(input_obj)
|
| 500 |
+
output_objs.append(output_obj)
|
| 501 |
+
|
| 502 |
+
# Before running 1F1B, need to receive first forward tensor.
|
| 503 |
+
# If all microbatches are run in warmup / cooldown phase, then no need to
|
| 504 |
+
# receive this tensor here.
|
| 505 |
+
if num_1f1b_micropairs > 0:
|
| 506 |
+
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 507 |
+
if forward_recv_shapes is None:
|
| 508 |
+
forward_recv_shapes = comm.recv_obj_meta(forward_recv_shapes)
|
| 509 |
+
input_obj = comm.recv_forward(
|
| 510 |
+
forward_recv_shapes,
|
| 511 |
+
dtype=self.dtype,
|
| 512 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 513 |
+
)
|
| 514 |
+
else:
|
| 515 |
+
input_obj = None
|
| 516 |
+
|
| 517 |
+
# Run 1F1B in steady state.
|
| 518 |
+
for i in range(num_1f1b_micropairs):
|
| 519 |
+
# Perform forward computation
|
| 520 |
+
output_obj = self._forward_step(
|
| 521 |
+
engine,
|
| 522 |
+
input_obj,
|
| 523 |
+
return_tensors,
|
| 524 |
+
return_output_label=return_output_label,
|
| 525 |
+
accum_loss=accum_loss,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 529 |
+
output_obj_grad = None
|
| 530 |
+
else:
|
| 531 |
+
output_obj_grad = comm.send_forward_recv_backward(
|
| 532 |
+
output_obj,
|
| 533 |
+
backward_recv_shapes,
|
| 534 |
+
dtype=self.dtype,
|
| 535 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Add input_obj and output_obj to end of list.
|
| 539 |
+
input_objs.append(input_obj)
|
| 540 |
+
output_objs.append(output_obj)
|
| 541 |
+
|
| 542 |
+
# Pop output_obj and output_obj from the start of the list for
|
| 543 |
+
# the backward pass.
|
| 544 |
+
input_obj = input_objs.pop(0)
|
| 545 |
+
output_obj = output_objs.pop(0)
|
| 546 |
+
|
| 547 |
+
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
|
| 548 |
+
|
| 549 |
+
if i == (num_1f1b_micropairs - 1):
|
| 550 |
+
input_obj = None
|
| 551 |
+
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 552 |
+
comm.send_backward(
|
| 553 |
+
input_obj_grad,
|
| 554 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 555 |
+
)
|
| 556 |
+
else:
|
| 557 |
+
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 558 |
+
input_obj = None
|
| 559 |
+
else:
|
| 560 |
+
input_obj = comm.send_backward_recv_forward(
|
| 561 |
+
input_obj_grad,
|
| 562 |
+
forward_recv_shapes,
|
| 563 |
+
dtype=self.dtype,
|
| 564 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Run cooldown backward passes.
|
| 568 |
+
for i in range(num_warmup_microsteps):
|
| 569 |
+
input_obj = input_objs.pop(0)
|
| 570 |
+
output_obj = output_objs.pop(0)
|
| 571 |
+
|
| 572 |
+
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 573 |
+
output_obj_grad = comm.recv_backward(
|
| 574 |
+
backward_recv_shapes,
|
| 575 |
+
dtype=self.dtype,
|
| 576 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 577 |
+
)
|
| 578 |
+
else:
|
| 579 |
+
output_obj_grad = None
|
| 580 |
+
|
| 581 |
+
input_obj_grad = self._backward_step(
|
| 582 |
+
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
| 586 |
+
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
| 587 |
+
|
| 588 |
+
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
|
| 589 |
+
|
| 590 |
+
return output, label, accum_loss
|
| 591 |
+
|
| 592 |
+
@llm_timeout(func_name="nointerleaved_forward_backward_step")
|
| 593 |
+
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
| 594 |
+
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
| 595 |
+
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
| 596 |
+
|
| 597 |
+
Args:
|
| 598 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 599 |
+
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
| 600 |
+
forward_only (bool, optional):
|
| 601 |
+
Whether run forward step only. Default is false. If true, no backward will be run.
|
| 602 |
+
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
| 603 |
+
return_output_label (bool, optional): If False, the output and label won't be returned.
|
| 604 |
+
Returns:
|
| 605 |
+
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
| 606 |
+
"""
|
| 607 |
+
|
| 608 |
+
assert (
|
| 609 |
+
forward_only or return_loss
|
| 610 |
+
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
| 611 |
+
|
| 612 |
+
# Load data first
|
| 613 |
+
self.load_batch(engine, data_iter)
|
| 614 |
+
|
| 615 |
+
if forward_only:
|
| 616 |
+
return self._forward_only_step(engine, return_loss, return_output_label)
|
| 617 |
+
else:
|
| 618 |
+
return self._forward_backward_step(engine, return_loss, return_output_label)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class InterleavedPipelineScheduler(PipelineScheduler):
|
| 622 |
+
"""
|
| 623 |
+
Interleaved Pipeline Scheduler.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
def __init__(
|
| 627 |
+
self,
|
| 628 |
+
num_microbatches: int,
|
| 629 |
+
num_chunks: int,
|
| 630 |
+
dtype: torch.dtype = torch.float,
|
| 631 |
+
data_process_func: Callable = None,
|
| 632 |
+
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
| 633 |
+
scatter_gather_tensors: bool = False,
|
| 634 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 635 |
+
communication_overlap: bool = False,
|
| 636 |
+
):
|
| 637 |
+
"""A helper schedule class for pipeline parallelism running environment.
|
| 638 |
+
It uses interleaved 1F1B strategy. Other properties are similar as
|
| 639 |
+
:class:`NonPipelineSchedule`.
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
num_microbatches (int): The number of microbatches.
|
| 643 |
+
num_chunks (int): The number of model chunks.
|
| 644 |
+
dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float.
|
| 645 |
+
data_process_func (Callable, optional):
|
| 646 |
+
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
| 647 |
+
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
| 648 |
+
scatter_gather_tensors (bool, optional):
|
| 649 |
+
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
| 650 |
+
scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None.
|
| 651 |
+
communication_overlap (bool, optional): Whether to enable communication overlap. Default is False.
|
| 652 |
+
"""
|
| 653 |
+
assert (
|
| 654 |
+
num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
|
| 655 |
+
), "num_microbatches must be an integer multiple of pipeline parallel world size"
|
| 656 |
+
|
| 657 |
+
assert (
|
| 658 |
+
isinstance(num_chunks, int) and num_chunks > 0
|
| 659 |
+
), f"expected num_chunks to be an integer and larger than 0, but got {num_chunks}"
|
| 660 |
+
|
| 661 |
+
super().__init__(
|
| 662 |
+
num_microbatches,
|
| 663 |
+
dtype=dtype,
|
| 664 |
+
data_process_func=data_process_func,
|
| 665 |
+
tensor_shape=tensor_shape,
|
| 666 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 667 |
+
scheduler_hooks=scheduler_hooks,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
gpc.set_virtual_pipeline_parallel_size(num_chunks)
|
| 671 |
+
gpc.set_virtual_pipeline_parallel_rank(0)
|
| 672 |
+
|
| 673 |
+
self._num_chunks = num_chunks
|
| 674 |
+
self._communication_overlap = communication_overlap
|
| 675 |
+
# switch 1f1b loop runner function according to communication overlap
|
| 676 |
+
self._run_1f1b_loop = (
|
| 677 |
+
self._run_1f1b_loop_with_overlap if communication_overlap else self._run_1f1b_loop_without_overlap
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
# states
|
| 681 |
+
self._pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
| 682 |
+
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
| 683 |
+
|
| 684 |
+
self._accum_loss = None
|
| 685 |
+
self._return_tensors = None
|
| 686 |
+
self._input_objs = [[] for _ in range(num_chunks)]
|
| 687 |
+
self._output_objs = [[] for _ in range(num_chunks)]
|
| 688 |
+
self._output_obj_grads = [[] for _ in range(num_chunks)]
|
| 689 |
+
|
| 690 |
+
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
|
| 691 |
+
self._output_obj_shapes = [None for _ in range(num_chunks)]
|
| 692 |
+
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
|
| 693 |
+
|
| 694 |
+
@property
|
| 695 |
+
def tensor_shape(self) -> torch.Size:
|
| 696 |
+
return self._tensor_shape
|
| 697 |
+
|
| 698 |
+
@tensor_shape.setter
|
| 699 |
+
def tensor_shape(self, tensor_shape: torch.Size):
|
| 700 |
+
self._tensor_shape = tensor_shape
|
| 701 |
+
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
|
| 702 |
+
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
|
| 703 |
+
|
| 704 |
+
def _clear_state(self) -> None:
|
| 705 |
+
self._accum_loss = None
|
| 706 |
+
self._return_tensors = None
|
| 707 |
+
self._input_objs = [[] for _ in range(self._num_chunks)]
|
| 708 |
+
self._output_objs = [[] for _ in range(self._num_chunks)]
|
| 709 |
+
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
|
| 710 |
+
|
| 711 |
+
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
|
| 712 |
+
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
|
| 713 |
+
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)]
|
| 714 |
+
|
| 715 |
+
def load_batch(self, engine, data_iter):
|
| 716 |
+
super().load_batch(engine, data_iter)
|
| 717 |
+
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
|
| 718 |
+
self.microbatch_offset = [0 for _ in range(self._num_chunks)]
|
| 719 |
+
|
| 720 |
+
def load_micro_batch(self, model_chunk_id):
|
| 721 |
+
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
| 722 |
+
data=self.batch_data,
|
| 723 |
+
label=self.batch_label,
|
| 724 |
+
offset=self.microbatch_offset[model_chunk_id],
|
| 725 |
+
micro_bsz=self.microbatch_size,
|
| 726 |
+
)
|
| 727 |
+
micro_batch_data["label"] = micro_batch_label
|
| 728 |
+
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
| 729 |
+
return move_to_device(micro_batch_data)
|
| 730 |
+
|
| 731 |
+
def _forward_step(self, engine, chunk_id):
|
| 732 |
+
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
| 733 |
+
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
| 734 |
+
Returns output tensor. This is a helper function and can be ignored by users.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 738 |
+
chunk_id (int): The id of model chunks.
|
| 739 |
+
Returns:
|
| 740 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
| 741 |
+
pipeline stage.
|
| 742 |
+
"""
|
| 743 |
+
gpc.set_virtual_pipeline_parallel_rank(chunk_id)
|
| 744 |
+
|
| 745 |
+
if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]):
|
| 746 |
+
self._input_objs[chunk_id].append(None)
|
| 747 |
+
input_obj = self._input_objs[chunk_id][-1]
|
| 748 |
+
|
| 749 |
+
micro_batch_data = self.load_micro_batch(chunk_id)
|
| 750 |
+
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
| 751 |
+
|
| 752 |
+
self._call_hooks("before_forward", data)
|
| 753 |
+
output_obj = self._call_engine(engine.model[chunk_id], data)
|
| 754 |
+
# Convert output_obj to fp32 when last model chunk of last stage
|
| 755 |
+
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
|
| 756 |
+
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
|
| 757 |
+
self._call_hooks("after_forward", output_obj)
|
| 758 |
+
|
| 759 |
+
if gpc.is_pipeline_last_stage():
|
| 760 |
+
self._call_hooks("post_helper_func", output_obj, label)
|
| 761 |
+
|
| 762 |
+
if self._return_tensors is not None:
|
| 763 |
+
self._return_tensors.append((output_obj, label))
|
| 764 |
+
if self._accum_loss is not None:
|
| 765 |
+
self._call_hooks("before_criterion", output_obj, label)
|
| 766 |
+
loss = self._call_engine_criterion(engine.criterion, output_obj, label)
|
| 767 |
+
self._call_hooks("after_criterion", loss)
|
| 768 |
+
|
| 769 |
+
loss_reduced = loss / self.num_microbatches
|
| 770 |
+
self._accum_loss.add_(loss_reduced.detach())
|
| 771 |
+
output_obj = loss_reduced
|
| 772 |
+
|
| 773 |
+
self._output_objs[chunk_id].append(output_obj)
|
| 774 |
+
|
| 775 |
+
return output_obj
|
| 776 |
+
|
| 777 |
+
def _backward_step(self, engine, chunk_id, step_id):
|
| 778 |
+
"""
|
| 779 |
+
Backward step for passed-in model. If it is the last stage, the input tensor
|
| 780 |
+
is obtained from the previous forward step, otherwise the passed-in input_obj is used.
|
| 781 |
+
Returns input tensor gradient. This is a helper function and can be ignored by users.
|
| 782 |
+
|
| 783 |
+
Args:
|
| 784 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 785 |
+
chunk_id (int): The id of model chunks.
|
| 786 |
+
step_id (int): The current step id.
|
| 787 |
+
|
| 788 |
+
Returns:
|
| 789 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient.
|
| 790 |
+
"""
|
| 791 |
+
gpc.set_virtual_pipeline_parallel_rank(chunk_id)
|
| 792 |
+
|
| 793 |
+
if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0:
|
| 794 |
+
self._output_obj_grads[chunk_id].append(None)
|
| 795 |
+
|
| 796 |
+
input_obj = self._input_objs[chunk_id].pop(0)
|
| 797 |
+
output_obj = self._output_objs[chunk_id].pop(0)
|
| 798 |
+
output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
|
| 799 |
+
|
| 800 |
+
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
|
| 801 |
+
|
| 802 |
+
return input_obj_grad
|
| 803 |
+
|
| 804 |
+
def _get_chunk_by_microbatch(self, step_id: int, backward: bool = False) -> int:
|
| 805 |
+
"""Helper method to get the model chunk ID given the iteration number."""
|
| 806 |
+
microbatch_id_in_group = step_id % (self._pp_size * self._num_chunks)
|
| 807 |
+
chunk_id = microbatch_id_in_group // self._pp_size
|
| 808 |
+
|
| 809 |
+
if backward:
|
| 810 |
+
chunk_id = self._num_chunks - chunk_id - 1
|
| 811 |
+
|
| 812 |
+
return chunk_id
|
| 813 |
+
|
| 814 |
+
def _get_current_microbatch_id(self, step_id: int) -> int:
|
| 815 |
+
# format:
|
| 816 |
+
# microstep_id : 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
|
| 817 |
+
# microbatch_id: 1 2 3 4 1 2 3 4 5 6 7 8 5 6 7 8
|
| 818 |
+
num_microbatch_group = step_id // (self._pp_size * self._num_chunks)
|
| 819 |
+
step_id_in_group = step_id % (self._pp_size * self._num_chunks)
|
| 820 |
+
|
| 821 |
+
microbatch_id = num_microbatch_group * self._pp_size + step_id_in_group % self._pp_size
|
| 822 |
+
|
| 823 |
+
return microbatch_id
|
| 824 |
+
|
| 825 |
+
def _run_warmup_loop(
|
| 826 |
+
self,
|
| 827 |
+
engine: Engine,
|
| 828 |
+
num_microsteps: int,
|
| 829 |
+
num_warmup_microsteps: int,
|
| 830 |
+
receive_extra_backward: bool = False,
|
| 831 |
+
forward_only: bool = False,
|
| 832 |
+
) -> None:
|
| 833 |
+
"""
|
| 834 |
+
Run the warm-up loop and prepare data for the 1F1B stage.
|
| 835 |
+
|
| 836 |
+
During the warm-up process, for each execution, it first performs a forward computation,
|
| 837 |
+
and then sends the computation result to the next stage.
|
| 838 |
+
It also receives data for the next forward computation.
|
| 839 |
+
Since the input for the first forward computation is not considered initially,
|
| 840 |
+
it needs to receive data once at the beginning.
|
| 841 |
+
|
| 842 |
+
After the warm-up is completed, we need to prepare data for the 1F1B stage.
|
| 843 |
+
The data preparation process should be consistent with the communication method of the 1F1B stage.
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
engine (Engine): The engine to run the warm-up loop.
|
| 847 |
+
num_microsteps (int): The total number of microsteps.
|
| 848 |
+
num_warmup_microsteps (int): The number of warm-up microsteps.
|
| 849 |
+
receive_extra_backward (bool, optional): Whether to receive extra backward input for the 1F1B stage.
|
| 850 |
+
Default is False.
|
| 851 |
+
forward_only (bool, optional): Whether to only perform forward pass. Default is False.
|
| 852 |
+
"""
|
| 853 |
+
if not gpc.is_pipeline_first_stage():
|
| 854 |
+
if self._input_obj_shapes[0] is None:
|
| 855 |
+
self._input_obj_shapes[0] = comm.recv_obj_meta(self._input_obj_shapes[0])
|
| 856 |
+
self._input_objs[0].append(
|
| 857 |
+
comm.recv_forward(
|
| 858 |
+
self._input_obj_shapes[0],
|
| 859 |
+
dtype=self.dtype,
|
| 860 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 861 |
+
)
|
| 862 |
+
)
|
| 863 |
+
else:
|
| 864 |
+
self._input_objs[0].append(None)
|
| 865 |
+
|
| 866 |
+
for k in range(num_warmup_microsteps):
|
| 867 |
+
chunk_id = self._get_chunk_by_microbatch(k)
|
| 868 |
+
|
| 869 |
+
output_obj = self._forward_step(engine, chunk_id)
|
| 870 |
+
|
| 871 |
+
if forward_only:
|
| 872 |
+
# when forward-only, no need to save tensors for a backward pass
|
| 873 |
+
self._input_objs[chunk_id].pop()
|
| 874 |
+
self._output_objs[chunk_id].pop()
|
| 875 |
+
|
| 876 |
+
if not gpc.is_pipeline_last_stage():
|
| 877 |
+
if isinstance(output_obj, torch.Tensor):
|
| 878 |
+
self._output_obj_shapes[chunk_id] = output_obj.shape
|
| 879 |
+
else:
|
| 880 |
+
self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj]
|
| 881 |
+
|
| 882 |
+
if self._send_tensor_shape_flags[chunk_id]:
|
| 883 |
+
comm.send_obj_meta(output_obj)
|
| 884 |
+
self._send_tensor_shape_flags[chunk_id] = False # send only once for each chunk.
|
| 885 |
+
|
| 886 |
+
# Determine if tensor should be received from previous stage.
|
| 887 |
+
next_forward_chunk_id = self._get_chunk_by_microbatch(k + 1)
|
| 888 |
+
|
| 889 |
+
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
|
| 890 |
+
if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[next_forward_chunk_id] is None:
|
| 891 |
+
self._input_obj_shapes[next_forward_chunk_id] = comm.recv_obj_meta()
|
| 892 |
+
if k == (num_microsteps - 1) or gpc.is_pipeline_first_stage():
|
| 893 |
+
input_shape = None
|
| 894 |
+
else:
|
| 895 |
+
input_shape = self._input_obj_shapes[next_forward_chunk_id]
|
| 896 |
+
|
| 897 |
+
# Don't send tensor downstream if on last stage.
|
| 898 |
+
if gpc.is_pipeline_last_stage():
|
| 899 |
+
output_obj = None
|
| 900 |
+
|
| 901 |
+
# Send and receive tensors as appropriate (send tensors computed
|
| 902 |
+
# in this iteration; receive tensors for next iteration).
|
| 903 |
+
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
|
| 904 |
+
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
|
| 905 |
+
input_obj = comm.send_forward_recv_forward(
|
| 906 |
+
output_obj,
|
| 907 |
+
input_shape,
|
| 908 |
+
dtype=self.dtype,
|
| 909 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 910 |
+
)
|
| 911 |
+
else:
|
| 912 |
+
# Receive output_obj_grad for next backward, if receive_extra_backward is True.
|
| 913 |
+
if self._communication_overlap:
|
| 914 |
+
# In this case, we should handle forward and backward communication separately, consistent with the
|
| 915 |
+
# overlap version of the 1F1B stage
|
| 916 |
+
input_obj = comm.send_forward_recv_forward(
|
| 917 |
+
output_obj,
|
| 918 |
+
input_shape,
|
| 919 |
+
dtype=self.dtype,
|
| 920 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 921 |
+
)
|
| 922 |
+
output_obj_grad = comm.send_backward_recv_backward(
|
| 923 |
+
None, # nothing to send
|
| 924 |
+
self._output_obj_shapes[self._num_chunks - 1],
|
| 925 |
+
dtype=self.dtype,
|
| 926 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 927 |
+
)
|
| 928 |
+
self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
|
| 929 |
+
else:
|
| 930 |
+
# In this case, we should handle forward and backward communication together, consistent with the
|
| 931 |
+
# non-overlap version of the 1F1B stage
|
| 932 |
+
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
| 933 |
+
output_obj,
|
| 934 |
+
None, # no backward grad to send
|
| 935 |
+
input_shape,
|
| 936 |
+
self._output_obj_shapes[self._num_chunks - 1],
|
| 937 |
+
dtype=self.dtype,
|
| 938 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 939 |
+
)
|
| 940 |
+
self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
|
| 941 |
+
|
| 942 |
+
self._input_objs[next_forward_chunk_id].append(input_obj)
|
| 943 |
+
|
| 944 |
+
def _run_1f1b_loop_with_overlap(
|
| 945 |
+
self,
|
| 946 |
+
engine: Engine,
|
| 947 |
+
num_warmup_microsteps: int,
|
| 948 |
+
num_1f1b_micropairs: int,
|
| 949 |
+
all_warmup_microsteps: bool = False,
|
| 950 |
+
) -> None:
|
| 951 |
+
"""
|
| 952 |
+
Run the 1F1B loop with overlap.
|
| 953 |
+
|
| 954 |
+
The 1F1B loop with overlap consists of the following steps:
|
| 955 |
+
1. Perform the forward pass.
|
| 956 |
+
2. Check if the backward input is ready.
|
| 957 |
+
3. Send the forward output and receive the forward input for the next iteration.
|
| 958 |
+
4. Perform the backward pass.
|
| 959 |
+
5. Check if the forward input is ready.
|
| 960 |
+
6. Send the backward output and receive the backward input for the next iteration.
|
| 961 |
+
|
| 962 |
+
Args:
|
| 963 |
+
engine (Engine): The engine to run the 1F1B loop.
|
| 964 |
+
num_warmup_microsteps (int): The number of warm-up microsteps.
|
| 965 |
+
num_1f1b_micropairs (int): The number of 1F1B micropairs.
|
| 966 |
+
all_warmup_microsteps (bool, optional): Whether to run all warm-up microsteps. Default is False.
|
| 967 |
+
"""
|
| 968 |
+
|
| 969 |
+
backward_async_communicator = None
|
| 970 |
+
|
| 971 |
+
# Run 1F1B in steady state.
|
| 972 |
+
for k in range(num_1f1b_micropairs):
|
| 973 |
+
forward_microstep_id = k + num_warmup_microsteps
|
| 974 |
+
backward_microstep_id = k
|
| 975 |
+
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
|
| 976 |
+
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
|
| 977 |
+
|
| 978 |
+
# 1. Forward pass.
|
| 979 |
+
output_obj = self._forward_step(engine, forward_chunk_id)
|
| 980 |
+
|
| 981 |
+
# 2. Check if the backward input is ready.
|
| 982 |
+
if backward_async_communicator is not None:
|
| 983 |
+
output_obj_grad = backward_async_communicator.wait_and_receive()
|
| 984 |
+
|
| 985 |
+
if backward_async_communicator.need_receive:
|
| 986 |
+
self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
|
| 987 |
+
|
| 988 |
+
# 3. Send the forward outputs and receive the forward inputs from the previous rank.
|
| 989 |
+
|
| 990 |
+
# Check if it is the last model chunk of the last pipeline stage, no need to send forward output.
|
| 991 |
+
gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
|
| 992 |
+
if gpc.is_pipeline_last_stage():
|
| 993 |
+
output_obj = None
|
| 994 |
+
|
| 995 |
+
# Check if it needs to receive the results from the previous rank.
|
| 996 |
+
next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
|
| 997 |
+
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
|
| 998 |
+
if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
|
| 999 |
+
input_obj_shape = None
|
| 1000 |
+
else:
|
| 1001 |
+
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
|
| 1002 |
+
|
| 1003 |
+
forward_async_communicator = comm.AsynCommunicator(
|
| 1004 |
+
output_obj,
|
| 1005 |
+
input_obj_shape,
|
| 1006 |
+
self.dtype,
|
| 1007 |
+
self.scatter_gather_tensors,
|
| 1008 |
+
forward=True,
|
| 1009 |
+
)
|
| 1010 |
+
forward_async_communicator.start()
|
| 1011 |
+
|
| 1012 |
+
# 5. Backward pass.
|
| 1013 |
+
|
| 1014 |
+
input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
|
| 1015 |
+
|
| 1016 |
+
input_obj = forward_async_communicator.wait_and_receive()
|
| 1017 |
+
if forward_async_communicator.need_receive:
|
| 1018 |
+
self._input_objs[next_forward_chunk_id].append(input_obj)
|
| 1019 |
+
|
| 1020 |
+
# 6. Send the backward output and receive the backward input for the next iteration.
|
| 1021 |
+
gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
|
| 1022 |
+
if gpc.is_pipeline_first_stage():
|
| 1023 |
+
input_obj_grad = None
|
| 1024 |
+
|
| 1025 |
+
next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
|
| 1026 |
+
with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
|
| 1027 |
+
if gpc.is_pipeline_last_stage():
|
| 1028 |
+
output_obj_shape = None
|
| 1029 |
+
else:
|
| 1030 |
+
output_obj_shape = self._output_obj_shapes[next_backward_chunk_id]
|
| 1031 |
+
|
| 1032 |
+
backward_async_communicator = comm.AsynCommunicator(
|
| 1033 |
+
input_obj_grad,
|
| 1034 |
+
output_obj_shape,
|
| 1035 |
+
self.dtype,
|
| 1036 |
+
self.scatter_gather_tensors,
|
| 1037 |
+
forward=False,
|
| 1038 |
+
)
|
| 1039 |
+
backward_async_communicator.start()
|
| 1040 |
+
|
| 1041 |
+
if all_warmup_microsteps:
|
| 1042 |
+
if not gpc.is_pipeline_last_stage():
|
| 1043 |
+
self._output_obj_grads[self._num_chunks - 1].append(
|
| 1044 |
+
comm.recv_backward(
|
| 1045 |
+
self._output_obj_shapes[self._num_chunks - 1],
|
| 1046 |
+
dtype=self.dtype,
|
| 1047 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 1048 |
+
)
|
| 1049 |
+
)
|
| 1050 |
+
else:
|
| 1051 |
+
self._output_obj_grads[self._num_chunks - 1].append(None)
|
| 1052 |
+
else:
|
| 1053 |
+
output_obj_grad = backward_async_communicator.wait_and_receive()
|
| 1054 |
+
if backward_async_communicator.need_receive:
|
| 1055 |
+
backward_chunk_id = self._get_chunk_by_microbatch(num_1f1b_micropairs, backward=True)
|
| 1056 |
+
self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
|
| 1057 |
+
|
| 1058 |
+
def _run_1f1b_loop_without_overlap(
|
| 1059 |
+
self,
|
| 1060 |
+
engine: Engine,
|
| 1061 |
+
num_warmup_microsteps: int,
|
| 1062 |
+
num_1f1b_micropairs: int,
|
| 1063 |
+
all_warmup_microsteps: bool = False,
|
| 1064 |
+
) -> None:
|
| 1065 |
+
"""
|
| 1066 |
+
Run the 1F1B loop without overlap.
|
| 1067 |
+
|
| 1068 |
+
The 1F1B loop without overlap consists of the following steps:
|
| 1069 |
+
1. Perform the forward pass.
|
| 1070 |
+
2. Perform the backward pass.
|
| 1071 |
+
3. Send the forward output of this iteration to the next stage, and send the backward output of this iteration
|
| 1072 |
+
to the previous stage, and receive the forward and backward inputs for the next iteration.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
engine (Engine): The engine to use for computation.
|
| 1076 |
+
num_warmup_microsteps (int): The number of warmup microsteps.
|
| 1077 |
+
num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
|
| 1078 |
+
all_warmup_microsteps (bool, optional): Whether to run all warmup microsteps. Defaults to False.
|
| 1079 |
+
"""
|
| 1080 |
+
for k in range(num_1f1b_micropairs):
|
| 1081 |
+
# Forward pass.
|
| 1082 |
+
forward_microstep_id = k + num_warmup_microsteps
|
| 1083 |
+
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
|
| 1084 |
+
output_obj = self._forward_step(engine, forward_chunk_id)
|
| 1085 |
+
|
| 1086 |
+
# Backward pass.
|
| 1087 |
+
backward_microstep_id = k
|
| 1088 |
+
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
|
| 1089 |
+
input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
|
| 1090 |
+
|
| 1091 |
+
# Send output_obj and input_obj_grad, receive input_obj
|
| 1092 |
+
# and output_obj_grad.
|
| 1093 |
+
|
| 1094 |
+
# Determine if current stage has anything to send in either direction,
|
| 1095 |
+
# otherwise set obj to None.
|
| 1096 |
+
gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
|
| 1097 |
+
if gpc.is_pipeline_last_stage():
|
| 1098 |
+
output_obj = None
|
| 1099 |
+
|
| 1100 |
+
gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
|
| 1101 |
+
if gpc.is_pipeline_first_stage():
|
| 1102 |
+
input_obj_grad = None
|
| 1103 |
+
|
| 1104 |
+
# Determine if peers are sending, and where in data structure to put
|
| 1105 |
+
# received tensors.
|
| 1106 |
+
next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
|
| 1107 |
+
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
|
| 1108 |
+
if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
|
| 1109 |
+
recv_prev = False
|
| 1110 |
+
else:
|
| 1111 |
+
recv_prev = True
|
| 1112 |
+
|
| 1113 |
+
next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
|
| 1114 |
+
with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
|
| 1115 |
+
if gpc.is_pipeline_last_stage():
|
| 1116 |
+
recv_next = False
|
| 1117 |
+
else:
|
| 1118 |
+
recv_next = True
|
| 1119 |
+
|
| 1120 |
+
input_shape = self._input_obj_shapes[next_forward_chunk_id] if recv_prev else None
|
| 1121 |
+
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
|
| 1122 |
+
|
| 1123 |
+
# Communicate objs.
|
| 1124 |
+
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
|
| 1125 |
+
output_obj,
|
| 1126 |
+
input_obj_grad,
|
| 1127 |
+
input_shape,
|
| 1128 |
+
output_shape,
|
| 1129 |
+
dtype=self.dtype,
|
| 1130 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
# Put input_obj and output_obj_grad in data structures in the
|
| 1134 |
+
# right location.
|
| 1135 |
+
if recv_prev:
|
| 1136 |
+
self._input_objs[next_forward_chunk_id].append(input_obj)
|
| 1137 |
+
if recv_next:
|
| 1138 |
+
self._output_obj_grads[next_backward_chunk_id].append(output_obj_grad)
|
| 1139 |
+
|
| 1140 |
+
# receive necessary data for next cooldown loop
|
| 1141 |
+
if all_warmup_microsteps:
|
| 1142 |
+
if not gpc.is_pipeline_last_stage():
|
| 1143 |
+
self._output_obj_grads[self._num_chunks - 1].append(
|
| 1144 |
+
comm.recv_backward(
|
| 1145 |
+
self._output_obj_shapes[self._num_chunks - 1],
|
| 1146 |
+
dtype=self.dtype,
|
| 1147 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 1148 |
+
)
|
| 1149 |
+
)
|
| 1150 |
+
else:
|
| 1151 |
+
self._output_obj_grads[self._num_chunks - 1].append(None)
|
| 1152 |
+
|
| 1153 |
+
def _run_cooldown_loop(self, engine: Engine, num_microsteps: int, num_1f1b_micropairs: int) -> None:
|
| 1154 |
+
"""
|
| 1155 |
+
Run the cooldown loop.
|
| 1156 |
+
|
| 1157 |
+
The cooldown loop consists of the following steps:
|
| 1158 |
+
1. Perform the backward step.
|
| 1159 |
+
2. Send the backward output to the next stage and receive inputs for next backward.
|
| 1160 |
+
|
| 1161 |
+
Args:
|
| 1162 |
+
engine (Engine): The engine to use for computation.
|
| 1163 |
+
num_microsteps (int): The total number of microsteps.
|
| 1164 |
+
num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
|
| 1165 |
+
"""
|
| 1166 |
+
for k in range(num_1f1b_micropairs, num_microsteps):
|
| 1167 |
+
chunk_id = self._get_chunk_by_microbatch(k, backward=True)
|
| 1168 |
+
|
| 1169 |
+
input_obj_grad = self._backward_step(engine, chunk_id, k)
|
| 1170 |
+
|
| 1171 |
+
next_backward_chunk_id = self._get_chunk_by_microbatch(k + 1, backward=True)
|
| 1172 |
+
|
| 1173 |
+
if k != (num_microsteps - 1) and not (
|
| 1174 |
+
gpc.is_pipeline_last_stage(ignore_virtual=True) and next_backward_chunk_id == (self._num_chunks - 1)
|
| 1175 |
+
):
|
| 1176 |
+
output_shape = self._output_obj_shapes[next_backward_chunk_id]
|
| 1177 |
+
else:
|
| 1178 |
+
output_shape = None
|
| 1179 |
+
|
| 1180 |
+
self._output_obj_grads[next_backward_chunk_id].append(
|
| 1181 |
+
comm.send_backward_recv_backward(
|
| 1182 |
+
input_obj_grad,
|
| 1183 |
+
output_shape,
|
| 1184 |
+
dtype=self.dtype,
|
| 1185 |
+
scatter_gather_tensors=self.scatter_gather_tensors,
|
| 1186 |
+
)
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
def _forward_only_step(self, engine: Engine):
|
| 1190 |
+
num_microsteps = self.num_microbatches * self._num_chunks
|
| 1191 |
+
num_warmup_microsteps = num_microsteps
|
| 1192 |
+
|
| 1193 |
+
self._run_warmup_loop(
|
| 1194 |
+
engine,
|
| 1195 |
+
num_microsteps,
|
| 1196 |
+
num_warmup_microsteps,
|
| 1197 |
+
receive_extra_backward=False,
|
| 1198 |
+
forward_only=True,
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
def _forward_backward_step(self, engine: Engine):
|
| 1202 |
+
# Compute number of warmup and remaining microbatches.
|
| 1203 |
+
all_warmup_microsteps = False
|
| 1204 |
+
num_microsteps = self.num_microbatches * self._num_chunks
|
| 1205 |
+
|
| 1206 |
+
# Run all forward passes and then all backward passes if number of
|
| 1207 |
+
# microbatches is just the number of pipeline stages.
|
| 1208 |
+
# Otherwise, perform (num_chunks-1)*pipeline_parallel_size on
|
| 1209 |
+
# all workers, followed by more microbatches after depending on
|
| 1210 |
+
# stage ID (more forward passes for earlier stages, later stages can
|
| 1211 |
+
# immediately start with 1F1B).
|
| 1212 |
+
if self.num_microbatches == self._pp_size:
|
| 1213 |
+
num_warmup_steps = num_microsteps
|
| 1214 |
+
all_warmup_microsteps = True
|
| 1215 |
+
else:
|
| 1216 |
+
num_warmup_steps = (self._pp_size - self._pp_rank - 1) * 2
|
| 1217 |
+
num_warmup_steps += (self._num_chunks - 1) * self._pp_size
|
| 1218 |
+
num_warmup_steps = min(num_warmup_steps, num_microsteps)
|
| 1219 |
+
num_1f1b_micropairs = num_microsteps - num_warmup_steps
|
| 1220 |
+
|
| 1221 |
+
# We usually need to prepare an extra backward data for the 1F1B stage when the WarmUp stage ends,
|
| 1222 |
+
# because the 1F1B stage typically performs one forward and backward pass together,
|
| 1223 |
+
# except in the following cases:
|
| 1224 |
+
receive_extra_backward = not (
|
| 1225 |
+
all_warmup_microsteps # Only warmup microsteps
|
| 1226 |
+
or gpc.is_pipeline_last_stage(ignore_virtual=True) # The rank is the last pipeline stage
|
| 1227 |
+
)
|
| 1228 |
+
|
| 1229 |
+
# 1. Warmup
|
| 1230 |
+
self._run_warmup_loop(
|
| 1231 |
+
engine,
|
| 1232 |
+
num_microsteps,
|
| 1233 |
+
num_warmup_steps,
|
| 1234 |
+
receive_extra_backward=receive_extra_backward,
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
# 2. 1F1B
|
| 1238 |
+
self._run_1f1b_loop(
|
| 1239 |
+
engine,
|
| 1240 |
+
num_warmup_steps,
|
| 1241 |
+
num_1f1b_micropairs=num_1f1b_micropairs,
|
| 1242 |
+
all_warmup_microsteps=all_warmup_microsteps,
|
| 1243 |
+
)
|
| 1244 |
+
|
| 1245 |
+
# 3. Cooldown
|
| 1246 |
+
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
|
| 1247 |
+
|
| 1248 |
+
@llm_timeout(func_name="interleaved_forward_backward_step")
|
| 1249 |
+
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
| 1250 |
+
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
| 1251 |
+
communication between pipeline stages as needed.
|
| 1252 |
+
|
| 1253 |
+
Args:
|
| 1254 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 1255 |
+
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
| 1256 |
+
forward_only (bool, optional):
|
| 1257 |
+
Whether run forward step only. Default is false. If true, no backward will be run.
|
| 1258 |
+
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
| 1259 |
+
return_output_label (bool, optional): If False, the output and label won't be returned.
|
| 1260 |
+
|
| 1261 |
+
Returns:
|
| 1262 |
+
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
| 1263 |
+
The loss would be returned only in the last stage.
|
| 1264 |
+
"""
|
| 1265 |
+
assert (
|
| 1266 |
+
forward_only or return_loss
|
| 1267 |
+
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
| 1268 |
+
|
| 1269 |
+
gpc.set_virtual_pipeline_parallel_rank(0)
|
| 1270 |
+
|
| 1271 |
+
self.load_batch(engine, data_iter)
|
| 1272 |
+
|
| 1273 |
+
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
| 1274 |
+
self._accum_loss = torch.zeros(1, device=get_current_device())
|
| 1275 |
+
if return_output_label:
|
| 1276 |
+
self._return_tensors = []
|
| 1277 |
+
|
| 1278 |
+
if forward_only:
|
| 1279 |
+
self._forward_only_step(engine)
|
| 1280 |
+
else:
|
| 1281 |
+
self._forward_backward_step(engine)
|
| 1282 |
+
|
| 1283 |
+
if return_output_label and len(self._return_tensors) > 0:
|
| 1284 |
+
output, label = pack_return_tensors(self._return_tensors)
|
| 1285 |
+
else:
|
| 1286 |
+
output, label = (None, None)
|
| 1287 |
+
accum_loss = self._accum_loss
|
| 1288 |
+
|
| 1289 |
+
self._clear_state()
|
| 1290 |
+
|
| 1291 |
+
return output, label, accum_loss
|
| 1292 |
+
|
| 1293 |
+
|
| 1294 |
+
class KDPipelineScheduler(PipelineScheduler):
|
| 1295 |
+
|
| 1296 |
+
def __init__(
|
| 1297 |
+
self,
|
| 1298 |
+
num_microbatches: int,
|
| 1299 |
+
dtype: torch.dtype = torch.float,
|
| 1300 |
+
data_process_func: Callable = None,
|
| 1301 |
+
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
| 1302 |
+
scatter_gather_tensors: bool = False,
|
| 1303 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 1304 |
+
):
|
| 1305 |
+
super().__init__(
|
| 1306 |
+
num_microbatches=num_microbatches,
|
| 1307 |
+
dtype=dtype,
|
| 1308 |
+
data_process_func=data_process_func,
|
| 1309 |
+
tensor_shape=tensor_shape,
|
| 1310 |
+
scatter_gather_tensors=scatter_gather_tensors,
|
| 1311 |
+
scheduler_hooks=scheduler_hooks,
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
| 1315 |
+
"""
|
| 1316 |
+
Forward step for passed-in model. If it is the first stage, the input tensor
|
| 1317 |
+
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
| 1318 |
+
Returns output tensor. This is a helper function and can be ignored by users.
|
| 1319 |
+
|
| 1320 |
+
Args:
|
| 1321 |
+
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
| 1322 |
+
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
| 1323 |
+
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
| 1324 |
+
return_output_label (bool, optional): Whether returns output labels.
|
| 1325 |
+
accum_loss (optional): Where accumulated loss stores.
|
| 1326 |
+
Returns:
|
| 1327 |
+
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
|
| 1328 |
+
pipeline stage.
|
| 1329 |
+
"""
|
| 1330 |
+
micro_batch_data = self.load_micro_batch()
|
| 1331 |
+
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
|
| 1332 |
+
|
| 1333 |
+
self._call_hooks("before_forward", data)
|
| 1334 |
+
output_obj = self._call_engine(engine.model, data)
|
| 1335 |
+
self._call_hooks("after_forward", output_obj)
|
| 1336 |
+
|
| 1337 |
+
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
| 1338 |
+
self._call_hooks("post_helper_func", output_obj, label)
|
| 1339 |
+
if return_output_label:
|
| 1340 |
+
return_tensors.append((output_obj, label))
|
| 1341 |
+
if accum_loss is not None:
|
| 1342 |
+
self._call_hooks("before_criterion", output_obj, label)
|
| 1343 |
+
loss_gt = gpc.config.kd_config['gt_weight'] * self._call_engine_criterion(engine.criterion, output_obj,
|
| 1344 |
+
label)
|
| 1345 |
+
|
| 1346 |
+
with torch.no_grad():
|
| 1347 |
+
engine.teacher.eval()
|
| 1348 |
+
output_obj_t = self._call_engine(engine.teacher, data)
|
| 1349 |
+
|
| 1350 |
+
loss_kd = gpc.config.kd_config['kd_weight'] * self._call_engine_criterion(engine.kd_criterion,
|
| 1351 |
+
output_obj,
|
| 1352 |
+
(output_obj_t, label))
|
| 1353 |
+
# loss = (gpc.config.kd_config['gt_weight'] * loss_gt +
|
| 1354 |
+
# gpc.config.kd_config['kd_weight'] * loss_kd)
|
| 1355 |
+
self._call_hooks("after_criterion", loss_gt + loss_kd)
|
| 1356 |
+
|
| 1357 |
+
loss_gt_reduced = loss_gt / self.num_microbatches
|
| 1358 |
+
loss_kd_reduced = loss_kd / self.num_microbatches
|
| 1359 |
+
accum_loss['loss_gt'].add_(loss_gt_reduced.detach())
|
| 1360 |
+
accum_loss['loss_kd'].add_(loss_kd_reduced.detach())
|
| 1361 |
+
output_obj = loss_gt_reduced + loss_kd_reduced
|
| 1362 |
+
|
| 1363 |
+
return output_obj
|
InternLM/internlm/core/trainer.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from typing import Iterable, Optional
|
| 8 |
+
|
| 9 |
+
from internlm.core.engine import Engine
|
| 10 |
+
from internlm.core.scheduler import (
|
| 11 |
+
BaseScheduler,
|
| 12 |
+
InterleavedPipelineScheduler,
|
| 13 |
+
NonPipelineScheduler,
|
| 14 |
+
PipelineScheduler,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TrainState:
|
| 19 |
+
"""
|
| 20 |
+
The TrainState class is used to record the current state of training.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
train_dl (DataLoader): The DataLoader object used for training.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config, batch_sampler) -> None:
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
config (Config): internlm config
|
| 30 |
+
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
|
| 31 |
+
asynchronous and prefetched, the batch_sampler state maintained inside the
|
| 32 |
+
dataloader are faster then the actual training progress, so we copy the
|
| 33 |
+
batch_sampler as the anchor point of ckpt reload.
|
| 34 |
+
"""
|
| 35 |
+
# The number of batches produced by the data iterator
|
| 36 |
+
self.batch_count: int = 0
|
| 37 |
+
# Used to store the number of samples consumed in the current epoch
|
| 38 |
+
self.num_consumed_samples_in_epoch: int = 0
|
| 39 |
+
# Total number of tokens consumed
|
| 40 |
+
self.num_consumed_tokens: int = 0
|
| 41 |
+
# Number of batches skipped due to inf or nan values
|
| 42 |
+
self.inf_nan_skip_batches: int = 0
|
| 43 |
+
# Records the number of updates, skipped batches and inf batches are not counted
|
| 44 |
+
self.step_count: int = 0
|
| 45 |
+
|
| 46 |
+
# Total step count
|
| 47 |
+
self.total_steps: int = config.data.total_steps
|
| 48 |
+
|
| 49 |
+
# resume tensorboard folder, need load from checkpoint or set manually.
|
| 50 |
+
self.resume_tb_folder = config.resume_tb_folder
|
| 51 |
+
|
| 52 |
+
self.tensorboard_folder = config.tensorboard_folder
|
| 53 |
+
|
| 54 |
+
# learning rate
|
| 55 |
+
self.lr = config.adam.lr
|
| 56 |
+
|
| 57 |
+
# smapler state
|
| 58 |
+
if batch_sampler:
|
| 59 |
+
self.init_batch_sampler(batch_sampler)
|
| 60 |
+
|
| 61 |
+
def init_batch_sampler(self, batch_sampler):
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
batch_sampler (torch.utils.data.Sampler): sampler.
|
| 65 |
+
"""
|
| 66 |
+
# make a copy of batch_sampler.
|
| 67 |
+
self.batch_sampler = batch_sampler.copy()
|
| 68 |
+
# Iterator for the batch sampler
|
| 69 |
+
self.batch_sampler_iter = iter(self.batch_sampler)
|
| 70 |
+
|
| 71 |
+
def __str__(self) -> str:
|
| 72 |
+
"""Returns a string representation of the training state in JSON format."""
|
| 73 |
+
info = {
|
| 74 |
+
"batch_count": self.batch_count,
|
| 75 |
+
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
| 76 |
+
"num_consumed_tokens": self.num_consumed_tokens,
|
| 77 |
+
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
| 78 |
+
"step_count": self.step_count,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
return json.dumps(info, indent=4, sort_keys=True)
|
| 82 |
+
|
| 83 |
+
def load_state_dict(self, other_stuffs):
|
| 84 |
+
"""
|
| 85 |
+
Resumes training from a checkpoint.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
other_stuffs (dict): Other information needed to resume training.
|
| 89 |
+
"""
|
| 90 |
+
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
|
| 91 |
+
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
|
| 92 |
+
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
|
| 93 |
+
|
| 94 |
+
# Because the ckpt save occurs after updating 'step_count',
|
| 95 |
+
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
|
| 96 |
+
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
|
| 97 |
+
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
| 98 |
+
self.step_count = other_stuffs.get("step_count", self.batch_count)
|
| 99 |
+
|
| 100 |
+
# resume tensorboard from older tensorboard_folder
|
| 101 |
+
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
| 102 |
+
|
| 103 |
+
def state_dict(self):
|
| 104 |
+
return {
|
| 105 |
+
"batch_count": self.batch_count,
|
| 106 |
+
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
| 107 |
+
"num_consumed_tokens": self.num_consumed_tokens,
|
| 108 |
+
"inf_nan_skip_batches": self.inf_nan_skip_batches,
|
| 109 |
+
"step_count": self.step_count,
|
| 110 |
+
"tensorboard_folder": self.tensorboard_folder,
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Trainer:
|
| 115 |
+
"""This is a class tending for easy deployments of users' training and evaluation instead of
|
| 116 |
+
writing their own scripts.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
engine (:class:`Engine`): Engine responsible for the process function.
|
| 120 |
+
schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(
|
| 124 |
+
self,
|
| 125 |
+
engine: Engine,
|
| 126 |
+
schedule: Optional[BaseScheduler] = None,
|
| 127 |
+
):
|
| 128 |
+
"""Initializes the Trainer class.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
engine (Engine): The engine responsible for the process function.
|
| 132 |
+
schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None.
|
| 133 |
+
"""
|
| 134 |
+
self._engine = engine
|
| 135 |
+
|
| 136 |
+
# build schedule
|
| 137 |
+
if schedule is None:
|
| 138 |
+
self._schedule = NonPipelineScheduler()
|
| 139 |
+
else:
|
| 140 |
+
assert isinstance(
|
| 141 |
+
schedule, BaseScheduler
|
| 142 |
+
), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
|
| 143 |
+
self._schedule = schedule
|
| 144 |
+
|
| 145 |
+
self._schedule.pre_processing(self._engine)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def engine(self):
|
| 149 |
+
"""Returns the engine that responsible for managing the training and evaluation process."""
|
| 150 |
+
return self._engine
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def schedule(self):
|
| 154 |
+
"""Returns the runtime scheduler."""
|
| 155 |
+
return self._schedule
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def uses_pipeline(self):
|
| 159 |
+
"""Returns whether the pipeline parallel is used or not."""
|
| 160 |
+
return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
|
| 161 |
+
|
| 162 |
+
def train(self):
|
| 163 |
+
"""Sets the model to training mode."""
|
| 164 |
+
self._engine.train()
|
| 165 |
+
|
| 166 |
+
def eval(self):
|
| 167 |
+
"""Sets the model to evaluation mode."""
|
| 168 |
+
self._engine.eval()
|
| 169 |
+
|
| 170 |
+
def zero_grad(self):
|
| 171 |
+
"""Sets the gradient of all parameters in the model to zero."""
|
| 172 |
+
self._engine.zero_grad()
|
| 173 |
+
|
| 174 |
+
def step(self):
|
| 175 |
+
"""Executes the parameter update step."""
|
| 176 |
+
return self._engine.step()
|
| 177 |
+
|
| 178 |
+
def execute_schedule(self, data_iter: Iterable, **kwargs):
|
| 179 |
+
"""Runs the forward, loss computation, and backward for the model.
|
| 180 |
+
Returns a tuple of (output, label, loss).
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
data_iter (Iterable): The data iterator.
|
| 184 |
+
**kwargs: Additional keyword arguments.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
|
| 188 |
+
"""
|
| 189 |
+
output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
|
| 190 |
+
return output, label, loss
|
InternLM/internlm/data/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .batch_sampler import get_dpsampler_dataloader
|
| 2 |
+
from .collaters import jsonl_ds_collate_fn, packed_collate_fn
|
| 3 |
+
from .dummy_dataset import RandomDataset
|
| 4 |
+
from .packed_dataset import PackedDataset, PackedDatasetWithoutCuSeqlen
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"jsonl_ds_collate_fn",
|
| 8 |
+
"packed_collate_fn",
|
| 9 |
+
"RandomDataset",
|
| 10 |
+
"PackedDataset",
|
| 11 |
+
"PackedDatasetWithoutCuSeqlen",
|
| 12 |
+
"get_dpsampler_dataloader",
|
| 13 |
+
]
|
InternLM/internlm/data/batch_sampler.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
from typing import Iterator, TypeVar
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset, Sampler
|
| 11 |
+
|
| 12 |
+
from internlm.core.context import ParallelMode
|
| 13 |
+
from internlm.core.context import global_context as gpc
|
| 14 |
+
from internlm.utils.logger import get_logger
|
| 15 |
+
|
| 16 |
+
logger = get_logger(__file__)
|
| 17 |
+
|
| 18 |
+
T_co = TypeVar("T_co", covariant=True)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DataParallelSampler(Sampler):
|
| 22 |
+
"""A data sampler for distributed data parallelism.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
|
| 26 |
+
shuffle (bool, optional): Whether to shuffle data, defaults to False.
|
| 27 |
+
seed (int, optional): The random seed used for sampling, defaults to 0.
|
| 28 |
+
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
| 29 |
+
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
| 30 |
+
the batch size, then the last batch will be smaller, defaults to False.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
dataset: Dataset,
|
| 36 |
+
shuffle: bool = False,
|
| 37 |
+
seed: int = 0,
|
| 38 |
+
drop_last: bool = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
self.dataset = dataset
|
| 41 |
+
self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
|
| 42 |
+
self.rank = gpc.get_local_rank(ParallelMode.DATA)
|
| 43 |
+
self.epoch = 0
|
| 44 |
+
self.drop_last = drop_last
|
| 45 |
+
# If the dataset length is evenly divisible by # of replicas, then there
|
| 46 |
+
# is no need to drop any data, since the dataset will be split equally.
|
| 47 |
+
# type: ignore[arg-type]
|
| 48 |
+
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
| 49 |
+
# Split to nearest available length that is evenly divisible.
|
| 50 |
+
# This is to ensure each rank receives the same amount of data when
|
| 51 |
+
# using this Sampler.
|
| 52 |
+
self.num_samples = math.ceil(
|
| 53 |
+
# `type:ignore` is required because Dataset cannot provide a default __len__
|
| 54 |
+
# see NOTE in pytorch/torch/utils/data/sampler.py
|
| 55 |
+
(len(self.dataset) - self.num_replicas)
|
| 56 |
+
/ self.num_replicas # type: ignore[arg-type]
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
|
| 60 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 61 |
+
self.shuffle = shuffle
|
| 62 |
+
self.seed = seed
|
| 63 |
+
|
| 64 |
+
def __iter__(self) -> Iterator[T_co]:
|
| 65 |
+
if self.shuffle:
|
| 66 |
+
# deterministically shuffle based on epoch and seed
|
| 67 |
+
g = torch.Generator()
|
| 68 |
+
g.manual_seed(self.seed + self.epoch)
|
| 69 |
+
# type: ignore[arg-type]
|
| 70 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
| 71 |
+
|
| 72 |
+
# update for next epoch so that there is no need to call
|
| 73 |
+
# set_epoch manually
|
| 74 |
+
self.epoch += 1
|
| 75 |
+
else:
|
| 76 |
+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
|
| 77 |
+
|
| 78 |
+
if not self.drop_last:
|
| 79 |
+
# add extra samples to make it evenly divisible
|
| 80 |
+
padding_size = self.total_size - len(indices)
|
| 81 |
+
if padding_size <= len(indices):
|
| 82 |
+
indices += indices[:padding_size]
|
| 83 |
+
else:
|
| 84 |
+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
| 85 |
+
else:
|
| 86 |
+
# remove tail of data to make it evenly divisible.
|
| 87 |
+
indices = indices[: self.total_size]
|
| 88 |
+
assert len(indices) == self.total_size
|
| 89 |
+
|
| 90 |
+
# subsample
|
| 91 |
+
indices = indices[self.rank : self.total_size : self.num_replicas]
|
| 92 |
+
assert len(indices) == self.num_samples
|
| 93 |
+
|
| 94 |
+
return iter(indices)
|
| 95 |
+
|
| 96 |
+
def __len__(self) -> int:
|
| 97 |
+
return self.num_samples
|
| 98 |
+
|
| 99 |
+
def set_epoch(self, epoch: int) -> None:
|
| 100 |
+
r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
| 101 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
| 102 |
+
sampler will yield the same ordering.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
epoch (int): Epoch number.
|
| 106 |
+
"""
|
| 107 |
+
self.epoch = epoch
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_dpsampler_dataloader(
|
| 111 |
+
dataset,
|
| 112 |
+
shuffle=False,
|
| 113 |
+
seed=1024,
|
| 114 |
+
add_sampler=True,
|
| 115 |
+
drop_last=False,
|
| 116 |
+
pin_memory=False,
|
| 117 |
+
num_workers=0,
|
| 118 |
+
**kwargs,
|
| 119 |
+
):
|
| 120 |
+
r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
| 121 |
+
|
| 122 |
+
Note:
|
| 123 |
+
When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
|
| 124 |
+
on the 1st stage and label on the last stage.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
|
| 128 |
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
| 129 |
+
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
| 130 |
+
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
| 131 |
+
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
| 132 |
+
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
| 133 |
+
the batch size, then the last batch will be smaller, defaults to False.
|
| 134 |
+
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
| 135 |
+
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
| 136 |
+
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
| 137 |
+
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
| 141 |
+
"""
|
| 142 |
+
_kwargs = kwargs.copy()
|
| 143 |
+
|
| 144 |
+
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
| 145 |
+
sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last)
|
| 146 |
+
else:
|
| 147 |
+
sampler = None
|
| 148 |
+
|
| 149 |
+
# Deterministic dataloader
|
| 150 |
+
def seed_worker():
|
| 151 |
+
worker_seed = seed
|
| 152 |
+
np.random.seed(worker_seed)
|
| 153 |
+
torch.manual_seed(worker_seed)
|
| 154 |
+
random.seed(worker_seed)
|
| 155 |
+
|
| 156 |
+
if sampler is None:
|
| 157 |
+
return DataLoader(
|
| 158 |
+
dataset,
|
| 159 |
+
worker_init_fn=seed_worker,
|
| 160 |
+
shuffle=shuffle,
|
| 161 |
+
drop_last=drop_last,
|
| 162 |
+
pin_memory=pin_memory,
|
| 163 |
+
num_workers=num_workers,
|
| 164 |
+
**_kwargs,
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
return DataLoader(
|
| 168 |
+
dataset,
|
| 169 |
+
sampler=sampler,
|
| 170 |
+
worker_init_fn=seed_worker,
|
| 171 |
+
drop_last=drop_last,
|
| 172 |
+
pin_memory=pin_memory,
|
| 173 |
+
num_workers=num_workers,
|
| 174 |
+
**_kwargs,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class StaticBatchSampler:
|
| 179 |
+
"""
|
| 180 |
+
A static batch sampler that generates batches with a fixed micro-batch size.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
num_samples (int): The total number of samples in the dataset.
|
| 184 |
+
batch_size (int): The batch size for the current rank. Defaults to 192.
|
| 185 |
+
rampup_batch_size (str): A string with three space-separated integers representing the
|
| 186 |
+
starting batch size, the increment, and the number of steps between
|
| 187 |
+
each increment. For tools, "192 24 8" means that the batch size
|
| 188 |
+
starts at 192 and increases by 24 every 8 steps. Defaults to
|
| 189 |
+
"6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
|
| 190 |
+
micro_bsz (int): The micro-batch size. Defaults to 2.
|
| 191 |
+
seed (int): The random seed for shuffling the indices. Defaults to 0.
|
| 192 |
+
drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
|
| 193 |
+
data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
|
| 194 |
+
data_world_size (int): The number of processes in the data parallel group. Defaults to 1.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
datasets,
|
| 200 |
+
batch_size=192,
|
| 201 |
+
rampup_batch_size="6 2 8",
|
| 202 |
+
micro_bsz=2,
|
| 203 |
+
seed=0,
|
| 204 |
+
drop_last=True,
|
| 205 |
+
data_rank=0,
|
| 206 |
+
data_world_size=1,
|
| 207 |
+
):
|
| 208 |
+
assert drop_last is True, "Currently only support drop last"
|
| 209 |
+
if rampup_batch_size:
|
| 210 |
+
# In the process increase to batch_size
|
| 211 |
+
start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
|
| 212 |
+
else:
|
| 213 |
+
start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
|
| 214 |
+
self.raw_rampup_batch_size = rampup_batch_size
|
| 215 |
+
self.start_bsz = start_bsz
|
| 216 |
+
self.bsz_incre = bsz_incre
|
| 217 |
+
self.incre_every = incre_every
|
| 218 |
+
if gpc.is_initialized(ParallelMode.PIPELINE):
|
| 219 |
+
assert (
|
| 220 |
+
batch_size - self.start_bsz
|
| 221 |
+
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
| 222 |
+
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
| 223 |
+
assert (
|
| 224 |
+
self.start_bsz % micro_bsz == 0
|
| 225 |
+
), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
|
| 226 |
+
assert (
|
| 227 |
+
self.bsz_incre % micro_bsz == 0
|
| 228 |
+
), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
|
| 229 |
+
|
| 230 |
+
self.batch_size = batch_size
|
| 231 |
+
self.epoch = 0
|
| 232 |
+
self.seed = seed
|
| 233 |
+
self.rng = np.random.RandomState(seed)
|
| 234 |
+
self.batch_count = 0
|
| 235 |
+
self.micro_bsz = micro_bsz
|
| 236 |
+
self.data_rank = data_rank
|
| 237 |
+
self.data_world_size = data_world_size
|
| 238 |
+
self.num_consumed_samples_in_epoch = 0
|
| 239 |
+
self.datasets = datasets
|
| 240 |
+
self.num_samples = sum([len(ds) for ds in datasets])
|
| 241 |
+
|
| 242 |
+
self.get_indices() # get data
|
| 243 |
+
|
| 244 |
+
def get_indices(self, old_indices=None):
|
| 245 |
+
if old_indices is not None:
|
| 246 |
+
assert (
|
| 247 |
+
len(old_indices) <= self.num_samples
|
| 248 |
+
), f"The checkpoint has {len(old_indices)} samples, \
|
| 249 |
+
while the new restart use less samples ({self.num_samples})"
|
| 250 |
+
|
| 251 |
+
else:
|
| 252 |
+
old_indices = np.array([])
|
| 253 |
+
|
| 254 |
+
# indices includes len(old_indices) but not self.num_samples
|
| 255 |
+
indices = np.arange(len(old_indices), self.num_samples)
|
| 256 |
+
self.rng_state = self.rng.get_state()
|
| 257 |
+
self.rng.shuffle(indices)
|
| 258 |
+
# Need to consider drop_last
|
| 259 |
+
ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
|
| 260 |
+
if self.batch_count < ramp_steps * self.incre_every:
|
| 261 |
+
rampup_samples = 0
|
| 262 |
+
for i in range(ramp_steps):
|
| 263 |
+
rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
|
| 264 |
+
assert (
|
| 265 |
+
rampup_samples * self.data_world_size <= self.num_samples
|
| 266 |
+
), f"Too much rampup samples: \
|
| 267 |
+
{rampup_samples*self.data_world_size} Vs. self.num_samples: {self.num_samples}"
|
| 268 |
+
|
| 269 |
+
num_samples = (self.num_samples - rampup_samples * self.data_world_size) // (
|
| 270 |
+
self.batch_size * self.data_world_size
|
| 271 |
+
)
|
| 272 |
+
num_samples = num_samples * self.batch_size * self.data_world_size + rampup_samples * self.data_world_size
|
| 273 |
+
else:
|
| 274 |
+
num_samples = self.num_samples // (self.batch_size * self.data_world_size)
|
| 275 |
+
num_samples = num_samples * self.batch_size * self.data_world_size
|
| 276 |
+
indices = np.concatenate([old_indices, indices]).astype(int) # It needs to be spliced with the previous
|
| 277 |
+
indices = indices[:num_samples]
|
| 278 |
+
self.indices = indices
|
| 279 |
+
assert len(self.indices) >= self.batch_size, "The number of samples should be larger than batch_size"
|
| 280 |
+
self.num_consumed_samples_in_epoch = 0
|
| 281 |
+
|
| 282 |
+
def set_epoch(self, epoch):
|
| 283 |
+
self.epoch = epoch
|
| 284 |
+
self.rng = np.random.RandomState(self.seed + self.epoch)
|
| 285 |
+
|
| 286 |
+
def __len__(self):
|
| 287 |
+
ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
|
| 288 |
+
if self.batch_count < ramp_steps * self.incre_every:
|
| 289 |
+
rampup_samples = 0
|
| 290 |
+
for i in range(ramp_steps):
|
| 291 |
+
rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
|
| 292 |
+
assert (
|
| 293 |
+
rampup_samples * self.data_world_size <= self.num_samples
|
| 294 |
+
), f"Too much rampup samples: {rampup_samples*self.data_world_size} \
|
| 295 |
+
Vs. self.num_samples: {self.num_samples}"
|
| 296 |
+
|
| 297 |
+
num_batches = (self.num_samples - rampup_samples * self.data_world_size) // self.batch_size
|
| 298 |
+
num_batches = num_batches // self.data_world_size + self.incre_every * ramp_steps
|
| 299 |
+
else:
|
| 300 |
+
num_batches = self.num_samples // self.batch_size // self.data_world_size
|
| 301 |
+
|
| 302 |
+
return num_batches
|
| 303 |
+
|
| 304 |
+
def __iter__(self):
|
| 305 |
+
indices = self.indices[self.data_rank :: self.data_world_size]
|
| 306 |
+
while self.num_consumed_samples_in_epoch < len(indices):
|
| 307 |
+
batch_rampup_idx = self.batch_count // self.incre_every
|
| 308 |
+
cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz
|
| 309 |
+
cur_batch_size = min(cur_batch_size, self.batch_size)
|
| 310 |
+
batch = indices[self.num_consumed_samples_in_epoch : self.num_consumed_samples_in_epoch + cur_batch_size]
|
| 311 |
+
yield batch
|
| 312 |
+
self.num_consumed_samples_in_epoch += len(batch) # Consider multiple processes.
|
| 313 |
+
self.batch_count += 1
|
| 314 |
+
self.get_indices() # get a new round
|
| 315 |
+
|
| 316 |
+
def state_dict(self):
|
| 317 |
+
states = {
|
| 318 |
+
"batch_size": self.batch_size,
|
| 319 |
+
"raw_rampup_batch_size": self.raw_rampup_batch_size,
|
| 320 |
+
"rng_state": self.rng_state,
|
| 321 |
+
"epoch": self.epoch,
|
| 322 |
+
"seed": self.seed,
|
| 323 |
+
"data_world_size": self.data_world_size,
|
| 324 |
+
"num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
|
| 325 |
+
"batch_count": self.batch_count, # The batch_count here is due to the existence of multiple processes,
|
| 326 |
+
# the batch may be oversent, and it needs to be overwritten by the external batch_count
|
| 327 |
+
"indices": self.indices, # The sequence used to breakpoint retraining is the same as before
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
return states
|
| 331 |
+
|
| 332 |
+
def load_state_dict(self, states):
|
| 333 |
+
for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size'
|
| 334 |
+
assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change
|
| 335 |
+
self.rng.set_state(states["rng_state"])
|
| 336 |
+
self.get_indices(old_indices=None) # Regenerate indices based on random state
|
| 337 |
+
self.epoch = states["epoch"]
|
| 338 |
+
self.batch_count = states["batch_count"]
|
| 339 |
+
self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"]
|
| 340 |
+
|
| 341 |
+
def copy(self):
|
| 342 |
+
copy_sampler = StaticBatchSampler(
|
| 343 |
+
self.datasets,
|
| 344 |
+
self.batch_size,
|
| 345 |
+
self.raw_rampup_batch_size,
|
| 346 |
+
self.micro_bsz,
|
| 347 |
+
self.seed,
|
| 348 |
+
drop_last=True,
|
| 349 |
+
data_rank=self.data_rank,
|
| 350 |
+
data_world_size=self.data_world_size,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
copy_sampler.load_state_dict(self.state_dict())
|
| 354 |
+
return copy_sampler
|
InternLM/internlm/data/collaters.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def packed_collate_fn(batch, packed_length):
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
Collate function for packed input sequences.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
batch (List[Dict]): List of dictionaries representing each sample in batch.
|
| 14 |
+
Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys.
|
| 15 |
+
packed_length (int): The length of packed sequence.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
|
| 19 |
+
"cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels".
|
| 20 |
+
|
| 21 |
+
Raises:
|
| 22 |
+
AssertionError: If the length of a sample is not equal to packed_length.
|
| 23 |
+
AssertionError: If the shape of the padded "input_ids" tensor does not have the correct shape.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
xs, ys, cu_seqlens, indexes, ts = [], [], [], [], []
|
| 27 |
+
for b in batch:
|
| 28 |
+
assert (
|
| 29 |
+
len(b["tokens"]) == packed_length
|
| 30 |
+
), f"length of a sample should be equal to packed_length, but got {len(b['tokens'])} and {packed_length})"
|
| 31 |
+
assert (
|
| 32 |
+
len(b["labels"]) == packed_length
|
| 33 |
+
), f"length of a sample should be equal to packed_length, but got {len(b['labels'])} and {packed_length})"
|
| 34 |
+
assert (
|
| 35 |
+
len(b["type_ids"]) == packed_length
|
| 36 |
+
), f"length of a sample should be equal to packed_length, but got {len(b['type_ids'])} and {packed_length})"
|
| 37 |
+
|
| 38 |
+
tokens = [abs(w) for w in b["tokens"]]
|
| 39 |
+
labels = [w if w > 0 else -100 for w in b["labels"]]
|
| 40 |
+
|
| 41 |
+
xs.append(torch.LongTensor(tokens))
|
| 42 |
+
# The labels have been shifted here, so they are aligned with the output corresponding to the token
|
| 43 |
+
ys.append(torch.LongTensor(labels))
|
| 44 |
+
ts.append(torch.LongTensor(b["type_ids"]))
|
| 45 |
+
cu_seqlens.append(torch.IntTensor(b["cu_seqlens"]))
|
| 46 |
+
indexes.append(torch.LongTensor(b["indexes"]))
|
| 47 |
+
|
| 48 |
+
xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
|
| 49 |
+
ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
|
| 50 |
+
ts = torch.nn.utils.rnn.pad_sequence(ts, batch_first=True, padding_value=0)
|
| 51 |
+
indexes = torch.stack(indexes, dim=0)
|
| 52 |
+
if len(set(map(len, cu_seqlens))) == 1: # if has uniform length, then stack to save device transfer time
|
| 53 |
+
cu_seqlens = torch.stack(cu_seqlens, dim=0)
|
| 54 |
+
|
| 55 |
+
assert xs.shape[1] == packed_length, (xs.shape[1], packed_length)
|
| 56 |
+
|
| 57 |
+
return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def jsonl_ds_collate_fn(batch, max_length_per_sample):
|
| 61 |
+
"""
|
| 62 |
+
Collate function for json dataset.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
batch (List[Dict]): List of dictionaries representing each sample in batch.
|
| 66 |
+
Each dictionary contains "tokens".
|
| 67 |
+
max_length_per_sample (int): The length of output sequence.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
|
| 71 |
+
and the tensor of padded "labels".
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
xs, ys = [], []
|
| 75 |
+
for x in batch:
|
| 76 |
+
x["tokens"] = x["tokens"][:max_length_per_sample]
|
| 77 |
+
tokens = [abs(w) for w in x["tokens"]]
|
| 78 |
+
labels = [w if w > 0 else -100 for w in x["tokens"]]
|
| 79 |
+
labels = labels[1:] + [-100]
|
| 80 |
+
xs.append(torch.as_tensor(tokens))
|
| 81 |
+
ys.append(torch.as_tensor(labels)) # y has been shifted
|
| 82 |
+
xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
|
| 83 |
+
ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
|
| 84 |
+
|
| 85 |
+
xs = torch.cat([xs, xs.new_zeros(len(xs), max_length_per_sample - len(xs[0]))], dim=-1)
|
| 86 |
+
ys = torch.cat([ys, ys.new_full((len(ys), max_length_per_sample - len(ys[0])), fill_value=-100)], dim=-1)
|
| 87 |
+
|
| 88 |
+
return {"input_ids": xs}, ys
|
InternLM/internlm/data/dataset.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import ConcatDataset
|
| 5 |
+
|
| 6 |
+
from internlm.data.single_dataset import JsonlDataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_dataset_dict(folder, split="valid") -> Dict:
|
| 10 |
+
"""
|
| 11 |
+
Return a dictionary of Datasets from a folder containing data files for validation.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
folder (str): The path to the folder containing data files.
|
| 15 |
+
split (str): The split of the data files to be used, default is "valid".
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
A dictionary containing Datasets for each folder in the given path
|
| 19 |
+
that contains data files with the specified split.
|
| 20 |
+
|
| 21 |
+
Raises:
|
| 22 |
+
AssertionError: If the given folder does not exist.
|
| 23 |
+
|
| 24 |
+
Example:
|
| 25 |
+
If the given folder is as follows,
|
| 26 |
+
- data
|
| 27 |
+
- zhihu
|
| 28 |
+
- xxx.bin
|
| 29 |
+
- valid.bin
|
| 30 |
+
- baike
|
| 31 |
+
- xxx.bin
|
| 32 |
+
- valid.bin
|
| 33 |
+
|
| 34 |
+
The returned dictionary will be,
|
| 35 |
+
{
|
| 36 |
+
'zhihu': Dataset,
|
| 37 |
+
'baike': Dataset
|
| 38 |
+
}
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
assert os.path.exists(folder), f"folder `{folder}` not exists"
|
| 42 |
+
data_dict = {}
|
| 43 |
+
|
| 44 |
+
for root, dirs, files in os.walk(folder, followlinks=True):
|
| 45 |
+
dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
|
| 46 |
+
datasets = []
|
| 47 |
+
for fn in sorted(files): # Need sorted to ensure that the order is consistent
|
| 48 |
+
if fn.endswith(".bin") and split in fn:
|
| 49 |
+
fp = os.path.join(root, fn)
|
| 50 |
+
ds = JsonlDataset(fp)
|
| 51 |
+
datasets.append(ds)
|
| 52 |
+
if datasets:
|
| 53 |
+
ds = ConcatDataset(datasets=datasets)
|
| 54 |
+
data_dict[os.path.basename(root)] = ds
|
| 55 |
+
|
| 56 |
+
return data_dict
|
InternLM/internlm/data/dummy_dataset.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RandomDataset(Dataset):
|
| 9 |
+
"""
|
| 10 |
+
RandomDataset for generating random dataset.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
num_samples (int): The number of samples to generate.
|
| 14 |
+
max_len (int): The maximum length of each sample.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, num_samples=10000, max_len=1024) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
rng = np.random.RandomState(1999)
|
| 21 |
+
max_num = rng.randint(1, 30, size=(num_samples,))
|
| 22 |
+
rep_num = rng.randint(10, 200, size=(num_samples,))
|
| 23 |
+
data = []
|
| 24 |
+
lengths = []
|
| 25 |
+
for n, r in zip(max_num, rep_num):
|
| 26 |
+
d = list(range(n)) * r
|
| 27 |
+
d = [n, r] + d
|
| 28 |
+
d = d[:max_len]
|
| 29 |
+
data.append(d)
|
| 30 |
+
lengths.append(len(d))
|
| 31 |
+
self.data = data
|
| 32 |
+
self.max_len = max_len
|
| 33 |
+
self.lengths = np.array(lengths, dtype=int)
|
| 34 |
+
|
| 35 |
+
def __getitem__(self, index):
|
| 36 |
+
d = self.data[index]
|
| 37 |
+
input_ids = np.array(d, dtype=int)
|
| 38 |
+
return {"tokens": list(input_ids), "type_id": 0}
|
| 39 |
+
|
| 40 |
+
def get_dataset_name(self):
|
| 41 |
+
return "dummy_path/dummy_lang/dummy_ds/train.bin"
|
| 42 |
+
|
| 43 |
+
def __len__(self):
|
| 44 |
+
return len(self.data)
|
InternLM/internlm/data/packed_dataset.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import itertools as it
|
| 5 |
+
import operator
|
| 6 |
+
import os
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from typing import Dict
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torch.utils.data import ConcatDataset
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from internlm.core.context import global_context as gpc
|
| 16 |
+
from internlm.data.single_dataset import JsonlDataset
|
| 17 |
+
from internlm.data.utils import get_dataset_type_id
|
| 18 |
+
from internlm.utils.logger import get_logger
|
| 19 |
+
|
| 20 |
+
DEFAULT_SEED = 1024
|
| 21 |
+
logger = get_logger(__file__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class PackedDataset(torch.utils.data.Dataset):
|
| 25 |
+
"""
|
| 26 |
+
The class PackedDataset takes in a dataset and aggregates samples of different
|
| 27 |
+
lengths together based on the packed_length.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
dataset: The original dataset to pack.
|
| 31 |
+
max_length_per_sample: The maximum length of each original sample. Default is 2048.
|
| 32 |
+
packed_length: The length of each packed sample. Default is 4096.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dataset,
|
| 38 |
+
max_length_per_sample: int = 2048,
|
| 39 |
+
packed_length: int = 4096,
|
| 40 |
+
):
|
| 41 |
+
assert hasattr(dataset, "lengths")
|
| 42 |
+
assert len(getattr(dataset, "lengths")) == len(
|
| 43 |
+
dataset
|
| 44 |
+
), "The dataset must have lengths attribute and have the same length as the dataset"
|
| 45 |
+
self.dataset = dataset
|
| 46 |
+
self.max_length_per_sample = max_length_per_sample
|
| 47 |
+
self.lengths = getattr(self.dataset, "lengths")
|
| 48 |
+
self.packed_length = packed_length
|
| 49 |
+
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
| 50 |
+
|
| 51 |
+
self.seed = DEFAULT_SEED
|
| 52 |
+
self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed)
|
| 53 |
+
self.num_tokens = sum(self.lengths)
|
| 54 |
+
|
| 55 |
+
def get_dataset_name(self):
|
| 56 |
+
return self.dataset.get_dataset_name()
|
| 57 |
+
|
| 58 |
+
def accu_sample_len(self, seed=None):
|
| 59 |
+
"""accumulative length of samples"""
|
| 60 |
+
if seed is not None:
|
| 61 |
+
rng = np.random.RandomState(seed)
|
| 62 |
+
else:
|
| 63 |
+
rng = np.random.RandomState(self.seed - 1)
|
| 64 |
+
|
| 65 |
+
sample_indices = np.arange(len(self.lengths))
|
| 66 |
+
rng.shuffle(sample_indices)
|
| 67 |
+
len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices))
|
| 68 |
+
acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add))
|
| 69 |
+
return sample_indices, len_samples_shuffled, acm_len_samples
|
| 70 |
+
|
| 71 |
+
def __len__(self):
|
| 72 |
+
# Line 405 of document_to_sequence.py in metaseq is directly spliced,
|
| 73 |
+
# without additional consideration of sos or eos
|
| 74 |
+
n_packs = self.num_tokens // self.packed_length
|
| 75 |
+
return n_packs
|
| 76 |
+
|
| 77 |
+
def cal_map(self, carriage_idx: int = 0):
|
| 78 |
+
assert carriage_idx >= 0
|
| 79 |
+
length_train = (carriage_idx + 1) * self.packed_length
|
| 80 |
+
post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left")
|
| 81 |
+
return post_pos
|
| 82 |
+
|
| 83 |
+
def mapping(self, pack_idx: int = 0):
|
| 84 |
+
# pack_idx is zero-based
|
| 85 |
+
pre_pos, pre_token_id = 0, 0
|
| 86 |
+
if pack_idx > 0:
|
| 87 |
+
pre_pos = self.cal_map(pack_idx - 1)
|
| 88 |
+
pre_token_id = self.len_samples_shuffled[pre_pos] - (
|
| 89 |
+
self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length
|
| 90 |
+
)
|
| 91 |
+
if pre_token_id == self.len_samples_shuffled[pre_pos]:
|
| 92 |
+
pre_pos += 1
|
| 93 |
+
pre_token_id = 0
|
| 94 |
+
|
| 95 |
+
pos = self.cal_map(pack_idx)
|
| 96 |
+
token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length)
|
| 97 |
+
return pre_pos, pre_token_id, pos, token_id
|
| 98 |
+
|
| 99 |
+
def build_pack(self, pre_pos: int, pre_token_id: int, pos: int, token_id: int):
|
| 100 |
+
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
| 101 |
+
|
| 102 |
+
while pre_pos < pos:
|
| 103 |
+
sample_idx = self.sample_indices[pre_pos]
|
| 104 |
+
sample = self.dataset[sample_idx]
|
| 105 |
+
chunk = sample["tokens"][pre_token_id:]
|
| 106 |
+
pack.extend(chunk)
|
| 107 |
+
_labels = deepcopy(chunk)
|
| 108 |
+
_labels = list(_labels[1:]) + [-100]
|
| 109 |
+
assert len(_labels) == len(chunk), (_labels, chunk)
|
| 110 |
+
labels.extend(_labels)
|
| 111 |
+
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
| 112 |
+
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
| 113 |
+
for _ in range(num_new_samples):
|
| 114 |
+
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
| 115 |
+
indexes.extend(list(range(self.max_length_per_sample)))
|
| 116 |
+
if tokens_left > 0:
|
| 117 |
+
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
| 118 |
+
indexes.extend(list(range(tokens_left)))
|
| 119 |
+
pre_pos = pre_pos + 1
|
| 120 |
+
pre_token_id = 0
|
| 121 |
+
|
| 122 |
+
sample_idx = self.sample_indices[pos]
|
| 123 |
+
sample = self.dataset[sample_idx]
|
| 124 |
+
chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample
|
| 125 |
+
pack.extend(chunk)
|
| 126 |
+
_labels = deepcopy(chunk)
|
| 127 |
+
if token_id == len(sample["tokens"]):
|
| 128 |
+
_labels = list(_labels[1:]) + [-100]
|
| 129 |
+
else:
|
| 130 |
+
if token_id > len(sample["tokens"]):
|
| 131 |
+
print(f"token_id {token_id}, len of sample {len(sample['tokens'])}")
|
| 132 |
+
_labels = list(_labels[1:]) + [sample["tokens"][token_id]]
|
| 133 |
+
assert len(_labels) == len(chunk), (_labels, chunk)
|
| 134 |
+
labels.extend(_labels)
|
| 135 |
+
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
| 136 |
+
num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
|
| 137 |
+
for _ in range(num_new_samples):
|
| 138 |
+
cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
|
| 139 |
+
indexes.extend(list(range(self.max_length_per_sample)))
|
| 140 |
+
if tokens_left > 0:
|
| 141 |
+
cu_seqlens.append(cu_seqlens[-1] + tokens_left)
|
| 142 |
+
indexes.extend(list(range(tokens_left)))
|
| 143 |
+
|
| 144 |
+
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
| 145 |
+
return out
|
| 146 |
+
|
| 147 |
+
def cal_pos_unpack(self, index):
|
| 148 |
+
if index == 0:
|
| 149 |
+
pre_pos = 0
|
| 150 |
+
else:
|
| 151 |
+
pre_pos = index * gpc.config.data["micro_bsz"]
|
| 152 |
+
|
| 153 |
+
pos = (index + 1) * gpc.config.data["micro_bsz"]
|
| 154 |
+
return pre_pos, pos
|
| 155 |
+
|
| 156 |
+
def build_unpack(self, index):
|
| 157 |
+
|
| 158 |
+
pre_pos, pos = self.cal_pos_unpack(index)
|
| 159 |
+
|
| 160 |
+
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
| 161 |
+
|
| 162 |
+
while pre_pos < pos and pre_pos < len(self.dataset):
|
| 163 |
+
sample_idx = self.sample_indices[pre_pos]
|
| 164 |
+
sample = self.dataset[sample_idx]
|
| 165 |
+
length = min(len(sample["tokens"]), self.max_length_per_sample)
|
| 166 |
+
chunk = sample["tokens"][0:length]
|
| 167 |
+
pack.extend(chunk)
|
| 168 |
+
_labels = deepcopy(chunk)
|
| 169 |
+
_labels = list(_labels[1:]) + [-100]
|
| 170 |
+
assert len(_labels) == len(chunk), (_labels, chunk)
|
| 171 |
+
labels.extend(_labels)
|
| 172 |
+
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
| 173 |
+
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
|
| 174 |
+
indexes.extend(list(range(length)))
|
| 175 |
+
pre_pos = pre_pos + 1
|
| 176 |
+
|
| 177 |
+
if cu_seqlens[-1] != self.packed_length:
|
| 178 |
+
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
|
| 179 |
+
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
|
| 180 |
+
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
|
| 181 |
+
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
|
| 182 |
+
cu_seqlens.append(self.packed_length)
|
| 183 |
+
|
| 184 |
+
assert len(pack) == self.packed_length
|
| 185 |
+
|
| 186 |
+
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
| 187 |
+
return out
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, item: int) -> Dict:
|
| 190 |
+
"""Given the index, it returns a dict as
|
| 191 |
+
{
|
| 192 |
+
'tokens': List[int],
|
| 193 |
+
'cu_seqlens': List[int],
|
| 194 |
+
'indexes': List[int], # denotes positional vector as 'tokens'
|
| 195 |
+
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
| 196 |
+
}
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
if gpc.config.model.use_flash_attn:
|
| 200 |
+
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
| 201 |
+
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
| 202 |
+
|
| 203 |
+
return self.build_unpack(item)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
| 207 |
+
"""
|
| 208 |
+
A dataset wrapper that aggregates samples with different lengths based on packed_length.
|
| 209 |
+
If a sample is shorter than max_length_per_sample, it will be merged with other samples.
|
| 210 |
+
For tools, given a dataset with 10 samples:
|
| 211 |
+
[1, 2, 3, 4, 5]
|
| 212 |
+
[6, 7]
|
| 213 |
+
[8, 9, 10, 11]
|
| 214 |
+
[12, ..., 100]
|
| 215 |
+
...
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
dataset: The original dataset to be wrapped.
|
| 219 |
+
max_length_per_sample (int): The maximum length allowed for each sample.
|
| 220 |
+
packed_length (int): The desired length for each packed sample.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
dataset,
|
| 226 |
+
max_length_per_sample: int = 2048,
|
| 227 |
+
packed_length: int = 4096,
|
| 228 |
+
debug=False,
|
| 229 |
+
):
|
| 230 |
+
assert packed_length % max_length_per_sample == 0
|
| 231 |
+
assert hasattr(dataset, "lengths")
|
| 232 |
+
assert len(getattr(dataset, "lengths")) == len(
|
| 233 |
+
dataset
|
| 234 |
+
), "The dataset must have lengths attribute and have the same length as the dataset"
|
| 235 |
+
self.dataset = dataset
|
| 236 |
+
self.max_length_per_sample = max_length_per_sample
|
| 237 |
+
self.lengths = getattr(self.dataset, "lengths")
|
| 238 |
+
self.bsz = packed_length // max_length_per_sample
|
| 239 |
+
self.packed_length = packed_length
|
| 240 |
+
self.debug = debug
|
| 241 |
+
# Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
|
| 242 |
+
|
| 243 |
+
self.seed = DEFAULT_SEED
|
| 244 |
+
indices = np.arange(len(self.lengths))
|
| 245 |
+
rng = np.random.RandomState(self.seed)
|
| 246 |
+
rng.shuffle(indices)
|
| 247 |
+
self.indices = indices
|
| 248 |
+
self.cum_lens = np.cumsum(self.lengths[self.indices])
|
| 249 |
+
self.num_tokens = sum(self.lengths)
|
| 250 |
+
|
| 251 |
+
def get_dataset_name(self):
|
| 252 |
+
return self.dataset.get_dataset_name()
|
| 253 |
+
|
| 254 |
+
def __len__(self):
|
| 255 |
+
n_packs = self.num_tokens // self.packed_length
|
| 256 |
+
return n_packs
|
| 257 |
+
|
| 258 |
+
def find_offset(self, offset):
|
| 259 |
+
idx = np.searchsorted(self.cum_lens, offset, side="right")
|
| 260 |
+
if idx == 0:
|
| 261 |
+
return idx, offset
|
| 262 |
+
length = offset - self.cum_lens[idx - 1]
|
| 263 |
+
return idx, length
|
| 264 |
+
|
| 265 |
+
def pdebug(self, line):
|
| 266 |
+
if self.debug:
|
| 267 |
+
print(line, flush=True)
|
| 268 |
+
|
| 269 |
+
def __getitem__(self, item: int) -> Dict:
|
| 270 |
+
"""Given the index, it returns a dict as
|
| 271 |
+
{
|
| 272 |
+
'tokens': List[int],
|
| 273 |
+
'cu_seqlens': List[int],
|
| 274 |
+
'indexes': List[int], # denotes positional vector as 'tokens'
|
| 275 |
+
'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
|
| 276 |
+
}
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
start_idx, start_length = self.find_offset(item * self.packed_length)
|
| 280 |
+
end_idx, end_length = self.find_offset((item + 1) * self.packed_length)
|
| 281 |
+
pack_tokens = []
|
| 282 |
+
pack_labels = []
|
| 283 |
+
type_ids = []
|
| 284 |
+
|
| 285 |
+
self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ")
|
| 286 |
+
self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ")
|
| 287 |
+
|
| 288 |
+
if start_idx == end_idx:
|
| 289 |
+
idx = self.indices[start_idx]
|
| 290 |
+
sample = self.dataset[idx]
|
| 291 |
+
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
| 292 |
+
tokens = sample["tokens"][start_length:end_length]
|
| 293 |
+
pack_tokens.extend(tokens)
|
| 294 |
+
pack_labels.extend(tokens[1:] + [-100])
|
| 295 |
+
type_ids.extend([sample["type_id"]] * len(tokens))
|
| 296 |
+
return {
|
| 297 |
+
"tokens": pack_tokens,
|
| 298 |
+
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
| 299 |
+
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
| 300 |
+
"labels": pack_labels,
|
| 301 |
+
"type_ids": type_ids,
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
idx = self.indices[start_idx]
|
| 305 |
+
sample = self.dataset[idx]
|
| 306 |
+
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
| 307 |
+
tokens = sample["tokens"][start_length:]
|
| 308 |
+
pack_tokens.extend(tokens)
|
| 309 |
+
pack_labels.extend(tokens[1:] + [-100])
|
| 310 |
+
type_ids.extend([sample["type_id"]] * len(tokens))
|
| 311 |
+
|
| 312 |
+
for i in range(start_idx + 1, end_idx):
|
| 313 |
+
idx = self.indices[i]
|
| 314 |
+
sample = self.dataset[idx]
|
| 315 |
+
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
| 316 |
+
tokens = sample["tokens"]
|
| 317 |
+
pack_tokens.extend(tokens)
|
| 318 |
+
pack_labels.extend(tokens[1:] + [-100])
|
| 319 |
+
type_ids.extend([sample.get("type_id")] * len(tokens))
|
| 320 |
+
|
| 321 |
+
# corner case, the last sample is useless
|
| 322 |
+
if end_length == 0:
|
| 323 |
+
pass
|
| 324 |
+
else:
|
| 325 |
+
idx = self.indices[end_idx]
|
| 326 |
+
sample = self.dataset[idx]
|
| 327 |
+
self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
|
| 328 |
+
tokens = sample["tokens"][:end_length]
|
| 329 |
+
pack_tokens.extend(tokens)
|
| 330 |
+
pack_labels.extend(tokens[1:] + [-100])
|
| 331 |
+
type_ids.extend([sample.get("type_id")] * len(tokens))
|
| 332 |
+
|
| 333 |
+
return {
|
| 334 |
+
"tokens": pack_tokens,
|
| 335 |
+
"cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
|
| 336 |
+
"indexes": list(range(self.max_length_per_sample)) * self.bsz,
|
| 337 |
+
"labels": pack_labels,
|
| 338 |
+
"type_ids": type_ids,
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def get_packed_dataset_without_short_length(
|
| 343 |
+
folder,
|
| 344 |
+
max_length_per_sample=2048,
|
| 345 |
+
packed_length=4096,
|
| 346 |
+
show_progress=False,
|
| 347 |
+
min_length=50,
|
| 348 |
+
min_length_dict=None,
|
| 349 |
+
pack_into_one_sample=False,
|
| 350 |
+
):
|
| 351 |
+
"""
|
| 352 |
+
Given a folder, combine all the .bin files into a single large dataset.
|
| 353 |
+
And filter out short samples with length less than 'min_length'.
|
| 354 |
+
|
| 355 |
+
Each .bin file is treated as a separate dataset.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
folder (str): Path to the folder containing the .bin files.
|
| 359 |
+
max_length_per_sample (int): Maximum length of each sample.
|
| 360 |
+
packed_length (int): Length to pack samples to.
|
| 361 |
+
show_progress (bool): Whether to show the progress bar.
|
| 362 |
+
min_length (int): The minimum length of the sample.
|
| 363 |
+
min_length_dict (dict): The minimum length of the sample for each dataset.
|
| 364 |
+
The format is something like {'pile-arxiv': 50}
|
| 365 |
+
dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
A packed dataset containing all the data from the .bin files.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
assert os.path.exists(folder), f"{folder} does not exist."
|
| 372 |
+
datasets = []
|
| 373 |
+
delete_samples = 0
|
| 374 |
+
|
| 375 |
+
for root, dirs, files in os.walk(folder, followlinks=True):
|
| 376 |
+
dirs.sort() # Let the folder need to be returned in a fixed order
|
| 377 |
+
if gpc.is_rank_for_log():
|
| 378 |
+
logger.info(f"Reading {root}...")
|
| 379 |
+
num_token_in_folder = 0
|
| 380 |
+
|
| 381 |
+
for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress):
|
| 382 |
+
if fn.endswith(".bin"):
|
| 383 |
+
fp = os.path.join(root, fn)
|
| 384 |
+
catch_ml_keys = []
|
| 385 |
+
min_length_num = min_length
|
| 386 |
+
if min_length_dict is not None:
|
| 387 |
+
for k, v in min_length_dict.items():
|
| 388 |
+
if k in fp:
|
| 389 |
+
min_length_num = v
|
| 390 |
+
catch_ml_keys.append(k)
|
| 391 |
+
assert (
|
| 392 |
+
len(catch_ml_keys) < 2
|
| 393 |
+
), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
|
| 394 |
+
|
| 395 |
+
ds_type_id = get_dataset_type_id(path=fp)
|
| 396 |
+
ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
|
| 397 |
+
|
| 398 |
+
if hasattr(ds, "old_length"):
|
| 399 |
+
delete_samples += ds.old_length - len(ds)
|
| 400 |
+
if len(ds) == 0:
|
| 401 |
+
if gpc.is_rank_for_log():
|
| 402 |
+
logger.info(f"None of the data in `{fp}` is longer than {min_length}")
|
| 403 |
+
continue
|
| 404 |
+
|
| 405 |
+
if pack_into_one_sample:
|
| 406 |
+
ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length)
|
| 407 |
+
else:
|
| 408 |
+
ds = PackedDataset(ds, max_length_per_sample, packed_length)
|
| 409 |
+
|
| 410 |
+
num_token_in_folder += len(ds) * packed_length
|
| 411 |
+
datasets.append(ds)
|
| 412 |
+
|
| 413 |
+
dataset = ConcatDataset(datasets=datasets)
|
| 414 |
+
if gpc.is_rank_for_log():
|
| 415 |
+
logger.info(
|
| 416 |
+
f"Find `{len(datasets)}` datasets, \
|
| 417 |
+
{len(dataset)} samples, \
|
| 418 |
+
delete `{delete_samples}` because of short length",
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
return dataset
|
InternLM/internlm/data/single_dataset.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
A .bin file corresponds to a Dataset instance here.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import mmap
|
| 10 |
+
import os
|
| 11 |
+
import threading
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class JsonlDataset(torch.utils.data.Dataset):
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
JSONL format is expected to roughly follow that of The Pile.
|
| 22 |
+
One-line-per-document of the form:
|
| 23 |
+
```
|
| 24 |
+
{
|
| 25 |
+
"tokens": List[int],
|
| 26 |
+
}
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
Note that only the "tokens" key is used.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, path: str, dataset_type_id: int = 0, min_length=50):
|
| 33 |
+
self.path = path
|
| 34 |
+
self.threadlocal = threading.local()
|
| 35 |
+
resolved_path = Path(path).resolve()
|
| 36 |
+
self.resolved_path = resolved_path
|
| 37 |
+
self.meta = Path(f"{resolved_path}.meta")
|
| 38 |
+
self.type_id = dataset_type_id
|
| 39 |
+
|
| 40 |
+
# only build the cache in on the primary worker to prevent overloading nfs
|
| 41 |
+
assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}"
|
| 42 |
+
try:
|
| 43 |
+
with open(self.meta, "rb") as f:
|
| 44 |
+
meta = np.load(f)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Cannot load file {self.meta}...")
|
| 47 |
+
raise e
|
| 48 |
+
self.offsets = meta[:, 0]
|
| 49 |
+
self.lengths = meta[:, -1]
|
| 50 |
+
|
| 51 |
+
if min_length > 0:
|
| 52 |
+
mask = self.lengths >= min_length
|
| 53 |
+
self.old_lengths = self.lengths.copy()
|
| 54 |
+
self.old_length = len(self.offsets)
|
| 55 |
+
self.offsets = self.offsets[mask]
|
| 56 |
+
self.lengths = self.lengths[mask]
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
f = self._get_mmap()
|
| 60 |
+
position = self.offsets[idx]
|
| 61 |
+
f.seek(position)
|
| 62 |
+
item = f.readline().decode("utf-8")
|
| 63 |
+
try:
|
| 64 |
+
item = json.loads(item)
|
| 65 |
+
item["length"] = len(item["tokens"]) # add a length info
|
| 66 |
+
item["type_id"] = self.type_id
|
| 67 |
+
except Exception as err:
|
| 68 |
+
raise json.decoder.JSONDecodeError(
|
| 69 |
+
doc=self.path,
|
| 70 |
+
pos=position,
|
| 71 |
+
msg=(
|
| 72 |
+
f"Error while loading JSONL line in file {self.path} at byte "
|
| 73 |
+
f"{position}. Contents of line:\n{item}\n{err}"
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
return item
|
| 77 |
+
|
| 78 |
+
def get_dataset_name(self):
|
| 79 |
+
return str(self.resolved_path)
|
| 80 |
+
|
| 81 |
+
def _get_mmap(self):
|
| 82 |
+
if not hasattr(self.threadlocal, "handles"):
|
| 83 |
+
with open(self.path, "rb") as f:
|
| 84 |
+
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
| 85 |
+
self.threadlocal.handles = [f, mm]
|
| 86 |
+
if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"):
|
| 87 |
+
raise NotImplementedError(
|
| 88 |
+
"Compressed files are not supported because .seek() would require "
|
| 89 |
+
"rereading the entire file, making performance too slow."
|
| 90 |
+
)
|
| 91 |
+
return self.threadlocal.handles[-1]
|
| 92 |
+
|
| 93 |
+
def __setstate__(self, state):
|
| 94 |
+
self.__dict__ = state
|
| 95 |
+
self.threadlocal = threading.local()
|
| 96 |
+
|
| 97 |
+
def __getstate__(self):
|
| 98 |
+
d = {}
|
| 99 |
+
for i, v in self.__dict__.items():
|
| 100 |
+
if i != "threadlocal":
|
| 101 |
+
d[i] = v
|
| 102 |
+
return d
|
| 103 |
+
|
| 104 |
+
def __del__(self):
|
| 105 |
+
if hasattr(self.threadlocal, "handles"):
|
| 106 |
+
# cleanup files we opened on initialization
|
| 107 |
+
while self.threadlocal.handles:
|
| 108 |
+
self.threadlocal.handles.pop().close()
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def exists(path):
|
| 112 |
+
return os.path.exists(path)
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
# Virtual length of the dataset depends on the epoch number if the number of documents
|
| 116 |
+
# is not perfectly divisible by the data_subshard_count
|
| 117 |
+
return len(self.offsets)
|
InternLM/internlm/data/utils.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from internlm.core.context import global_context as gpc
|
| 7 |
+
|
| 8 |
+
DATASET_TYPE_IDS_MAP = {"vision": 0}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_dataset_type_id(path):
|
| 12 |
+
import re
|
| 13 |
+
|
| 14 |
+
match_idxes = []
|
| 15 |
+
for key, idx in DATASET_TYPE_IDS_MAP.items():
|
| 16 |
+
if re.search(rf"/[z_]*{key}/", path):
|
| 17 |
+
match_idxes.append(idx)
|
| 18 |
+
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
| 19 |
+
return match_idxes[0]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def unpack_data(input_ids, cu_seqlens):
|
| 23 |
+
"""
|
| 24 |
+
input_ids: (n, packed_length)
|
| 25 |
+
Return:
|
| 26 |
+
output: (batch_size, max_length)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
bsz = input_ids.shape[0]
|
| 30 |
+
|
| 31 |
+
num_sequence = gpc.config.data["micro_bsz"]
|
| 32 |
+
|
| 33 |
+
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
| 34 |
+
|
| 35 |
+
for i in range(bsz):
|
| 36 |
+
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
| 37 |
+
cu_seqlens_slice = cu_seqlens[i]
|
| 38 |
+
for j in range(num_sequence):
|
| 39 |
+
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
|
| 40 |
+
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
| 41 |
+
outputs[i] = output
|
| 42 |
+
|
| 43 |
+
if bsz == 1:
|
| 44 |
+
outputs = outputs.squeeze(0)
|
| 45 |
+
|
| 46 |
+
return outputs
|
InternLM/internlm/initialize/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .initialize_trainer import initialize_trainer, initialize_kd_trainer
|
| 2 |
+
from .launch import (
|
| 3 |
+
get_default_parser,
|
| 4 |
+
initialize_distributed_env,
|
| 5 |
+
launch_from_slurm,
|
| 6 |
+
launch_from_torch,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"get_default_parser",
|
| 11 |
+
"initialize_trainer",
|
| 12 |
+
"initialize_kd_trainer",
|
| 13 |
+
"launch_from_slurm",
|
| 14 |
+
"launch_from_torch",
|
| 15 |
+
"initialize_distributed_env",
|
| 16 |
+
]
|
InternLM/internlm/initialize/initialize_tensor.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def scaled_init_method_normal(sigma: float = 1.0, num_layers: int = 1):
|
| 10 |
+
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
| 11 |
+
std = sigma / math.sqrt(2.0 * num_layers)
|
| 12 |
+
|
| 13 |
+
def init_(tensor):
|
| 14 |
+
return nn.init.normal_(tensor, mean=0.0, std=std)
|
| 15 |
+
|
| 16 |
+
return init_
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def normal_(mean: float = 0.0, std: float = 1.0):
|
| 20 |
+
r"""Return the initializer filling the input Tensor with values drawn from the normal distribution
|
| 21 |
+
|
| 22 |
+
.. math::
|
| 23 |
+
\mathcal{N}(\text{mean}, \text{std}^2)
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
mean (float): the mean of the normal distribution. Defaults 0.0.
|
| 27 |
+
std (float): the standard deviation of the normal distribution. Defaults 1.0.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def initializer(tensor: Tensor):
|
| 31 |
+
return nn.init.normal_(tensor, mean, std)
|
| 32 |
+
|
| 33 |
+
return initializer
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def scaled_init_method_uniform(sigma: float = 1.0, num_layers: int = 1):
|
| 37 |
+
"""Init method based on p(x)=Uniform(-a, a) where std(x)=sigma/sqrt(2*num_layers)."""
|
| 38 |
+
std = sigma / math.sqrt(2.0 * num_layers)
|
| 39 |
+
a = math.sqrt(3.0 * std)
|
| 40 |
+
|
| 41 |
+
def init_(tensor):
|
| 42 |
+
return nn.init.uniform_(tensor, -a, a)
|
| 43 |
+
|
| 44 |
+
return init_
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def uniform_(mean: float = 0.0, std: float = 1.0):
|
| 48 |
+
r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution
|
| 49 |
+
|
| 50 |
+
.. math::
|
| 51 |
+
\mathcal{U}(mean-a, mean+a), where a satisfies \mathcal{U}_{std}=std.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
mean (float): the mean of the uniform distribution. Defaults 0.0.
|
| 55 |
+
std (float): the standard deviation of the uniform distribution. Defaults 1.0.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
a = math.sqrt(3.0 * std)
|
| 59 |
+
|
| 60 |
+
def initializer(tensor: Tensor):
|
| 61 |
+
return nn.init.uniform_(tensor, mean - a, mean + a)
|
| 62 |
+
|
| 63 |
+
return initializer
|
InternLM/internlm/initialize/initialize_trainer.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
|
| 5 |
+
|
| 6 |
+
from typing import Callable, Iterable, List, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn.modules.loss import _Loss
|
| 10 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 11 |
+
from torch.optim.optimizer import Optimizer
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
from internlm.core.context import global_context as gpc
|
| 15 |
+
from internlm.core.context import ParallelMode
|
| 16 |
+
from internlm.core.engine import Engine, KDEngine
|
| 17 |
+
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
| 18 |
+
from internlm.core.scheduler import (InterleavedPipelineScheduler, KDNonPipelineScheduler, KDPipelineScheduler,
|
| 19 |
+
NonPipelineScheduler, PipelineScheduler, SchedulerHook)
|
| 20 |
+
from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
|
| 21 |
+
from internlm.core.trainer import Trainer
|
| 22 |
+
from internlm.data.utils import unpack_data
|
| 23 |
+
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
| 24 |
+
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
| 25 |
+
from internlm.utils.common import get_current_device
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def initialize_kd_trainer(
|
| 29 |
+
model: nn.Module,
|
| 30 |
+
teacher: nn.Module,
|
| 31 |
+
optimizer: Optimizer,
|
| 32 |
+
criterion: Optional[_Loss] = None,
|
| 33 |
+
kd_criterion: Optional[_Loss] = None,
|
| 34 |
+
train_dataloader: Optional[Iterable] = None,
|
| 35 |
+
test_dataloader: Optional[Iterable] = None,
|
| 36 |
+
lr_scheduler: Optional[_LRScheduler] = None,
|
| 37 |
+
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
| 38 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 39 |
+
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
| 40 |
+
"""Core function to wrap the essential training components with our functionality based on the config which is
|
| 41 |
+
loaded into gpc.config.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
|
| 45 |
+
optimizer (:class:`BaseOptimizer`): Your optimizer for training.
|
| 46 |
+
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
| 47 |
+
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
| 48 |
+
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
| 49 |
+
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
|
| 53 |
+
A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
|
| 54 |
+
where only ``trainer`` could not be None.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
if isinstance(model, nn.Module):
|
| 58 |
+
# first sync model across dp ranks
|
| 59 |
+
model.to(get_current_device())
|
| 60 |
+
elif isinstance(model, Callable):
|
| 61 |
+
model = model().to(get_current_device())
|
| 62 |
+
|
| 63 |
+
# clip grad norm
|
| 64 |
+
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
| 65 |
+
|
| 66 |
+
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
| 67 |
+
|
| 68 |
+
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
| 69 |
+
if gpc.is_using_pp():
|
| 70 |
+
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
| 71 |
+
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
| 72 |
+
gradient_handlers = []
|
| 73 |
+
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
| 74 |
+
for config in gradient_handler_cfg:
|
| 75 |
+
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
| 76 |
+
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
| 77 |
+
gradient_handlers.append(handler)
|
| 78 |
+
|
| 79 |
+
# initialize scheduler for trainer
|
| 80 |
+
scheduler = None
|
| 81 |
+
if gpc.config.model.use_flash_attn:
|
| 82 |
+
data_fn = None
|
| 83 |
+
else:
|
| 84 |
+
data_fn = unpack_data
|
| 85 |
+
if gpc.is_using_pp():
|
| 86 |
+
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
| 87 |
+
tensor_shape = get_tensor_shape()
|
| 88 |
+
use_interleaved = (
|
| 89 |
+
hasattr(gpc.config, "model") and hasattr(gpc.config.model,
|
| 90 |
+
"num_chunks") and gpc.config.model.num_chunks > 1
|
| 91 |
+
)
|
| 92 |
+
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
| 93 |
+
if use_interleaved:
|
| 94 |
+
raise NotImplementedError('InterleavedPipelineScheduler for KD is not implemented')
|
| 95 |
+
|
| 96 |
+
else:
|
| 97 |
+
scheduler = KDPipelineScheduler(
|
| 98 |
+
data_process_func=data_fn,
|
| 99 |
+
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
| 100 |
+
dtype=gpc.config.model["dtype"],
|
| 101 |
+
tensor_shape=tensor_shape,
|
| 102 |
+
scatter_gather_tensors=scatter_gather,
|
| 103 |
+
scheduler_hooks=scheduler_hooks,
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
scheduler = KDNonPipelineScheduler(
|
| 107 |
+
data_process_func=data_fn,
|
| 108 |
+
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
| 109 |
+
scheduler_hooks=scheduler_hooks,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# initialize engine for trainer
|
| 113 |
+
engine = KDEngine(
|
| 114 |
+
model=model,
|
| 115 |
+
teacher=teacher,
|
| 116 |
+
optimizer=optimizer,
|
| 117 |
+
lr_scheduler=lr_scheduler,
|
| 118 |
+
beta2_scheduler=beta2_scheduler,
|
| 119 |
+
criterion=criterion,
|
| 120 |
+
kd_criterion=kd_criterion,
|
| 121 |
+
gradient_handlers=gradient_handlers,
|
| 122 |
+
clip_grad_norm=clip_grad_norm,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
trainer = Trainer(engine, scheduler)
|
| 126 |
+
|
| 127 |
+
return trainer, train_dataloader, test_dataloader, lr_scheduler
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def initialize_trainer(
|
| 131 |
+
model: nn.Module,
|
| 132 |
+
optimizer: Optimizer,
|
| 133 |
+
criterion: Optional[_Loss] = None,
|
| 134 |
+
train_dataloader: Optional[Iterable] = None,
|
| 135 |
+
test_dataloader: Optional[Iterable] = None,
|
| 136 |
+
lr_scheduler: Optional[_LRScheduler] = None,
|
| 137 |
+
beta2_scheduler: Optional[Beta2Scheduler] = None,
|
| 138 |
+
scheduler_hooks: Optional[List[SchedulerHook]] = None,
|
| 139 |
+
) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
|
| 140 |
+
"""Core function to wrap the essential training components with our functionality based on the config which is
|
| 141 |
+
loaded into gpc.config.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
|
| 145 |
+
optimizer (:class:`BaseOptimizer`): Your optimizer for training.
|
| 146 |
+
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
|
| 147 |
+
train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
|
| 148 |
+
test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
|
| 149 |
+
lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
|
| 153 |
+
A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
|
| 154 |
+
where only ``trainer`` could not be None.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
if isinstance(model, nn.Module):
|
| 158 |
+
# first sync model across dp ranks
|
| 159 |
+
model.to(get_current_device())
|
| 160 |
+
elif isinstance(model, Callable):
|
| 161 |
+
model = model().to(get_current_device())
|
| 162 |
+
|
| 163 |
+
# clip grad norm
|
| 164 |
+
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
| 165 |
+
|
| 166 |
+
assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
|
| 167 |
+
|
| 168 |
+
# gradient handler, only support PipelineSharedModuleGradientHandler now
|
| 169 |
+
if gpc.is_using_pp():
|
| 170 |
+
gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
|
| 171 |
+
gradient_handler_cfg = gpc.config.get("gradient_handler", [])
|
| 172 |
+
gradient_handlers = []
|
| 173 |
+
assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
|
| 174 |
+
for config in gradient_handler_cfg:
|
| 175 |
+
if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
|
| 176 |
+
handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
|
| 177 |
+
gradient_handlers.append(handler)
|
| 178 |
+
|
| 179 |
+
# initialize scheduler for trainer
|
| 180 |
+
scheduler = None
|
| 181 |
+
if gpc.config.model.use_flash_attn:
|
| 182 |
+
data_fn = None
|
| 183 |
+
else:
|
| 184 |
+
data_fn = unpack_data
|
| 185 |
+
if gpc.is_using_pp():
|
| 186 |
+
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
| 187 |
+
tensor_shape = get_tensor_shape()
|
| 188 |
+
use_interleaved = (
|
| 189 |
+
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
| 190 |
+
)
|
| 191 |
+
scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
|
| 192 |
+
if use_interleaved:
|
| 193 |
+
if isinstance(model, nn.Sequential):
|
| 194 |
+
model = nn.ModuleList([model])
|
| 195 |
+
|
| 196 |
+
communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
|
| 197 |
+
scheduler = InterleavedPipelineScheduler(
|
| 198 |
+
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
| 199 |
+
num_chunks=gpc.config.model.num_chunks,
|
| 200 |
+
dtype=gpc.config.model["dtype"],
|
| 201 |
+
tensor_shape=tensor_shape,
|
| 202 |
+
scatter_gather_tensors=scatter_gather,
|
| 203 |
+
scheduler_hooks=scheduler_hooks,
|
| 204 |
+
communication_overlap=communication_overlap,
|
| 205 |
+
)
|
| 206 |
+
else:
|
| 207 |
+
scheduler = PipelineScheduler(
|
| 208 |
+
data_process_func=data_fn,
|
| 209 |
+
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
| 210 |
+
dtype=gpc.config.model["dtype"],
|
| 211 |
+
tensor_shape=tensor_shape,
|
| 212 |
+
scatter_gather_tensors=scatter_gather,
|
| 213 |
+
scheduler_hooks=scheduler_hooks,
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
scheduler = NonPipelineScheduler(
|
| 217 |
+
data_process_func=data_fn,
|
| 218 |
+
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
| 219 |
+
scheduler_hooks=scheduler_hooks,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# initialize engine for trainer
|
| 223 |
+
engine = Engine(
|
| 224 |
+
model=model,
|
| 225 |
+
optimizer=optimizer,
|
| 226 |
+
lr_scheduler=lr_scheduler,
|
| 227 |
+
beta2_scheduler=beta2_scheduler,
|
| 228 |
+
criterion=criterion,
|
| 229 |
+
gradient_handlers=gradient_handlers,
|
| 230 |
+
clip_grad_norm=clip_grad_norm,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
trainer = Trainer(engine, scheduler)
|
| 234 |
+
|
| 235 |
+
return trainer, train_dataloader, test_dataloader, lr_scheduler
|
InternLM/internlm/initialize/launch.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from internlm.core.context import Config
|
| 12 |
+
from internlm.core.context import global_context as gpc
|
| 13 |
+
from internlm.monitor import initialize_light_monitor
|
| 14 |
+
from internlm.utils.common import get_master_node
|
| 15 |
+
from internlm.utils.logger import get_logger
|
| 16 |
+
from internlm.utils.timeout import llm_timeout
|
| 17 |
+
|
| 18 |
+
logger = get_logger(__file__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_default_parser():
|
| 22 |
+
"""Reads user command line and uses an argument parser to parse the input arguments.
|
| 23 |
+
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser.
|
| 27 |
+
"""
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("--config", type=str, help="path to the config file")
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--launcher",
|
| 32 |
+
type=str,
|
| 33 |
+
default="slurm",
|
| 34 |
+
choices=["slurm", "torch"],
|
| 35 |
+
help="launcher for launching distributed environment",
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--host", type=str, help="the master address for distributed training")
|
| 38 |
+
parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
|
| 39 |
+
parser.add_argument("--world_size", type=int, help="world size for distributed training")
|
| 40 |
+
parser.add_argument("--rank", type=int, help="rank for the default process group")
|
| 41 |
+
parser.add_argument("--local_rank", type=int, help="local rank on the node")
|
| 42 |
+
parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
|
| 43 |
+
parser.add_argument("--seed", type=int, default=1024)
|
| 44 |
+
parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
|
| 45 |
+
return parser
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def args_sanity_check():
|
| 49 |
+
assert gpc.config is not None, "config is not load!"
|
| 50 |
+
|
| 51 |
+
# the default model type is INTERNLM
|
| 52 |
+
if "model_type" not in gpc.config:
|
| 53 |
+
gpc.config._add_item("model_type", "INTERNLM")
|
| 54 |
+
|
| 55 |
+
# procssing the parallel config in gpc
|
| 56 |
+
if "zero1" not in gpc.config.parallel:
|
| 57 |
+
gpc.config.parallel._add_item("zero1", -1)
|
| 58 |
+
|
| 59 |
+
if "pipeline" not in gpc.config.parallel:
|
| 60 |
+
gpc.config.parallel._add_item("pipeline", 1)
|
| 61 |
+
|
| 62 |
+
if "tensor" not in gpc.config.parallel:
|
| 63 |
+
gpc.config.parallel._add_item("tensor", 1)
|
| 64 |
+
|
| 65 |
+
# processing the data config in gpc
|
| 66 |
+
data = gpc.config.data
|
| 67 |
+
|
| 68 |
+
assert data.seq_len is not None, "'seq_len' must be given a value"
|
| 69 |
+
assert data.micro_bsz is not None, "'micro_bsz' must be given a value"
|
| 70 |
+
|
| 71 |
+
if "packed_length" in data and gpc.is_rank_for_log():
|
| 72 |
+
logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.")
|
| 73 |
+
|
| 74 |
+
data._add_item("packed_length", data.seq_len * data.micro_bsz)
|
| 75 |
+
|
| 76 |
+
if "micro_num" not in data:
|
| 77 |
+
data._add_item("micro_num", 1)
|
| 78 |
+
|
| 79 |
+
data._add_item("gradient_accumulation", data.micro_num)
|
| 80 |
+
if gpc.is_rank_for_log():
|
| 81 |
+
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
| 82 |
+
|
| 83 |
+
# batch_size should be equal with micro_num, should not use it directly
|
| 84 |
+
data._add_item("batch_size", data.micro_num)
|
| 85 |
+
|
| 86 |
+
if "min_length" not in data:
|
| 87 |
+
data._add_item("min_length", 0)
|
| 88 |
+
|
| 89 |
+
if "train_folder" not in data:
|
| 90 |
+
data._add_item("train_folder", None)
|
| 91 |
+
|
| 92 |
+
if "valid_folder" not in data:
|
| 93 |
+
data._add_item("valid_folder", None)
|
| 94 |
+
|
| 95 |
+
if "valid_micro_num" not in data:
|
| 96 |
+
data._add_item("valid_micro_num", data.micro_num)
|
| 97 |
+
|
| 98 |
+
if "valid_every" not in data:
|
| 99 |
+
data._add_item("valid_every", 0)
|
| 100 |
+
|
| 101 |
+
if "empty_cache_and_diag_interval" not in data:
|
| 102 |
+
data._add_item("empty_cache_and_diag_interval", 50)
|
| 103 |
+
|
| 104 |
+
if "diag_outlier_ratio" not in data:
|
| 105 |
+
data._add_item("diag_outlier_ratio", 1.1)
|
| 106 |
+
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
|
| 107 |
+
|
| 108 |
+
if gpc.is_rank_for_log():
|
| 109 |
+
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
|
| 110 |
+
logger.info(f"seq_len: {data.seq_len}")
|
| 111 |
+
logger.info(f"micro_num: {data.micro_num}")
|
| 112 |
+
logger.info(f"micro_bsz: {data.micro_bsz}")
|
| 113 |
+
logger.info(f"packed_length: {data.packed_length}")
|
| 114 |
+
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
|
| 115 |
+
logger.info(f"min_length: {data.min_length}")
|
| 116 |
+
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
| 117 |
+
logger.info(f"valid_every: {data.valid_every}")
|
| 118 |
+
|
| 119 |
+
# processing the checkpoint config
|
| 120 |
+
ckpt = gpc.config.ckpt
|
| 121 |
+
if "enable_save_ckpt" not in ckpt:
|
| 122 |
+
ckpt._add_item("enable_save_ckpt", True)
|
| 123 |
+
|
| 124 |
+
# Saving checkpoint args.
|
| 125 |
+
if ckpt.enable_save_ckpt:
|
| 126 |
+
assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
|
| 127 |
+
assert ckpt.checkpoint_every > 0
|
| 128 |
+
assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
|
| 129 |
+
|
| 130 |
+
if "async_upload" not in ckpt:
|
| 131 |
+
ckpt._add_item("async_upload", False) # async defalut is False.
|
| 132 |
+
else:
|
| 133 |
+
if ckpt.async_upload:
|
| 134 |
+
assert "save_ckpt_folder" in ckpt
|
| 135 |
+
if "boto3:" not in ckpt.save_ckpt_folder:
|
| 136 |
+
if gpc.is_rank_for_log():
|
| 137 |
+
logger.warning(
|
| 138 |
+
"Storing ckpt on file system does not support asynchronous storage, will use sync save!"
|
| 139 |
+
)
|
| 140 |
+
ckpt.async_upload = False
|
| 141 |
+
else:
|
| 142 |
+
if "async_upload_tmp_folder" not in ckpt:
|
| 143 |
+
ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
|
| 144 |
+
|
| 145 |
+
if not ckpt.async_upload:
|
| 146 |
+
ckpt._add_item("async_upload_tmp_folder", None)
|
| 147 |
+
|
| 148 |
+
if "oss_snapshot_freq" not in ckpt:
|
| 149 |
+
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
| 150 |
+
else:
|
| 151 |
+
ckpt._add_item("checkpoint_every", float("inf"))
|
| 152 |
+
ckpt._add_item("oss_snapshot_freq", float("inf"))
|
| 153 |
+
ckpt._add_item("save_ckpt_folder", None)
|
| 154 |
+
ckpt._add_item("async_upload", False)
|
| 155 |
+
ckpt._add_item("async_upload_tmp_folder", None)
|
| 156 |
+
ckpt._add_item("snapshot_ckpt_folder", None)
|
| 157 |
+
|
| 158 |
+
if "load_ckpt_folder" not in ckpt:
|
| 159 |
+
ckpt._add_item("load_ckpt_folder", None)
|
| 160 |
+
|
| 161 |
+
if "stop_file_path" not in ckpt:
|
| 162 |
+
ckpt._add_item("stop_file_path", None)
|
| 163 |
+
|
| 164 |
+
if "auto_resume" not in ckpt:
|
| 165 |
+
# If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
|
| 166 |
+
# to auto-load latest checkpoint.
|
| 167 |
+
ckpt._add_item("auto_resume", True)
|
| 168 |
+
|
| 169 |
+
if gpc.is_rank_for_log():
|
| 170 |
+
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
| 171 |
+
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
| 172 |
+
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
| 173 |
+
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
| 174 |
+
|
| 175 |
+
# tensorboard writer config
|
| 176 |
+
if "enable_tb" not in gpc.config:
|
| 177 |
+
gpc.config._add_item("enable_tb", True)
|
| 178 |
+
if "tensorboard_folder" not in gpc.config:
|
| 179 |
+
gpc.config._add_item(
|
| 180 |
+
"tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
|
| 181 |
+
)
|
| 182 |
+
if "resume_tb_folder" not in gpc.config:
|
| 183 |
+
gpc.config._add_item(
|
| 184 |
+
"resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
if gpc.is_rank_for_log():
|
| 188 |
+
logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
|
| 189 |
+
logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
|
| 190 |
+
|
| 191 |
+
# cudnn
|
| 192 |
+
torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
|
| 193 |
+
torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
|
| 194 |
+
clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
|
| 195 |
+
|
| 196 |
+
if gpc.is_rank_for_log():
|
| 197 |
+
logger.info("+" * 15 + " Other Info " + "+" * 15) # pylint: disable=W1201
|
| 198 |
+
logger.info(f"cudnn.benchmark: {torch.backends.cudnn.benchmark }")
|
| 199 |
+
logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
|
| 200 |
+
logger.info(f"clip_grad_norm: {clip_grad_norm}")
|
| 201 |
+
|
| 202 |
+
model = gpc.config.model
|
| 203 |
+
if "dtype" not in model:
|
| 204 |
+
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
| 205 |
+
model._add_item("dtype", torch.float16)
|
| 206 |
+
else:
|
| 207 |
+
if gpc.config.model.dtype == "torch.bfloat16":
|
| 208 |
+
gpc.config.model.dtype = torch.bfloat16
|
| 209 |
+
elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
|
| 210 |
+
gpc.config.model.dtype = torch.float16
|
| 211 |
+
elif gpc.config.model.dtype == "torch.float32":
|
| 212 |
+
gpc.config.model.dtype = torch.float32
|
| 213 |
+
elif gpc.config.model.dtype == "torch.tf32":
|
| 214 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 215 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 216 |
+
gpc.config.model.dtype = torch.float32
|
| 217 |
+
else:
|
| 218 |
+
assert gpc.config.model.dtype in [
|
| 219 |
+
"torch.float16",
|
| 220 |
+
"torch.half",
|
| 221 |
+
"torch.bfloat16",
|
| 222 |
+
"torch.float32",
|
| 223 |
+
"torch.tf32",
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
if "checkpoint" in model:
|
| 227 |
+
if model.checkpoint is True:
|
| 228 |
+
model.checkpoint = 1
|
| 229 |
+
elif model.checkpoint is False:
|
| 230 |
+
model.checkpoint = 0
|
| 231 |
+
else:
|
| 232 |
+
assert (
|
| 233 |
+
model.checkpoint >= 0 and model.checkpoint <= 1
|
| 234 |
+
), f'model.checkpoint: "{model.checkpoint}" should >=0 and <=1'
|
| 235 |
+
|
| 236 |
+
if "teacher" in gpc.config:
|
| 237 |
+
teacher = gpc.config.teacher
|
| 238 |
+
if "dtype" not in teacher:
|
| 239 |
+
logger.warning("dtype is not set, use torch.float16 by defalut!")
|
| 240 |
+
teacher._add_item("dtype", torch.float16)
|
| 241 |
+
else:
|
| 242 |
+
if gpc.config.teacher.dtype == "torch.bfloat16":
|
| 243 |
+
gpc.config.teacher.dtype = torch.bfloat16
|
| 244 |
+
elif gpc.config.teacher.dtype in ("torch.float16", "torch.half"):
|
| 245 |
+
gpc.config.teacher.dtype = torch.float16
|
| 246 |
+
elif gpc.config.teacher.dtype == "torch.float32":
|
| 247 |
+
gpc.config.teacher.dtype = torch.float32
|
| 248 |
+
elif gpc.config.teacher.dtype == "torch.tf32":
|
| 249 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 250 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 251 |
+
gpc.config.teacher.dtype = torch.float32
|
| 252 |
+
else:
|
| 253 |
+
assert gpc.config.teacher.dtype in [
|
| 254 |
+
"torch.float16",
|
| 255 |
+
"torch.half",
|
| 256 |
+
"torch.bfloat16",
|
| 257 |
+
"torch.float32",
|
| 258 |
+
"torch.tf32",
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
if "checkpoint" in teacher:
|
| 262 |
+
if teacher.checkpoint is True:
|
| 263 |
+
teacher.checkpoint = 1
|
| 264 |
+
elif teacher.checkpoint is False:
|
| 265 |
+
teacher.checkpoint = 0
|
| 266 |
+
else:
|
| 267 |
+
assert (
|
| 268 |
+
teacher.checkpoint >= 0 and teacher.checkpoint <= 1
|
| 269 |
+
), f'teacher.checkpoint: "{teacher.checkpoint}" should >=0 and <=1'
|
| 270 |
+
|
| 271 |
+
if gpc.is_rank_for_log():
|
| 272 |
+
logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201
|
| 273 |
+
logger.info(f"Model: {gpc.config.model}")
|
| 274 |
+
|
| 275 |
+
logger.info("+" * 15 + " grad_scaler Info " + "+" * 15) # pylint: disable=W1201
|
| 276 |
+
logger.info(f"grad_scaler: {gpc.config.grad_scaler}")
|
| 277 |
+
|
| 278 |
+
logger.info("+" * 15 + " hybrid_zero_optimizer Info " + "+" * 15) # pylint: disable=W1201
|
| 279 |
+
logger.info(f"hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}")
|
| 280 |
+
|
| 281 |
+
logger.info("+" * 15 + " adam Info " + "+" * 15) # pylint: disable=W1201
|
| 282 |
+
logger.info(f"adam: {gpc.config.adam}")
|
| 283 |
+
|
| 284 |
+
logger.info("+" * 15 + " beta2_scheduler Info " + "+" * 15) # pylint: disable=W1201
|
| 285 |
+
logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}")
|
| 286 |
+
|
| 287 |
+
# process the model config
|
| 288 |
+
if "use_flash_attn" not in gpc.config.model:
|
| 289 |
+
gpc.config.model._add_item("use_flash_attn", True)
|
| 290 |
+
|
| 291 |
+
# process the parallel config
|
| 292 |
+
if "sequence_parallel" not in gpc.config.parallel:
|
| 293 |
+
gpc.config.parallel._add_item("sequence_parallel", False)
|
| 294 |
+
else:
|
| 295 |
+
assert not (
|
| 296 |
+
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
| 297 |
+
), "sequence parallel does not support use_flash_attn=False"
|
| 298 |
+
|
| 299 |
+
# monitoring default config
|
| 300 |
+
monitor_default_config = {
|
| 301 |
+
"alert_address": None, # compatible with old alert config
|
| 302 |
+
"monitor": { # new monitoring config
|
| 303 |
+
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
|
| 304 |
+
},
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
for key, value in monitor_default_config.items():
|
| 308 |
+
if key not in gpc.config:
|
| 309 |
+
gpc.config._add_item(key, value)
|
| 310 |
+
|
| 311 |
+
alert = gpc.config.monitor.alert
|
| 312 |
+
|
| 313 |
+
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
|
| 314 |
+
logger.warning("alert is enable but alert_address is not set")
|
| 315 |
+
|
| 316 |
+
optim_ckpt = gpc.config.hybrid_zero_optimizer
|
| 317 |
+
if "zero_overlap_communication" in optim_ckpt:
|
| 318 |
+
# Compatible with the old interfaces.
|
| 319 |
+
optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
|
| 320 |
+
if "overlap_sync_grad" not in optim_ckpt:
|
| 321 |
+
optim_ckpt._add_item("overlap_sync_grad", False)
|
| 322 |
+
if "overlap_sync_param" not in optim_ckpt:
|
| 323 |
+
optim_ckpt._add_item("overlap_sync_param", False)
|
| 324 |
+
if gpc.is_rank_for_log():
|
| 325 |
+
logger.info(
|
| 326 |
+
f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def launch(
|
| 331 |
+
config: Union[str, Path, Config, Dict],
|
| 332 |
+
rank: int,
|
| 333 |
+
world_size: int,
|
| 334 |
+
host: str,
|
| 335 |
+
port: int,
|
| 336 |
+
backend: str = "nccl",
|
| 337 |
+
local_rank: int = None,
|
| 338 |
+
seed: int = 1024,
|
| 339 |
+
):
|
| 340 |
+
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
|
| 341 |
+
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
| 345 |
+
rank (int): Rank for the default process group
|
| 346 |
+
world_size (int): World size of the default process group
|
| 347 |
+
host (str): The master address for distributed training
|
| 348 |
+
port (str): The master port for distributed training
|
| 349 |
+
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
| 350 |
+
local_rank (int, optional):
|
| 351 |
+
Rank for the process on the node and is used to set the default CUDA device,
|
| 352 |
+
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
|
| 353 |
+
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
| 354 |
+
|
| 355 |
+
Raises:
|
| 356 |
+
Exception: Raise exception when config type is wrong
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
# set config
|
| 360 |
+
assert isinstance(
|
| 361 |
+
config, (Config, str, Path, dict)
|
| 362 |
+
), f"expected argument config to be Config, str or Path, but got {type(config)}"
|
| 363 |
+
if not isinstance(config, Config) and isinstance(config, dict):
|
| 364 |
+
config = Config(config)
|
| 365 |
+
if isinstance(config, (str, Path)):
|
| 366 |
+
config = Config.from_file(config)
|
| 367 |
+
gpc.load_config(config)
|
| 368 |
+
|
| 369 |
+
# init default process group
|
| 370 |
+
gpc.init_global_dist(rank, world_size, backend, host, port)
|
| 371 |
+
|
| 372 |
+
# init process groups for different parallel modes from config
|
| 373 |
+
gpc.init_parallel_groups()
|
| 374 |
+
|
| 375 |
+
# set cuda device
|
| 376 |
+
if torch.cuda.is_available():
|
| 377 |
+
# if local rank is not given, calculate automatically
|
| 378 |
+
gpc.set_device(local_rank)
|
| 379 |
+
|
| 380 |
+
# set the number of processes running on the same node
|
| 381 |
+
gpc.detect_num_processes_on_current_node()
|
| 382 |
+
|
| 383 |
+
gpc.set_seed(seed)
|
| 384 |
+
|
| 385 |
+
if gpc.is_rank_for_log():
|
| 386 |
+
logger.info(
|
| 387 |
+
f"Distributed environment is initialized, "
|
| 388 |
+
f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
|
| 389 |
+
f"tensor parallel size: {gpc.tensor_parallel_size}",
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def launch_from_slurm(
|
| 394 |
+
config: Union[str, Path, Config, Dict],
|
| 395 |
+
host: str,
|
| 396 |
+
port: int,
|
| 397 |
+
backend: str = "nccl",
|
| 398 |
+
seed: int = 1024,
|
| 399 |
+
):
|
| 400 |
+
"""A wrapper for internlm.launch for SLURM launcher by reading rank and world size from the environment variables
|
| 401 |
+
set by SLURM
|
| 402 |
+
|
| 403 |
+
Args:
|
| 404 |
+
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
| 405 |
+
host (str): The master address for distributed training
|
| 406 |
+
port (str): The master port for distributed training
|
| 407 |
+
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
| 408 |
+
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
| 409 |
+
"""
|
| 410 |
+
try:
|
| 411 |
+
rank = int(os.environ["SLURM_PROCID"])
|
| 412 |
+
world_size = int(os.environ["SLURM_NPROCS"])
|
| 413 |
+
except KeyError as e:
|
| 414 |
+
raise RuntimeError(f"Could not find {e} in the SLURM environment")
|
| 415 |
+
|
| 416 |
+
launch(
|
| 417 |
+
config=config,
|
| 418 |
+
rank=rank,
|
| 419 |
+
world_size=world_size,
|
| 420 |
+
host=host,
|
| 421 |
+
port=port,
|
| 422 |
+
backend=backend,
|
| 423 |
+
seed=seed,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def launch_from_torch(
|
| 428 |
+
config: Union[str, Path, Config, Dict],
|
| 429 |
+
backend: str = "nccl",
|
| 430 |
+
seed: int = 1024,
|
| 431 |
+
):
|
| 432 |
+
"""A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
|
| 433 |
+
from the environment variables set by PyTorch
|
| 434 |
+
|
| 435 |
+
Args:
|
| 436 |
+
config (Union[str, dict, Config]): Config file or config file path are both acceptable
|
| 437 |
+
backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
|
| 438 |
+
seed (int, optional): Specified random seed for every process. Defaults to 1024.
|
| 439 |
+
"""
|
| 440 |
+
try:
|
| 441 |
+
rank = int(os.environ["RANK"])
|
| 442 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 443 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 444 |
+
host = os.environ["MASTER_ADDR"]
|
| 445 |
+
port = int(os.environ["MASTER_PORT"])
|
| 446 |
+
except KeyError as e:
|
| 447 |
+
raise RuntimeError(f"Could not find {e} in the torch environment")
|
| 448 |
+
|
| 449 |
+
launch(
|
| 450 |
+
config=config,
|
| 451 |
+
local_rank=local_rank,
|
| 452 |
+
rank=rank,
|
| 453 |
+
world_size=world_size,
|
| 454 |
+
host=host,
|
| 455 |
+
port=port,
|
| 456 |
+
backend=backend,
|
| 457 |
+
seed=seed,
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
@llm_timeout(func_name="initialize_distributed_env")
|
| 462 |
+
def initialize_distributed_env(
|
| 463 |
+
config: str,
|
| 464 |
+
launcher: str = "slurm",
|
| 465 |
+
master_port: int = 8888,
|
| 466 |
+
seed: int = 1024,
|
| 467 |
+
args_check=True,
|
| 468 |
+
):
|
| 469 |
+
"""
|
| 470 |
+
Initialize distributed environment for distributed training.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
config (str): Config file path.
|
| 474 |
+
launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
|
| 475 |
+
master_port (str): The master port for distributed training. 8888 by default.
|
| 476 |
+
seed (int, optional): Specified random seed for every process. 1024 by default.
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
torch.cuda.empty_cache()
|
| 480 |
+
|
| 481 |
+
if launcher == "torch":
|
| 482 |
+
launch_from_torch(config=config, seed=seed)
|
| 483 |
+
elif launcher == "slurm":
|
| 484 |
+
launch_from_slurm(
|
| 485 |
+
config=config,
|
| 486 |
+
host=get_master_node(),
|
| 487 |
+
port=master_port,
|
| 488 |
+
seed=seed,
|
| 489 |
+
)
|
| 490 |
+
else:
|
| 491 |
+
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
| 492 |
+
|
| 493 |
+
if args_check:
|
| 494 |
+
args_sanity_check()
|
| 495 |
+
|
| 496 |
+
# init light monitor client
|
| 497 |
+
alert_config = gpc.config.monitor.alert
|
| 498 |
+
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
|
| 499 |
+
light_monitor_address = alert_config.light_monitor_address
|
| 500 |
+
if light_monitor_address:
|
| 501 |
+
initialize_light_monitor(light_monitor_address)
|
| 502 |
+
else:
|
| 503 |
+
logger.warning("monitor address is none, monitor could not be used!")
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def get_config_value(config, key, defalut):
|
| 507 |
+
try:
|
| 508 |
+
value = config[key]
|
| 509 |
+
except KeyError:
|
| 510 |
+
value = defalut
|
| 511 |
+
return value
|
InternLM/internlm/initialize/legacy/__init__.py
ADDED
|
File without changes
|
InternLM/internlm/initialize/legacy/launch.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from internlm.initialize.launch import get_config_value
|
| 5 |
+
from internlm.utils.logger import get_logger
|
| 6 |
+
|
| 7 |
+
logger = get_logger(__file__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def auto_resume_sanity_check(ckpt_config):
|
| 11 |
+
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
|
| 12 |
+
if load_given_ckpt is None:
|
| 13 |
+
return True # default value is True
|
| 14 |
+
else:
|
| 15 |
+
return not load_given_ckpt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def ckpt_info_sanity_check(ckpt_config):
|
| 19 |
+
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
|
| 20 |
+
|
| 21 |
+
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
|
| 22 |
+
|
| 23 |
+
if load_model_only_folder is not None:
|
| 24 |
+
assert (
|
| 25 |
+
load_ckpt_folder is None
|
| 26 |
+
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
| 27 |
+
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
| 28 |
+
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
|
| 29 |
+
else:
|
| 30 |
+
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
|
| 31 |
+
|
| 32 |
+
if isinstance(load_ckpt_folder, str):
|
| 33 |
+
if load_optimizer:
|
| 34 |
+
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
|
| 35 |
+
else:
|
| 36 |
+
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
|
| 37 |
+
elif load_ckpt_folder is None:
|
| 38 |
+
return None
|
| 39 |
+
else:
|
| 40 |
+
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
|
InternLM/internlm/model/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from .embedding import Embedding1D, RotaryEmbedding
|
| 5 |
+
from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
|
| 6 |
+
from .metrics import AccPerplex
|
| 7 |
+
from .modeling_internlm import build_model_with_cfg
|
| 8 |
+
from .modeling_vit import build_vit_model_with_cfg
|
| 9 |
+
from .multi_head_attention import MHA
|
| 10 |
+
from .utils import gather_forward_split_backward
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"Embedding1D",
|
| 14 |
+
"FeedForward",
|
| 15 |
+
"RotaryEmbedding",
|
| 16 |
+
"RewardModelLinear",
|
| 17 |
+
"ScaleColumnParallelLinear",
|
| 18 |
+
"AccPerplex",
|
| 19 |
+
"MHA",
|
| 20 |
+
"gather_forward_split_backward",
|
| 21 |
+
"build_model_with_cfg",
|
| 22 |
+
"build_vit_model_with_cfg"
|
| 23 |
+
]
|
InternLM/internlm/model/embedding.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
|
| 6 |
+
import rotary_emb
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
|
| 11 |
+
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
|
| 12 |
+
from torch import Tensor, nn
|
| 13 |
+
|
| 14 |
+
from internlm.core.context import ParallelMode
|
| 15 |
+
from internlm.core.context import global_context as gpc
|
| 16 |
+
|
| 17 |
+
from .utils import gather_forward_split_backward, split_forward_gather_backward
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
from .muse import VQGANModel
|
| 21 |
+
|
| 22 |
+
class Embedding1DLVM(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
vq_model_path: str,
|
| 26 |
+
embedding_dim: int = None,
|
| 27 |
+
freeze_vq_model: bool = True
|
| 28 |
+
):
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
self.vq_model = VQGANModel.from_pretrained(vq_model_path)
|
| 32 |
+
if freeze_vq_model:
|
| 33 |
+
self.vq_model.requires_grad_(False)
|
| 34 |
+
self.vq_model.eval()
|
| 35 |
+
|
| 36 |
+
self.num_embeddings, vq_embed_dim = self.vq_model.quantize.embedding.weight.shape
|
| 37 |
+
|
| 38 |
+
if embedding_dim is not None:
|
| 39 |
+
self.embed_proj = nn.Linear(vq_embed_dim, embedding_dim, bias=False)
|
| 40 |
+
self.embedding_dim = embedding_dim
|
| 41 |
+
else:
|
| 42 |
+
self.embed_proj = None
|
| 43 |
+
self.embedding_dim = vq_embed_dim
|
| 44 |
+
|
| 45 |
+
def forward(self, input_: Tensor) -> Tensor:
|
| 46 |
+
|
| 47 |
+
# input: N x seq
|
| 48 |
+
output_parallel = self.vq_model.quantize.get_codebook_entry_for_lvm(input_) # N x vq_embed_dim x sqrt(seq) x sqrt(seq)
|
| 49 |
+
|
| 50 |
+
if self.embed_proj is not None:
|
| 51 |
+
output_parallel = self.embed_proj(output_parallel)
|
| 52 |
+
|
| 53 |
+
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
| 54 |
+
|
| 55 |
+
if gpc.config.parallel.sequence_parallel:
|
| 56 |
+
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
| 57 |
+
|
| 58 |
+
return output
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class Embedding1D(nn.Module):
|
| 62 |
+
"""
|
| 63 |
+
1D Embedding.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
num_embeddings (int): The size of vocab.
|
| 67 |
+
embedding_dim (int): The dimention of model.
|
| 68 |
+
padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
|
| 69 |
+
therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
|
| 70 |
+
i.e. it remains as a fixed "pad". None by default.
|
| 71 |
+
dtype (Optional[torch.dtype]): Data type None by default.
|
| 72 |
+
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
num_embeddings: int,
|
| 78 |
+
embedding_dim: int,
|
| 79 |
+
*args,
|
| 80 |
+
padding_idx: int = None,
|
| 81 |
+
dtype: torch.dtype = None,
|
| 82 |
+
**kwargs,
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
|
| 86 |
+
self.num_embeddings = num_embeddings
|
| 87 |
+
self.embed_dim = embedding_dim
|
| 88 |
+
embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size
|
| 89 |
+
|
| 90 |
+
self.padding_idx = padding_idx
|
| 91 |
+
self.embed_args = args
|
| 92 |
+
self.embed_kwargs = kwargs
|
| 93 |
+
|
| 94 |
+
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
|
| 95 |
+
|
| 96 |
+
def forward(self, input_: Tensor) -> Tensor:
|
| 97 |
+
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
| 98 |
+
|
| 99 |
+
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
|
| 100 |
+
|
| 101 |
+
if gpc.config.parallel.sequence_parallel:
|
| 102 |
+
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
|
| 103 |
+
|
| 104 |
+
return output
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
| 108 |
+
"""
|
| 109 |
+
ApplyRotaryEmbQKV_
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
|
| 114 |
+
"""
|
| 115 |
+
qkv: (total, 3, nheads, headdim)
|
| 116 |
+
cos, sin: (seqlen, rotary_dim / 2)
|
| 117 |
+
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
|
| 118 |
+
rotary_dim must be <= headdim
|
| 119 |
+
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
|
| 120 |
+
"""
|
| 121 |
+
_, three, _, headdim = qkv.shape
|
| 122 |
+
assert three == 3
|
| 123 |
+
rotary_seqlen, rotary_dim = cos.shape
|
| 124 |
+
rotary_dim *= 2
|
| 125 |
+
assert rotary_dim <= headdim
|
| 126 |
+
cos_k = cos if cos_k is None else cos_k
|
| 127 |
+
sin_k = sin if sin_k is None else sin_k
|
| 128 |
+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
|
| 129 |
+
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
| 130 |
+
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
|
| 131 |
+
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
| 132 |
+
rotary_emb.apply_rotary(
|
| 133 |
+
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
|
| 134 |
+
)
|
| 135 |
+
ctx.save_for_backward(cos, sin, cos_k, sin_k)
|
| 136 |
+
return qkv
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def backward(ctx, dqkv):
|
| 140 |
+
cos, sin, cos_k, sin_k = ctx.saved_tensors
|
| 141 |
+
rotary_dim = cos.shape[-1]
|
| 142 |
+
rotary_dim *= 2
|
| 143 |
+
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
|
| 144 |
+
rotary_emb.apply_rotary(
|
| 145 |
+
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
|
| 146 |
+
)
|
| 147 |
+
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
|
| 148 |
+
rotary_emb.apply_rotary(
|
| 149 |
+
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
|
| 150 |
+
)
|
| 151 |
+
return dqkv, None, None, None, None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
|
| 155 |
+
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
|
| 156 |
+
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 160 |
+
"""
|
| 161 |
+
The rotary position embeddings from RoFormer_ (Su et. al).
|
| 162 |
+
A crucial insight from the method is that the query and keys are
|
| 163 |
+
transformed by rotation matrices which depend on the relative positions.
|
| 164 |
+
|
| 165 |
+
Other implementations are available in the Rotary Transformer repo_ and in
|
| 166 |
+
GPT-NeoX_, GPT-NeoX was an inspiration
|
| 167 |
+
|
| 168 |
+
.. _RoFormer: https://arxiv.org/abs/2104.09864
|
| 169 |
+
.. _repo: https://github.com/ZhuiyiTechnology/roformer
|
| 170 |
+
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
|
| 171 |
+
|
| 172 |
+
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
|
| 173 |
+
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
|
| 174 |
+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, dim: int, base=10000, scale_base=0, device=None):
|
| 178 |
+
""" """
|
| 179 |
+
super().__init__()
|
| 180 |
+
# Generate and save the inverse frequency buffer (non trainable)
|
| 181 |
+
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
| 182 |
+
self.scale_base = scale_base
|
| 183 |
+
self.scale = (
|
| 184 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
| 185 |
+
if scale_base > 0
|
| 186 |
+
else None
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
self._seq_len_cached = 0
|
| 190 |
+
self._cos_cached = None
|
| 191 |
+
self._sin_cached = None
|
| 192 |
+
self._cos_k_cached = None
|
| 193 |
+
self._sin_k_cached = None
|
| 194 |
+
|
| 195 |
+
def _update_cos_sin_cache(self, x, indexes):
|
| 196 |
+
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
| 197 |
+
if not isinstance(indexes, int):
|
| 198 |
+
seqlen = indexes.max().item() + 1
|
| 199 |
+
else:
|
| 200 |
+
seqlen = indexes + 1 # eval_forward
|
| 201 |
+
# Reset the tables if the sequence length has changed,
|
| 202 |
+
# or if we're on a new device (possibly due to tracing for instance)
|
| 203 |
+
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
| 204 |
+
self._seq_len_cached = seqlen
|
| 205 |
+
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
|
| 206 |
+
# Don't do einsum, it converts fp32 to fp16
|
| 207 |
+
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 208 |
+
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
| 209 |
+
if self.scale is None:
|
| 210 |
+
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
| 211 |
+
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
| 212 |
+
else:
|
| 213 |
+
power = (
|
| 214 |
+
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
| 215 |
+
) / self.scale_base
|
| 216 |
+
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
| 217 |
+
# We want the multiplication by scale to happen in fp32
|
| 218 |
+
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
| 219 |
+
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
| 220 |
+
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
| 221 |
+
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
| 222 |
+
|
| 223 |
+
def forward(self, qkv: torch.Tensor, **kwargs):
|
| 224 |
+
if kwargs.get("indexes", None) is not None:
|
| 225 |
+
return self._forward(qkv, kwargs.pop("indexes"))
|
| 226 |
+
if kwargs.get("inference_params", None) is not None:
|
| 227 |
+
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
|
| 228 |
+
else:
|
| 229 |
+
return self._eval_forward(qkv)
|
| 230 |
+
|
| 231 |
+
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 232 |
+
self._update_cos_sin_cache(qkv, indexes)
|
| 233 |
+
if self.scale is None:
|
| 234 |
+
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
|
| 235 |
+
else:
|
| 236 |
+
return apply_rotary_emb_qkv_(
|
| 237 |
+
qkv,
|
| 238 |
+
self._cos_cached[indexes],
|
| 239 |
+
self._sin_cached[indexes],
|
| 240 |
+
self._cos_k_cached[indexes],
|
| 241 |
+
self._sin_k_cached[indexes],
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def _eval_forward(self, qkv, seqlen_offset=0):
|
| 245 |
+
"""
|
| 246 |
+
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
| 247 |
+
token in the batch.
|
| 248 |
+
"""
|
| 249 |
+
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
|
| 250 |
+
if self.scale is None:
|
| 251 |
+
return legacy_apply_rotary_embed_qkv(
|
| 252 |
+
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
|
| 253 |
+
)
|
| 254 |
+
else:
|
| 255 |
+
return legacy_apply_rotary_embed_qkv(
|
| 256 |
+
qkv,
|
| 257 |
+
self._cos_cached[seqlen_offset:],
|
| 258 |
+
self._sin_cached[seqlen_offset:],
|
| 259 |
+
self._cos_k_cached[seqlen_offset:],
|
| 260 |
+
self._sin_k_cached[seqlen_offset:],
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _single_forward(self, x, indexes=0):
|
| 264 |
+
assert self.scale is None
|
| 265 |
+
self._update_cos_sin_cache(x, indexes)
|
| 266 |
+
x = x[None, ...]
|
| 267 |
+
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
|
| 268 |
+
return ret
|
| 269 |
+
|
| 270 |
+
def _single_eval_forward(self, x, seqlen_offset=0):
|
| 271 |
+
assert self.scale is None
|
| 272 |
+
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
|
| 273 |
+
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
InternLM/internlm/model/linear.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
| 9 |
+
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from internlm.core.context import ParallelMode
|
| 13 |
+
from internlm.core.context import global_context as gpc
|
| 14 |
+
from internlm.model.utils import fused_dense_func_torch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ScaleColumnParallelLinear(nn.Linear):
|
| 18 |
+
"""
|
| 19 |
+
ScaleColumnParallelLinear.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
in_features (int): size of each input sample
|
| 23 |
+
out_features (int): size of each output sample
|
| 24 |
+
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
| 25 |
+
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
| 26 |
+
in the config.
|
| 27 |
+
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 28 |
+
we do an all_gather of x before doing the matmul.
|
| 29 |
+
If not, then the input is already gathered.
|
| 30 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 31 |
+
dtype (Optional[torch.dtype]): The type of data.
|
| 32 |
+
weight_scale (int): For training stability. 1 by default.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
in_features: int,
|
| 38 |
+
out_features: int,
|
| 39 |
+
process_group: Optional[torch.distributed.ProcessGroup],
|
| 40 |
+
bias: bool = True,
|
| 41 |
+
device: Optional[torch.device] = None,
|
| 42 |
+
dtype: Optional[torch.dtype] = None,
|
| 43 |
+
weight_scale: int = 1,
|
| 44 |
+
) -> None:
|
| 45 |
+
world_size = torch.distributed.get_world_size(process_group)
|
| 46 |
+
if out_features % world_size != 0:
|
| 47 |
+
raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
|
| 48 |
+
super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
|
| 49 |
+
self.process_group = process_group
|
| 50 |
+
self.weight_scale = weight_scale
|
| 51 |
+
|
| 52 |
+
def forward(self, input): # pylint: disable=W0622
|
| 53 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 54 |
+
# we do an all_gather of x before doing the matmul.
|
| 55 |
+
# If not, then the input is already gathered.
|
| 56 |
+
if self.weight_scale != 1:
|
| 57 |
+
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
| 58 |
+
else:
|
| 59 |
+
weight = self.weight
|
| 60 |
+
return fused_dense_func_torch(
|
| 61 |
+
input,
|
| 62 |
+
weight,
|
| 63 |
+
self.bias,
|
| 64 |
+
process_group=self.process_group,
|
| 65 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class RewardModelLinear(ScaleColumnParallelLinear):
|
| 70 |
+
"""
|
| 71 |
+
RewardModelLinear.
|
| 72 |
+
Args:
|
| 73 |
+
in_features (int): size of each input sample
|
| 74 |
+
out_features (int): size of each output sample
|
| 75 |
+
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
| 76 |
+
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
| 77 |
+
in the config.
|
| 78 |
+
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 79 |
+
we do an all_gather of x before doing the matmul.
|
| 80 |
+
If not, then the input is already gathered.
|
| 81 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 82 |
+
dtype (Optional[torch.dtype]): The type of data.
|
| 83 |
+
weight_scale (int): For training stability. 1 by default.
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
in_features: int,
|
| 89 |
+
out_features: int,
|
| 90 |
+
process_group: Optional[torch.distributed.ProcessGroup],
|
| 91 |
+
bias: bool = True,
|
| 92 |
+
device: Optional[torch.device] = None,
|
| 93 |
+
dtype: Optional[torch.dtype] = None,
|
| 94 |
+
weight_scale: int = 1,
|
| 95 |
+
) -> None:
|
| 96 |
+
super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
|
| 97 |
+
torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
| 98 |
+
if bias:
|
| 99 |
+
torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
|
| 100 |
+
|
| 101 |
+
def forward(self, input): # pylint: disable=W0622
|
| 102 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 103 |
+
# we do an all_gather of x before doing the matmul.
|
| 104 |
+
# If not, then the input is already gathered.
|
| 105 |
+
if self.weight_scale != 1:
|
| 106 |
+
weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
|
| 107 |
+
else:
|
| 108 |
+
weight = self.weight
|
| 109 |
+
return fused_dense_func_torch(
|
| 110 |
+
input,
|
| 111 |
+
weight,
|
| 112 |
+
self.bias,
|
| 113 |
+
process_group=self.process_group,
|
| 114 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ColumnParallelLinearTorch(ColumnParallelLinear):
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
|
| 121 |
+
# we do an all_gather of x before doing the matmul.
|
| 122 |
+
# If not, then the input is already gathered.
|
| 123 |
+
|
| 124 |
+
return fused_dense_func_torch(
|
| 125 |
+
x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class RowParallelLinearTorch(RowParallelLinear):
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
"""
|
| 132 |
+
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
| 133 |
+
a reduce_scatter of the result.
|
| 134 |
+
"""
|
| 135 |
+
out = fused_dense_func_torch(x, self.weight, self.bias)
|
| 136 |
+
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
|
| 137 |
+
return reduce_fn(out, self.process_group)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class FeedForward(nn.Module):
|
| 141 |
+
"""
|
| 142 |
+
FeedForward.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
in_features (int): size of each input sample
|
| 146 |
+
hidden_features (int): size of hidden state of FFN
|
| 147 |
+
out_features (int): size of each output sample
|
| 148 |
+
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
|
| 149 |
+
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
|
| 150 |
+
in the config.
|
| 151 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 152 |
+
dtype (Optional[torch.dtype]): The type of data.
|
| 153 |
+
multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
in_features: int,
|
| 159 |
+
hidden_features: int,
|
| 160 |
+
out_features: int = None,
|
| 161 |
+
process_group: Optional[torch.distributed.ProcessGroup] = None,
|
| 162 |
+
bias: bool = True,
|
| 163 |
+
device: Optional[torch.device] = None,
|
| 164 |
+
dtype: Optional[torch.dtype] = None,
|
| 165 |
+
multiple_of: int = 256,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
|
| 169 |
+
hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
|
| 170 |
+
|
| 171 |
+
self.w1 = ColumnParallelLinearTorch(
|
| 172 |
+
in_features,
|
| 173 |
+
hidden_features,
|
| 174 |
+
process_group,
|
| 175 |
+
bias,
|
| 176 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 177 |
+
device=device,
|
| 178 |
+
dtype=dtype,
|
| 179 |
+
)
|
| 180 |
+
self.w2 = ColumnParallelLinearTorch(
|
| 181 |
+
in_features,
|
| 182 |
+
hidden_features,
|
| 183 |
+
process_group,
|
| 184 |
+
bias,
|
| 185 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 186 |
+
device=device,
|
| 187 |
+
dtype=dtype,
|
| 188 |
+
)
|
| 189 |
+
self.w3 = RowParallelLinearTorch(
|
| 190 |
+
hidden_features,
|
| 191 |
+
out_features,
|
| 192 |
+
process_group,
|
| 193 |
+
bias=bias,
|
| 194 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 195 |
+
device=device,
|
| 196 |
+
dtype=dtype,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 201 |
+
return out
|
InternLM/internlm/model/loss.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from internlm.core.context import ParallelMode
|
| 9 |
+
from internlm.core.context import global_context as gpc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FlashGPTLMLoss(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Loss function for flash GPT Language Model.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, parallel_output=True, label_smoothing=0):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
if label_smoothing is not None:
|
| 21 |
+
if label_smoothing != 0:
|
| 22 |
+
if gpc.is_rank_for_log():
|
| 23 |
+
print(f"use label_smoothing: {label_smoothing}")
|
| 24 |
+
else:
|
| 25 |
+
label_smoothing = 0
|
| 26 |
+
self.label_smoothing = label_smoothing
|
| 27 |
+
|
| 28 |
+
if parallel_output:
|
| 29 |
+
self.loss_fn = FlashCrossEntropyLoss(
|
| 30 |
+
reduction="mean",
|
| 31 |
+
inplace_backward=True,
|
| 32 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 33 |
+
label_smoothing=label_smoothing,
|
| 34 |
+
) # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D
|
| 35 |
+
else:
|
| 36 |
+
# Here, the output will gather output is set in the model, so use ordinary loss
|
| 37 |
+
self.loss_fn = nn.CrossEntropyLoss(reduction="mean", label_smoothing=label_smoothing)
|
| 38 |
+
|
| 39 |
+
def forward(self, *args):
|
| 40 |
+
if len(args) == 3:
|
| 41 |
+
# residual is to match prenorm
|
| 42 |
+
logits, _, labels = args
|
| 43 |
+
elif len(args) == 2:
|
| 44 |
+
# When using postnorm
|
| 45 |
+
logits, labels = args
|
| 46 |
+
else:
|
| 47 |
+
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
|
| 48 |
+
shift_logits = logits.contiguous().view(-1, logits.size(-1))
|
| 49 |
+
shift_labels = labels.contiguous().view(-1)
|
| 50 |
+
loss = self.loss_fn(
|
| 51 |
+
shift_logits, shift_labels
|
| 52 |
+
) # There is no need to consider the ignore_index problem here, because the loss calculation will be
|
| 53 |
+
# calculated through the calculation range, and -100 must be outside this range, so there is no problem
|
| 54 |
+
|
| 55 |
+
return loss
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class KLDivLoss(nn.Module):
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.temperature = gpc.config.kd_config.get('temperature', 1)
|
| 62 |
+
self.inverse = gpc.config.kd_config.get('inverse', False)
|
| 63 |
+
|
| 64 |
+
def forward(self, *args):
|
| 65 |
+
if len(args) == 3:
|
| 66 |
+
if self.inverse:
|
| 67 |
+
logits_teacher, logits_student, _ = args
|
| 68 |
+
else:
|
| 69 |
+
logits_student, logits_teacher, _ = args
|
| 70 |
+
else:
|
| 71 |
+
raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
|
| 72 |
+
|
| 73 |
+
logits_teacher = logits_teacher.contiguous().view(-1, logits_teacher.size(-1))
|
| 74 |
+
logits_student = logits_student.contiguous().view(-1, logits_student.size(-1))
|
| 75 |
+
|
| 76 |
+
log_pred_student = F.log_softmax(logits_student / self.temperature, dim=1)
|
| 77 |
+
pred_teacher = F.softmax(logits_teacher / self.temperature, dim=1)
|
| 78 |
+
loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
|
| 79 |
+
loss_kd *= self.temperature ** 2
|
| 80 |
+
|
| 81 |
+
return loss_kd
|
InternLM/internlm/model/metrics.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
|
| 5 |
+
from torch_scatter import scatter
|
| 6 |
+
|
| 7 |
+
from internlm.core.context import ParallelMode
|
| 8 |
+
from internlm.core.context import global_context as gpc
|
| 9 |
+
from internlm.utils.parallel import is_no_pp_or_last_stage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AccPerplex:
|
| 13 |
+
"""
|
| 14 |
+
AccPerplex module for calculating model's accuracy and perplexity metrics.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
device: The GPU device.
|
| 18 |
+
tp_pg: The tensor parallel process group.
|
| 19 |
+
dp_pg: The data parallel process group.
|
| 20 |
+
tokenizer: For calculating BPB.
|
| 21 |
+
dataset_types (List[str]): Various data types that will be used in the current training process,
|
| 22 |
+
such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
|
| 23 |
+
in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
|
| 27 |
+
self.device = device
|
| 28 |
+
self.right = torch.Tensor([0]).to(device=device)
|
| 29 |
+
self.total = torch.Tensor([0]).to(device=device)
|
| 30 |
+
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
| 31 |
+
self.tp_pg = tp_pg
|
| 32 |
+
self.dp_pg = dp_pg
|
| 33 |
+
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
| 34 |
+
self.tokenizer = tokenizer
|
| 35 |
+
self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
|
| 36 |
+
self.batch_shift = 0
|
| 37 |
+
self.type_ids = None
|
| 38 |
+
if dataset_types is not None:
|
| 39 |
+
self.dataset_types = dataset_types
|
| 40 |
+
self.total_type_count = len(dataset_types)
|
| 41 |
+
self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
| 42 |
+
self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
|
| 43 |
+
|
| 44 |
+
self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
|
| 45 |
+
|
| 46 |
+
def set_current_type_ids(self, type_ids: torch.Tensor):
|
| 47 |
+
self.batch_shift = 0
|
| 48 |
+
self.type_ids = type_ids.cuda()
|
| 49 |
+
|
| 50 |
+
def __call__(self, logits, labels):
|
| 51 |
+
return self.update(logits, labels, type_ids=self.type_ids)
|
| 52 |
+
|
| 53 |
+
def update(self, logits, labels, type_ids=None):
|
| 54 |
+
if gpc.config.model.use_flash_attn:
|
| 55 |
+
micro_bsz = labels.size(0)
|
| 56 |
+
else:
|
| 57 |
+
micro_bsz = 1
|
| 58 |
+
if type_ids is not None:
|
| 59 |
+
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
| 60 |
+
self.batch_shift += 1
|
| 61 |
+
self.loss_with_type_id.update(logits, labels, type_ids)
|
| 62 |
+
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
if isinstance(logits, (list, tuple)):
|
| 65 |
+
logits = logits[0]
|
| 66 |
+
|
| 67 |
+
logits = logits.detach().clone()
|
| 68 |
+
labels = labels.detach().clone()
|
| 69 |
+
|
| 70 |
+
if self.tokenizer: # need to calculate bits per bytes
|
| 71 |
+
sequences = self.tokenizer.decode_ids(labels.tolist())
|
| 72 |
+
self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
|
| 73 |
+
|
| 74 |
+
shift_logits = logits.view(-1, logits.size(-1))
|
| 75 |
+
shift_labels = labels.view(-1)
|
| 76 |
+
# There is a shift according to the current rank, because the logits are split
|
| 77 |
+
pred_shift = self.tp_local_rank * logits.shape[-1]
|
| 78 |
+
|
| 79 |
+
logits_max = torch.max(shift_logits, dim=-1)[0]
|
| 80 |
+
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
|
| 81 |
+
# Determine whether the maximum value of the current local tensor is the global maximum value
|
| 82 |
+
logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
|
| 83 |
+
|
| 84 |
+
corrects = torch.logical_and(
|
| 85 |
+
(shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
|
| 86 |
+
).long()
|
| 87 |
+
mask = shift_labels.ne(-100).long()
|
| 88 |
+
if hasattr(self, "total_type_count"):
|
| 89 |
+
ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
|
| 90 |
+
token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
|
| 91 |
+
if len(ds_acc) < self.total_type_count:
|
| 92 |
+
ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
|
| 93 |
+
token_num_type = torch.cat(
|
| 94 |
+
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
| 95 |
+
)
|
| 96 |
+
self.ds_tokens += token_num_type
|
| 97 |
+
sync_tensor = ds_acc
|
| 98 |
+
torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
| 99 |
+
self.ds_right += sync_tensor.view(-1)
|
| 100 |
+
|
| 101 |
+
acc = corrects.sum()
|
| 102 |
+
torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
| 103 |
+
self.right += acc # Masked_fill is not needed here because -100 is not available anyway
|
| 104 |
+
self.total += mask.sum()
|
| 105 |
+
|
| 106 |
+
# Subtract the maximum value.
|
| 107 |
+
shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
|
| 108 |
+
|
| 109 |
+
# Get the partition's vocab indecies
|
| 110 |
+
partition_vocab_size = shift_logits.size()[-1]
|
| 111 |
+
vocab_start_index = partition_vocab_size * self.tp_local_rank
|
| 112 |
+
vocab_end_index = vocab_start_index + partition_vocab_size
|
| 113 |
+
|
| 114 |
+
# Create a mask of valid vocab ids (1 means it needs to be masked).
|
| 115 |
+
target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
|
| 116 |
+
masked_target = shift_labels - vocab_start_index
|
| 117 |
+
masked_target[target_mask] = 0
|
| 118 |
+
|
| 119 |
+
# Get predicted-logits = logits[target].
|
| 120 |
+
# For Simplicity, we model_hf logits to a 2-D tensor with size
|
| 121 |
+
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
|
| 122 |
+
logits_2d = shift_logits.view(-1, partition_vocab_size)
|
| 123 |
+
masked_target_1d = masked_target.view(-1)
|
| 124 |
+
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
|
| 125 |
+
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
|
| 126 |
+
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
|
| 127 |
+
predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
|
| 128 |
+
predicted_logits[target_mask] = 0.0
|
| 129 |
+
# All reduce is needed to get the chunks from other GPUs.
|
| 130 |
+
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
| 131 |
+
|
| 132 |
+
pred_exp_logits = torch.exp(predicted_logits)
|
| 133 |
+
# Sum of exponential of logits along vocab dimension across all GPUs.
|
| 134 |
+
sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
|
| 135 |
+
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
| 136 |
+
|
| 137 |
+
total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
|
| 138 |
+
self.total_log_probs += total_log_probs
|
| 139 |
+
|
| 140 |
+
def get_metric(self, reset=True):
|
| 141 |
+
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
| 142 |
+
torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 143 |
+
torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 144 |
+
torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 145 |
+
if hasattr(self, "total_type_count"):
|
| 146 |
+
torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 147 |
+
torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 148 |
+
if self.tokenizer:
|
| 149 |
+
torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 150 |
+
|
| 151 |
+
acc = round((self.right / self.total).item(), 4)
|
| 152 |
+
perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
|
| 153 |
+
bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
|
| 154 |
+
|
| 155 |
+
if hasattr(self, "total_type_count"):
|
| 156 |
+
ds_acc = {}
|
| 157 |
+
ds_tokens = {}
|
| 158 |
+
for i in range(self.total_type_count):
|
| 159 |
+
ds_acc[f"acc/{self.dataset_types[i]}"] = round(
|
| 160 |
+
(self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
|
| 161 |
+
)
|
| 162 |
+
ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
|
| 163 |
+
if reset:
|
| 164 |
+
self.right.fill_(0)
|
| 165 |
+
self.total.fill_(0)
|
| 166 |
+
self.total_log_probs.fill_(0)
|
| 167 |
+
self.total_bytes.fill_(0)
|
| 168 |
+
if hasattr(self, "total_type_count"):
|
| 169 |
+
self.ds_right.fill_(0)
|
| 170 |
+
self.ds_tokens.fill_(0)
|
| 171 |
+
if self.tokenizer is not None:
|
| 172 |
+
res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
|
| 173 |
+
else:
|
| 174 |
+
res = {"acc": acc, "perplexity": perplexity}
|
| 175 |
+
if hasattr(self, "total_type_count"):
|
| 176 |
+
res.update(ds_acc)
|
| 177 |
+
res.update(ds_tokens)
|
| 178 |
+
|
| 179 |
+
loss_res = self.loss_with_type_id.get_metric(reset)
|
| 180 |
+
res.update(loss_res)
|
| 181 |
+
|
| 182 |
+
return res
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LossWithTypeId:
|
| 186 |
+
"""
|
| 187 |
+
Notice the loss value computed here may be not the same with the main info loss,
|
| 188 |
+
cause loss here is the reduced result of the data parallel.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
|
| 192 |
+
self.device = device
|
| 193 |
+
self.dp_pg = dp_pg
|
| 194 |
+
|
| 195 |
+
self.loss = torch.Tensor([0.0]).to(device=device)
|
| 196 |
+
self.token_num = torch.Tensor([0.0]).to(device=device)
|
| 197 |
+
|
| 198 |
+
if dataset_types is not None:
|
| 199 |
+
self.dataset_types = dataset_types
|
| 200 |
+
self.total_type_count = len(dataset_types)
|
| 201 |
+
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
| 202 |
+
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
|
| 203 |
+
|
| 204 |
+
self.loss_fn = FlashCrossEntropyLoss(
|
| 205 |
+
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
def update(self, logits, labels, type_ids=None):
|
| 209 |
+
with torch.no_grad():
|
| 210 |
+
if isinstance(logits, (list, tuple)):
|
| 211 |
+
logits = logits[0]
|
| 212 |
+
logits = logits.contiguous().view(-1, logits.size(-1))
|
| 213 |
+
labels = labels.contiguous().view(-1)
|
| 214 |
+
loss_list = self.loss_fn(logits, labels)
|
| 215 |
+
|
| 216 |
+
cond = labels != -100
|
| 217 |
+
real_loss_list = loss_list[cond]
|
| 218 |
+
self.loss += real_loss_list.sum()
|
| 219 |
+
self.token_num += real_loss_list.numel()
|
| 220 |
+
|
| 221 |
+
if hasattr(self, "total_type_count"):
|
| 222 |
+
type_ids = type_ids.contiguous().view(-1).to(self.device)
|
| 223 |
+
real_type_ids = type_ids[cond]
|
| 224 |
+
|
| 225 |
+
loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
|
| 226 |
+
token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
|
| 227 |
+
|
| 228 |
+
if len(loss_list_type) < self.total_type_count:
|
| 229 |
+
loss_list_type = torch.cat(
|
| 230 |
+
[loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
|
| 231 |
+
)
|
| 232 |
+
token_num_type = torch.cat(
|
| 233 |
+
[token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
|
| 234 |
+
)
|
| 235 |
+
self.ds_loss += loss_list_type
|
| 236 |
+
self.ds_token_num += token_num_type
|
| 237 |
+
|
| 238 |
+
def get_metric(self, reset=True):
|
| 239 |
+
if is_no_pp_or_last_stage() and self.dp_pg is not None:
|
| 240 |
+
torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 241 |
+
torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 242 |
+
if hasattr(self, "total_type_count"):
|
| 243 |
+
torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 244 |
+
torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
|
| 245 |
+
|
| 246 |
+
loss = round((self.loss / self.token_num).item(), 4)
|
| 247 |
+
res = {
|
| 248 |
+
"loss_from_metric": loss,
|
| 249 |
+
}
|
| 250 |
+
if hasattr(self, "total_type_count"):
|
| 251 |
+
ds_loss = {}
|
| 252 |
+
for i in range(self.total_type_count):
|
| 253 |
+
ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
|
| 254 |
+
res.update(ds_loss)
|
| 255 |
+
|
| 256 |
+
if reset:
|
| 257 |
+
self.loss.fill_(0.0)
|
| 258 |
+
self.token_num.fill_(0.0)
|
| 259 |
+
if hasattr(self, "total_type_count"):
|
| 260 |
+
self.ds_loss.fill_(0.0)
|
| 261 |
+
self.ds_token_num.fill_(0.0)
|
| 262 |
+
|
| 263 |
+
return res
|
InternLM/internlm/model/modeling_internlm.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
| 9 |
+
from flash_attn.modules.mlp import ParallelFusedMLP
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
| 13 |
+
from internlm.core.context.parallel_context import global_context as gpc
|
| 14 |
+
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
| 15 |
+
from internlm.model.embedding import Embedding1D, Embedding1DLVM
|
| 16 |
+
from internlm.model.linear import (
|
| 17 |
+
FeedForward,
|
| 18 |
+
RewardModelLinear,
|
| 19 |
+
ScaleColumnParallelLinear,
|
| 20 |
+
)
|
| 21 |
+
from internlm.model.multi_head_attention import MHA
|
| 22 |
+
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
|
| 23 |
+
from internlm.solver.pipeline_utils import partition_uniform
|
| 24 |
+
from internlm.utils.checkpoint import activation_checkpoint
|
| 25 |
+
from internlm.utils.common import filter_kwargs
|
| 26 |
+
from internlm.utils.logger import get_logger
|
| 27 |
+
from internlm.utils.registry import MODEL_INITIALIZER
|
| 28 |
+
|
| 29 |
+
MODEL_TYPE = "INTERNLM"
|
| 30 |
+
|
| 31 |
+
logger = get_logger(__file__)
|
| 32 |
+
RMSNorm = try_import_RMSNorm()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PackedFlashBaseLayer1D(nn.Module):
|
| 36 |
+
"""
|
| 37 |
+
1D Packed Flash Base Layer.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
hidden_size (int): The hidden size of model. 768 by default.
|
| 41 |
+
num_attention_heads (int): The number of attention heads. 12 by default.
|
| 42 |
+
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
| 43 |
+
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
| 44 |
+
drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
|
| 45 |
+
dtype (torch.dtype): Type of data. torch.float by default.
|
| 46 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
| 47 |
+
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
| 48 |
+
layer_idx (int): The index of current layer. 0 by default.
|
| 49 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
| 50 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 51 |
+
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
| 52 |
+
use_flash_attn (bool): Whether use flash-attn. True by default.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
hidden_size: int = 768,
|
| 58 |
+
num_attention_heads: int = 12,
|
| 59 |
+
mlp_ratio: int = 4,
|
| 60 |
+
attn_drop_rate: float = 0,
|
| 61 |
+
drop_rate: float = 0.0,
|
| 62 |
+
dtype: torch.dtype = torch.float,
|
| 63 |
+
layer_norm_epsilon: float = 1e-6,
|
| 64 |
+
checkpoint: bool = False,
|
| 65 |
+
layer_idx: int = 0,
|
| 66 |
+
residual_in_fp32: bool = False,
|
| 67 |
+
device: Optional[torch.device] = None,
|
| 68 |
+
norm_type: str = "rmsnorm",
|
| 69 |
+
dropout_selective_checkpoint: bool = True,
|
| 70 |
+
use_scaled_init: bool = True,
|
| 71 |
+
use_swiglu: bool = True,
|
| 72 |
+
use_flash_attn: bool = True,
|
| 73 |
+
):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.checkpoint = checkpoint
|
| 76 |
+
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
| 77 |
+
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
| 78 |
+
self.layer_idx = layer_idx
|
| 79 |
+
self.use_flash_attn = use_flash_attn
|
| 80 |
+
|
| 81 |
+
head_dim = hidden_size // num_attention_heads
|
| 82 |
+
self.mixer = MHA(
|
| 83 |
+
embed_dim=hidden_size,
|
| 84 |
+
num_heads=num_attention_heads,
|
| 85 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 86 |
+
dropout=attn_drop_rate,
|
| 87 |
+
softmax_scale=1 / math.sqrt(head_dim),
|
| 88 |
+
causal=True,
|
| 89 |
+
layer_idx=layer_idx,
|
| 90 |
+
rotary_emb_dim=head_dim,
|
| 91 |
+
rotary_emb_scale_base=0,
|
| 92 |
+
use_flash_attn=use_flash_attn,
|
| 93 |
+
device=device,
|
| 94 |
+
dtype=dtype,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.dropout1 = nn.Dropout(drop_rate)
|
| 98 |
+
if norm_type == "rmsnorm":
|
| 99 |
+
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 100 |
+
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 101 |
+
else:
|
| 102 |
+
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 103 |
+
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 104 |
+
|
| 105 |
+
if use_swiglu:
|
| 106 |
+
self.mlp = FeedForward(
|
| 107 |
+
hidden_size,
|
| 108 |
+
int(hidden_size * mlp_ratio),
|
| 109 |
+
out_features=hidden_size,
|
| 110 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 111 |
+
bias=False,
|
| 112 |
+
device=device,
|
| 113 |
+
dtype=dtype,
|
| 114 |
+
)
|
| 115 |
+
else:
|
| 116 |
+
self.mlp = ParallelFusedMLP(
|
| 117 |
+
hidden_size,
|
| 118 |
+
int(hidden_size * mlp_ratio),
|
| 119 |
+
out_features=hidden_size,
|
| 120 |
+
activation="gelu_approx",
|
| 121 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 122 |
+
bias1=False,
|
| 123 |
+
bias2=False,
|
| 124 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 125 |
+
checkpoint_lvl=0,
|
| 126 |
+
heuristic="auto",
|
| 127 |
+
device=device,
|
| 128 |
+
dtype=dtype,
|
| 129 |
+
)
|
| 130 |
+
for _, param in self.mlp.named_parameters():
|
| 131 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 132 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 133 |
+
self.dropout2 = nn.Dropout(drop_rate)
|
| 134 |
+
self.use_swiglu = use_swiglu
|
| 135 |
+
self.use_scaled_init = use_scaled_init
|
| 136 |
+
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
| 137 |
+
self.return_residual = False
|
| 138 |
+
self.reset_parameters()
|
| 139 |
+
|
| 140 |
+
def reset_parameters(self):
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for name, param in self.mixer.named_parameters():
|
| 143 |
+
if param.ndim == 1:
|
| 144 |
+
param.data.zero_()
|
| 145 |
+
elif "Wqkv" in name:
|
| 146 |
+
normal_(std=0.006)(param.data)
|
| 147 |
+
elif self.use_scaled_init:
|
| 148 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 149 |
+
else:
|
| 150 |
+
normal_(std=0.0015)(param.data)
|
| 151 |
+
|
| 152 |
+
for name, param in self.mlp.named_parameters():
|
| 153 |
+
if param.ndim == 1 and "bias" in name:
|
| 154 |
+
param.data.zero_()
|
| 155 |
+
elif self.use_swiglu:
|
| 156 |
+
if self.use_scaled_init and "w2" in name:
|
| 157 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 158 |
+
else:
|
| 159 |
+
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
|
| 160 |
+
else:
|
| 161 |
+
if self.use_scaled_init and "fc1" not in name:
|
| 162 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 163 |
+
else:
|
| 164 |
+
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
| 165 |
+
|
| 166 |
+
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
| 167 |
+
if self.checkpoint and self.training:
|
| 168 |
+
return activation_checkpoint(
|
| 169 |
+
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
| 170 |
+
)
|
| 171 |
+
else:
|
| 172 |
+
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
| 173 |
+
|
| 174 |
+
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
| 175 |
+
r"""Pass the input through the encoder layer.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 179 |
+
residual: hidden_states = Attn/MLP(LN(residual))
|
| 180 |
+
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
|
| 181 |
+
indexes: the length of index is same as hidden states, which stand for the current position
|
| 182 |
+
"""
|
| 183 |
+
mixer_kwargs = {
|
| 184 |
+
"cu_seqlens": cu_seqlens,
|
| 185 |
+
"max_seqlen": max_seqlen,
|
| 186 |
+
"indexes": indexes,
|
| 187 |
+
"inference_params": inference_params,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
def _dropout_and_norm_attn(_hidden_states):
|
| 191 |
+
_dropped = self.dropout1(_hidden_states)
|
| 192 |
+
_residual = _dropped
|
| 193 |
+
_hidden_states = self.norm1(_residual.float())
|
| 194 |
+
return _residual, _hidden_states
|
| 195 |
+
|
| 196 |
+
if self.dropout_selective_checkpoint:
|
| 197 |
+
residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
|
| 198 |
+
else:
|
| 199 |
+
residual, hidden_states = _dropout_and_norm_attn(hidden_states)
|
| 200 |
+
|
| 201 |
+
if self.residual_in_fp32:
|
| 202 |
+
residual = residual.to(torch.float32)
|
| 203 |
+
|
| 204 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 205 |
+
|
| 206 |
+
def _dropout_and_norm_ffn(_residual, _hidden_states):
|
| 207 |
+
_dropped = self.dropout2(_hidden_states)
|
| 208 |
+
_residual = (_dropped + _residual) if _residual is not None else _dropped
|
| 209 |
+
_hidden_states = self.norm2(_residual.float())
|
| 210 |
+
return _residual, _hidden_states
|
| 211 |
+
|
| 212 |
+
if self.dropout_selective_checkpoint:
|
| 213 |
+
residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
|
| 214 |
+
else:
|
| 215 |
+
residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
|
| 216 |
+
|
| 217 |
+
if self.residual_in_fp32:
|
| 218 |
+
residual = residual.to(torch.float32)
|
| 219 |
+
|
| 220 |
+
hidden_states = self.mlp(hidden_states)
|
| 221 |
+
|
| 222 |
+
return hidden_states + residual
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class PackedFlashInternLm1D(nn.Module):
|
| 226 |
+
"""
|
| 227 |
+
1D Packed Flash InternLm.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
num_layers (int): The number of layer. 12 by default.
|
| 231 |
+
hidden_size (int): The size of hidden state. 768 by default.
|
| 232 |
+
num_attention_heads (int): The number of attention head. 12 by default.
|
| 233 |
+
vocab_size (int): The size of vocabulary. 50304 by default.
|
| 234 |
+
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
| 235 |
+
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
| 236 |
+
drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
|
| 237 |
+
dtype (torch.dtype): The type of data. torch.float by default.
|
| 238 |
+
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
|
| 239 |
+
of layers. 0.0 by default.
|
| 240 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
| 241 |
+
first (bool): Whether input embedding layer or not. False by default.
|
| 242 |
+
last (bool): Whether output embedding layer or not. False by default.
|
| 243 |
+
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
| 244 |
+
True by default.
|
| 245 |
+
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
| 246 |
+
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
| 247 |
+
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
|
| 248 |
+
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
| 249 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
| 250 |
+
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
| 251 |
+
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
| 252 |
+
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(
|
| 256 |
+
self,
|
| 257 |
+
num_layers: int = 12,
|
| 258 |
+
hidden_size: int = 768,
|
| 259 |
+
num_attention_heads: int = 12,
|
| 260 |
+
vocab_size: int = 50304,
|
| 261 |
+
mlp_ratio: int = 4.0,
|
| 262 |
+
attn_drop_rate: float = 0.0,
|
| 263 |
+
drop_rate: float = 0.0,
|
| 264 |
+
dtype: torch.dtype = torch.float,
|
| 265 |
+
checkpoint: float = 0.0,
|
| 266 |
+
layer_norm_epsilon: float = 1e-5,
|
| 267 |
+
first: bool = False,
|
| 268 |
+
last: bool = False,
|
| 269 |
+
embed_split_hidden: bool = False,
|
| 270 |
+
embed_grad_scale: float = 0.1,
|
| 271 |
+
parallel_output: bool = True,
|
| 272 |
+
start_layer_idx: int = 0,
|
| 273 |
+
device: Optional[torch.device] = None,
|
| 274 |
+
residual_in_fp32: bool = False,
|
| 275 |
+
norm_type: str = "rmsnorm",
|
| 276 |
+
is_reward: bool = False,
|
| 277 |
+
dropout_selective_checkpoint: bool = True,
|
| 278 |
+
use_scaled_init: bool = True,
|
| 279 |
+
use_swiglu: bool = True,
|
| 280 |
+
use_flash_attn: bool = True,
|
| 281 |
+
lvm_config: dict = None,
|
| 282 |
+
):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.lvm_config = lvm_config
|
| 285 |
+
|
| 286 |
+
checkpoint_layer_num = int(num_layers * checkpoint)
|
| 287 |
+
|
| 288 |
+
if is_reward:
|
| 289 |
+
head_cls = RewardModelLinear
|
| 290 |
+
else:
|
| 291 |
+
head_cls = ScaleColumnParallelLinear
|
| 292 |
+
if first:
|
| 293 |
+
if self.lvm_config.get('enable', False):
|
| 294 |
+
self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg'))
|
| 295 |
+
if self.embedding.embed_proj is not None:
|
| 296 |
+
for _, param in self.embedding.embed_proj.named_parameters():
|
| 297 |
+
normal_(std=0.0052)(param)
|
| 298 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 299 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 300 |
+
else:
|
| 301 |
+
if embed_split_hidden:
|
| 302 |
+
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
| 303 |
+
else:
|
| 304 |
+
self.embedding = ParallelGPT2Embeddings(
|
| 305 |
+
embed_dim=hidden_size,
|
| 306 |
+
vocab_size=vocab_size,
|
| 307 |
+
max_position_embeddings=-1,
|
| 308 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 309 |
+
padding_idx=None,
|
| 310 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 311 |
+
device=device,
|
| 312 |
+
dtype=dtype,
|
| 313 |
+
)
|
| 314 |
+
for _, param in self.embedding.named_parameters():
|
| 315 |
+
normal_(std=0.0052)(param)
|
| 316 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 317 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 318 |
+
self.embed_grad_scale = embed_grad_scale
|
| 319 |
+
self.blocks = nn.ModuleList(
|
| 320 |
+
[
|
| 321 |
+
PackedFlashBaseLayer1D(
|
| 322 |
+
hidden_size=hidden_size,
|
| 323 |
+
num_attention_heads=num_attention_heads,
|
| 324 |
+
mlp_ratio=mlp_ratio,
|
| 325 |
+
attn_drop_rate=attn_drop_rate,
|
| 326 |
+
drop_rate=drop_rate,
|
| 327 |
+
dtype=dtype,
|
| 328 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
| 329 |
+
checkpoint=lid < checkpoint_layer_num,
|
| 330 |
+
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
|
| 331 |
+
residual_in_fp32=residual_in_fp32,
|
| 332 |
+
device=device,
|
| 333 |
+
norm_type=norm_type,
|
| 334 |
+
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
| 335 |
+
use_scaled_init=use_scaled_init,
|
| 336 |
+
use_swiglu=use_swiglu,
|
| 337 |
+
use_flash_attn=use_flash_attn,
|
| 338 |
+
)
|
| 339 |
+
for lid in range(num_layers)
|
| 340 |
+
]
|
| 341 |
+
)
|
| 342 |
+
if last:
|
| 343 |
+
if norm_type == "rmsnorm":
|
| 344 |
+
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 345 |
+
else:
|
| 346 |
+
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 347 |
+
self.head = head_cls(
|
| 348 |
+
in_features=hidden_size,
|
| 349 |
+
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
| 350 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 351 |
+
bias=False,
|
| 352 |
+
device=device,
|
| 353 |
+
dtype=dtype,
|
| 354 |
+
weight_scale=embed_grad_scale,
|
| 355 |
+
)
|
| 356 |
+
for _, param in self.head.named_parameters():
|
| 357 |
+
normal_(std=0.0052)(param)
|
| 358 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 359 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 360 |
+
self.parallel_output = parallel_output
|
| 361 |
+
|
| 362 |
+
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
| 363 |
+
# attention_mask: compute attention on the places where the value is 1
|
| 364 |
+
if hasattr(self, "embedding"):
|
| 365 |
+
hidden_states = self.embedding(input_ids)
|
| 366 |
+
if self.embed_grad_scale != 1:
|
| 367 |
+
hidden_states = (
|
| 368 |
+
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
|
| 369 |
+
)
|
| 370 |
+
if isinstance(cu_seqlens, list):
|
| 371 |
+
assert len(cu_seqlens) == 1
|
| 372 |
+
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
|
| 373 |
+
|
| 374 |
+
if cu_seqlens is not None:
|
| 375 |
+
cu_seqlens = cu_seqlens.squeeze(0)
|
| 376 |
+
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
|
| 377 |
+
# the batch dimension with a size of 1 should be directly squeezed off.
|
| 378 |
+
|
| 379 |
+
if indexes is not None:
|
| 380 |
+
assert len(indexes) == 1
|
| 381 |
+
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
| 382 |
+
indexes = indexes[0]
|
| 383 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
| 384 |
+
|
| 385 |
+
for _, block in enumerate(self.blocks):
|
| 386 |
+
hidden_states = block(
|
| 387 |
+
hidden_states,
|
| 388 |
+
cu_seqlens=cu_seqlens,
|
| 389 |
+
indexes=indexes,
|
| 390 |
+
inference_params=inference_params,
|
| 391 |
+
max_seqlen=max_seqlen,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
if hasattr(self, "norm"):
|
| 395 |
+
hidden_states = self.norm(hidden_states.float())
|
| 396 |
+
if hasattr(self, "head"):
|
| 397 |
+
hidden_states = self.head(hidden_states)
|
| 398 |
+
|
| 399 |
+
if not self.parallel_output:
|
| 400 |
+
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
| 405 |
+
"""
|
| 406 |
+
build generic model 1d
|
| 407 |
+
|
| 408 |
+
Args:
|
| 409 |
+
num_layers (int): The number of layer.
|
| 410 |
+
num_chunks (int): The number of partitions in pipeline parallel.
|
| 411 |
+
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
|
| 412 |
+
|
| 413 |
+
"""
|
| 414 |
+
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
| 415 |
+
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
| 416 |
+
|
| 417 |
+
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
| 418 |
+
parts = all_parts[pipeline_rank]
|
| 419 |
+
if gpc.is_rank_for_log():
|
| 420 |
+
logger.info(f"The layer sharding is {all_parts}.")
|
| 421 |
+
|
| 422 |
+
models = []
|
| 423 |
+
|
| 424 |
+
for start, end in parts:
|
| 425 |
+
kwargs["num_layers"] = end - start
|
| 426 |
+
kwargs["first"] = start == 0
|
| 427 |
+
# If there is no content in the final layer, assign the last layer.
|
| 428 |
+
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
|
| 429 |
+
kwargs["device"] = device
|
| 430 |
+
kwargs["start_layer_idx"] = start
|
| 431 |
+
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
|
| 432 |
+
|
| 433 |
+
models.append(chunk)
|
| 434 |
+
torch.distributed.barrier()
|
| 435 |
+
if len(models) == 1:
|
| 436 |
+
model = models[0]
|
| 437 |
+
else:
|
| 438 |
+
model = nn.ModuleList(models)
|
| 439 |
+
|
| 440 |
+
return model
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
| 444 |
+
def build_model_with_cfg(
|
| 445 |
+
num_chunks=1,
|
| 446 |
+
checkpoint=0.0,
|
| 447 |
+
dtype=torch.float,
|
| 448 |
+
embed_split_hidden=False,
|
| 449 |
+
num_layers=48,
|
| 450 |
+
hidden_size=2048,
|
| 451 |
+
vocab_size=50304,
|
| 452 |
+
embed_grad_scale=1,
|
| 453 |
+
parallel_output=True,
|
| 454 |
+
num_attention_heads=32,
|
| 455 |
+
mlp_ratio=4.0,
|
| 456 |
+
residual_in_fp32=False,
|
| 457 |
+
norm_type="rmsnorm",
|
| 458 |
+
drop_rate=0,
|
| 459 |
+
attn_drop_rate=0,
|
| 460 |
+
apply_post_layer_norm=False, # pylint: disable=W0613
|
| 461 |
+
layer_norm_epsilon=1e-5,
|
| 462 |
+
is_reward=False,
|
| 463 |
+
dropout_selective_checkpoint=True,
|
| 464 |
+
use_scaled_init: bool = True,
|
| 465 |
+
use_swiglu: bool = True,
|
| 466 |
+
use_flash_attn: bool = True,
|
| 467 |
+
lvm_config=None,
|
| 468 |
+
):
|
| 469 |
+
"""
|
| 470 |
+
Build model with config.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
|
| 474 |
+
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
|
| 475 |
+
dtype (torch.dtype): The type of data. torch.float by default.
|
| 476 |
+
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
| 477 |
+
False by default.
|
| 478 |
+
num_layers (int): The number of layer. 48 by default.
|
| 479 |
+
hidden_size (int): The size of hidden state. 2048 by default.
|
| 480 |
+
vocab_size (int): The size of vocabulary. 50304 by default.
|
| 481 |
+
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
| 482 |
+
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
| 483 |
+
num_attention_heads (int): The number of attention head. 32 by default.
|
| 484 |
+
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
|
| 485 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
|
| 486 |
+
because this parameter requires inconsistent data types to be passed between pipelines,
|
| 487 |
+
which requires significant modifications to internlm.
|
| 488 |
+
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
| 489 |
+
drop_rate (float): The dropout rate of input hidden state. 0 by default.
|
| 490 |
+
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
| 491 |
+
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
|
| 492 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
| 493 |
+
is_reward (bool): Whether to use reward model. False by default.
|
| 494 |
+
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
| 495 |
+
use_scaled_init (bool): Whether to use scaled init. True by default.
|
| 496 |
+
use_swiglu (bool): Whether to use swiglu. True by default.
|
| 497 |
+
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
| 498 |
+
|
| 499 |
+
"""
|
| 500 |
+
|
| 501 |
+
cfg = dict(
|
| 502 |
+
hidden_size=hidden_size,
|
| 503 |
+
num_attention_heads=num_attention_heads,
|
| 504 |
+
checkpoint=checkpoint,
|
| 505 |
+
dtype=dtype,
|
| 506 |
+
embed_split_hidden=embed_split_hidden,
|
| 507 |
+
vocab_size=vocab_size,
|
| 508 |
+
embed_grad_scale=embed_grad_scale,
|
| 509 |
+
parallel_output=parallel_output,
|
| 510 |
+
mlp_ratio=mlp_ratio,
|
| 511 |
+
residual_in_fp32=residual_in_fp32,
|
| 512 |
+
norm_type=norm_type,
|
| 513 |
+
drop_rate=drop_rate,
|
| 514 |
+
attn_drop_rate=attn_drop_rate,
|
| 515 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
| 516 |
+
is_reward=is_reward,
|
| 517 |
+
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
| 518 |
+
use_scaled_init=use_scaled_init,
|
| 519 |
+
use_swiglu=use_swiglu,
|
| 520 |
+
use_flash_attn=use_flash_attn,
|
| 521 |
+
lvm_config=lvm_config,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
InternLM/internlm/model/modeling_vit.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
| 9 |
+
from flash_attn.modules.mlp import ParallelFusedMLP
|
| 10 |
+
from torch import nn
|
| 11 |
+
|
| 12 |
+
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
| 13 |
+
from internlm.core.context.parallel_context import global_context as gpc
|
| 14 |
+
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
| 15 |
+
from internlm.model.embedding import Embedding1D, Embedding1DLVM
|
| 16 |
+
from internlm.model.linear import (
|
| 17 |
+
FeedForward,
|
| 18 |
+
RewardModelLinear,
|
| 19 |
+
ScaleColumnParallelLinear,
|
| 20 |
+
)
|
| 21 |
+
from internlm.model.multi_head_attention import MHA
|
| 22 |
+
from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, try_import_LayerNorm
|
| 23 |
+
from internlm.solver.pipeline_utils import partition_uniform
|
| 24 |
+
from internlm.utils.checkpoint import activation_checkpoint
|
| 25 |
+
from internlm.utils.common import filter_kwargs
|
| 26 |
+
from internlm.utils.logger import get_logger
|
| 27 |
+
from internlm.utils.registry import MODEL_INITIALIZER
|
| 28 |
+
|
| 29 |
+
MODEL_TYPE = "ViT"
|
| 30 |
+
|
| 31 |
+
logger = get_logger(__file__)
|
| 32 |
+
RMSNorm = try_import_RMSNorm()
|
| 33 |
+
LayerNorm = try_import_LayerNorm()
|
| 34 |
+
|
| 35 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
| 36 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 37 |
+
|
| 38 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
| 39 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
| 40 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
| 41 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
| 42 |
+
'survival rate' as the argument.
|
| 43 |
+
|
| 44 |
+
"""
|
| 45 |
+
if drop_prob == 0. or not training:
|
| 46 |
+
return x
|
| 47 |
+
keep_prob = 1 - drop_prob
|
| 48 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 49 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 50 |
+
if keep_prob > 0.0 and scale_by_keep:
|
| 51 |
+
random_tensor.div_(keep_prob)
|
| 52 |
+
return x * random_tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class DropPath(nn.Module):
|
| 56 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 57 |
+
"""
|
| 58 |
+
def __init__(self, drop_prob=None, scale_by_keep=True):
|
| 59 |
+
super(DropPath, self).__init__()
|
| 60 |
+
self.drop_prob = drop_prob
|
| 61 |
+
self.scale_by_keep = scale_by_keep
|
| 62 |
+
|
| 63 |
+
def forward(self, x):
|
| 64 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class PackedFlashBaseLayer1D(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
1D Packed Flash Base Layer.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
hidden_size (int): The hidden size of model. 768 by default.
|
| 73 |
+
num_attention_heads (int): The number of attention heads. 12 by default.
|
| 74 |
+
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
| 75 |
+
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
| 76 |
+
drop_path_rate (float): The drop path rate of the input hidden state. 0.0 by default.
|
| 77 |
+
dtype (torch.dtype): Type of data. torch.float by default.
|
| 78 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
| 79 |
+
checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
|
| 80 |
+
layer_idx (int): The index of current layer. 0 by default.
|
| 81 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
| 82 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 83 |
+
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
| 84 |
+
use_flash_attn (bool): Whether use flash-attn. True by default.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
hidden_size: int = 768,
|
| 90 |
+
num_attention_heads: int = 12,
|
| 91 |
+
mlp_ratio: int = 4,
|
| 92 |
+
mlp_bias: bool = False,
|
| 93 |
+
attn_drop_rate: float = 0,
|
| 94 |
+
drop_path_rate: float = 0.0,
|
| 95 |
+
dtype: torch.dtype = torch.float,
|
| 96 |
+
layer_norm_epsilon: float = 1e-6,
|
| 97 |
+
checkpoint: bool = False,
|
| 98 |
+
layer_idx: int = 0,
|
| 99 |
+
residual_in_fp32: bool = False,
|
| 100 |
+
device: Optional[torch.device] = None,
|
| 101 |
+
norm_type: str = "rmsnorm",
|
| 102 |
+
dropout_selective_checkpoint: bool = True,
|
| 103 |
+
use_scaled_init: bool = True,
|
| 104 |
+
use_swiglu: bool = True,
|
| 105 |
+
use_flash_attn: bool = True,
|
| 106 |
+
):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.checkpoint = checkpoint
|
| 109 |
+
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
| 110 |
+
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
| 111 |
+
self.layer_idx = layer_idx
|
| 112 |
+
self.use_flash_attn = use_flash_attn
|
| 113 |
+
|
| 114 |
+
head_dim = hidden_size // num_attention_heads
|
| 115 |
+
self.mixer = MHA(
|
| 116 |
+
embed_dim=hidden_size,
|
| 117 |
+
num_heads=num_attention_heads,
|
| 118 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 119 |
+
dropout=attn_drop_rate,
|
| 120 |
+
softmax_scale=1 / math.sqrt(head_dim),
|
| 121 |
+
causal=True,
|
| 122 |
+
layer_idx=layer_idx,
|
| 123 |
+
rotary_emb_dim=head_dim,
|
| 124 |
+
rotary_emb_scale_base=0,
|
| 125 |
+
use_flash_attn=use_flash_attn,
|
| 126 |
+
device=device,
|
| 127 |
+
dtype=dtype,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
self.dropout1 = DropPath(drop_path_rate)
|
| 131 |
+
if norm_type == "rmsnorm":
|
| 132 |
+
self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 133 |
+
self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 134 |
+
else:
|
| 135 |
+
self.norm1 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 136 |
+
self.norm2 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 137 |
+
|
| 138 |
+
self.mlp = ParallelFusedMLP(
|
| 139 |
+
hidden_size,
|
| 140 |
+
int(hidden_size * mlp_ratio),
|
| 141 |
+
out_features=hidden_size,
|
| 142 |
+
activation="gelu_approx",
|
| 143 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 144 |
+
bias1=mlp_bias,
|
| 145 |
+
bias2=mlp_bias,
|
| 146 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 147 |
+
checkpoint_lvl=0,
|
| 148 |
+
heuristic="auto",
|
| 149 |
+
device=device,
|
| 150 |
+
dtype=dtype,
|
| 151 |
+
)
|
| 152 |
+
for _, param in self.mlp.named_parameters():
|
| 153 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 154 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 155 |
+
self.dropout2 = DropPath(drop_path_rate)
|
| 156 |
+
self.use_swiglu = use_swiglu
|
| 157 |
+
self.use_scaled_init = use_scaled_init
|
| 158 |
+
self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
|
| 159 |
+
self.return_residual = False
|
| 160 |
+
self.reset_parameters()
|
| 161 |
+
|
| 162 |
+
def reset_parameters(self):
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
for name, param in self.mixer.named_parameters():
|
| 165 |
+
if param.ndim == 1:
|
| 166 |
+
param.data.zero_()
|
| 167 |
+
elif "Wqkv" in name:
|
| 168 |
+
normal_(std=0.006)(param.data)
|
| 169 |
+
elif self.use_scaled_init:
|
| 170 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 171 |
+
else:
|
| 172 |
+
normal_(std=0.0015)(param.data)
|
| 173 |
+
|
| 174 |
+
for name, param in self.mlp.named_parameters():
|
| 175 |
+
if param.ndim == 1 and "bias" in name:
|
| 176 |
+
param.data.zero_()
|
| 177 |
+
elif self.use_swiglu:
|
| 178 |
+
if self.use_scaled_init and "w2" in name:
|
| 179 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 180 |
+
else:
|
| 181 |
+
normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
|
| 182 |
+
else:
|
| 183 |
+
if self.use_scaled_init and "fc1" not in name:
|
| 184 |
+
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
|
| 185 |
+
else:
|
| 186 |
+
normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
|
| 187 |
+
|
| 188 |
+
def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
| 189 |
+
if self.checkpoint and self.training:
|
| 190 |
+
return activation_checkpoint(
|
| 191 |
+
self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
|
| 195 |
+
|
| 196 |
+
def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
|
| 197 |
+
r"""Pass the input through the encoder layer.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
hidden_states: the sequence to the encoder layer (required).
|
| 201 |
+
residual: hidden_states = Attn/MLP(LN(residual))
|
| 202 |
+
cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
|
| 203 |
+
indexes: the length of index is same as hidden states, which stand for the current position
|
| 204 |
+
"""
|
| 205 |
+
mixer_kwargs = {
|
| 206 |
+
"cu_seqlens": cu_seqlens,
|
| 207 |
+
"max_seqlen": max_seqlen,
|
| 208 |
+
"indexes": indexes,
|
| 209 |
+
"inference_params": inference_params,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
residual = hidden_states
|
| 213 |
+
|
| 214 |
+
hidden_states = self.norm1(residual.float())
|
| 215 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
| 216 |
+
hidden_states = self.dropout1(hidden_states)
|
| 217 |
+
|
| 218 |
+
residual = residual + hidden_states
|
| 219 |
+
|
| 220 |
+
hidden_states = self.norm2(residual.float())
|
| 221 |
+
hidden_states = self.mlp(hidden_states)
|
| 222 |
+
hidden_states = self.dropout2(hidden_states)
|
| 223 |
+
|
| 224 |
+
return hidden_states + residual
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class PackedFlashInternLm1D(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
1D Packed Flash InternLm.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
num_layers (int): The number of layer. 12 by default.
|
| 233 |
+
hidden_size (int): The size of hidden state. 768 by default.
|
| 234 |
+
num_attention_heads (int): The number of attention head. 12 by default.
|
| 235 |
+
vocab_size (int): The size of vocabulary. 50304 by default.
|
| 236 |
+
mlp_ratio (int): The ratio of MLP layers. 4 by default.
|
| 237 |
+
attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
|
| 238 |
+
drop_path_rate (float): The drop path rate of input hidden state. 0.0 by default.
|
| 239 |
+
dtype (torch.dtype): The type of data. torch.float by default.
|
| 240 |
+
checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
|
| 241 |
+
of layers. 0.0 by default.
|
| 242 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
|
| 243 |
+
first (bool): Whether input embedding layer or not. False by default.
|
| 244 |
+
last (bool): Whether output embedding layer or not. False by default.
|
| 245 |
+
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
| 246 |
+
True by default.
|
| 247 |
+
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
| 248 |
+
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
| 249 |
+
start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
|
| 250 |
+
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
| 251 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
| 252 |
+
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
| 253 |
+
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
| 254 |
+
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
num_layers: int = 12,
|
| 260 |
+
hidden_size: int = 768,
|
| 261 |
+
num_attention_heads: int = 12,
|
| 262 |
+
vocab_size: int = 50304,
|
| 263 |
+
mlp_ratio: int = 4.0,
|
| 264 |
+
mlp_bias: bool = False,
|
| 265 |
+
attn_drop_rate: float = 0.0,
|
| 266 |
+
drop_path_rate: float = 0.0,
|
| 267 |
+
dtype: torch.dtype = torch.float,
|
| 268 |
+
checkpoint: float = 0.0,
|
| 269 |
+
layer_norm_epsilon: float = 1e-5,
|
| 270 |
+
first: bool = False,
|
| 271 |
+
last: bool = False,
|
| 272 |
+
embed_split_hidden: bool = False,
|
| 273 |
+
embed_grad_scale: float = 0.1,
|
| 274 |
+
parallel_output: bool = True,
|
| 275 |
+
start_layer_idx: int = 0,
|
| 276 |
+
device: Optional[torch.device] = None,
|
| 277 |
+
residual_in_fp32: bool = False,
|
| 278 |
+
norm_type: str = "rmsnorm",
|
| 279 |
+
is_reward: bool = False,
|
| 280 |
+
dropout_selective_checkpoint: bool = True,
|
| 281 |
+
use_scaled_init: bool = True,
|
| 282 |
+
use_swiglu: bool = True,
|
| 283 |
+
use_flash_attn: bool = True,
|
| 284 |
+
lvm_config: dict = None,
|
| 285 |
+
):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.lvm_config = lvm_config
|
| 288 |
+
|
| 289 |
+
checkpoint_layer_num = int(num_layers * checkpoint)
|
| 290 |
+
|
| 291 |
+
head_cls = ScaleColumnParallelLinear
|
| 292 |
+
if first:
|
| 293 |
+
if self.lvm_config.get('enable', False):
|
| 294 |
+
self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg'))
|
| 295 |
+
if self.embedding.embed_proj is not None:
|
| 296 |
+
for _, param in self.embedding.embed_proj.named_parameters():
|
| 297 |
+
normal_(std=0.0052)(param)
|
| 298 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 299 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 300 |
+
else:
|
| 301 |
+
if embed_split_hidden:
|
| 302 |
+
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
|
| 303 |
+
else:
|
| 304 |
+
self.embedding = ParallelGPT2Embeddings(
|
| 305 |
+
embed_dim=hidden_size,
|
| 306 |
+
vocab_size=vocab_size,
|
| 307 |
+
max_position_embeddings=-1,
|
| 308 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 309 |
+
padding_idx=None,
|
| 310 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 311 |
+
device=device,
|
| 312 |
+
dtype=dtype,
|
| 313 |
+
)
|
| 314 |
+
for _, param in self.embedding.named_parameters():
|
| 315 |
+
normal_(std=0.0052)(param)
|
| 316 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 317 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 318 |
+
self.embed_grad_scale = embed_grad_scale
|
| 319 |
+
self.blocks = nn.ModuleList(
|
| 320 |
+
[
|
| 321 |
+
PackedFlashBaseLayer1D(
|
| 322 |
+
hidden_size=hidden_size,
|
| 323 |
+
num_attention_heads=num_attention_heads,
|
| 324 |
+
mlp_ratio=mlp_ratio,
|
| 325 |
+
mlp_bias=mlp_bias,
|
| 326 |
+
attn_drop_rate=attn_drop_rate,
|
| 327 |
+
drop_path_rate=drop_path_rate,
|
| 328 |
+
dtype=dtype,
|
| 329 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
| 330 |
+
checkpoint=lid < checkpoint_layer_num,
|
| 331 |
+
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
|
| 332 |
+
residual_in_fp32=residual_in_fp32,
|
| 333 |
+
device=device,
|
| 334 |
+
norm_type=norm_type,
|
| 335 |
+
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
| 336 |
+
use_scaled_init=use_scaled_init,
|
| 337 |
+
use_swiglu=use_swiglu,
|
| 338 |
+
use_flash_attn=use_flash_attn,
|
| 339 |
+
)
|
| 340 |
+
for lid in range(num_layers)
|
| 341 |
+
]
|
| 342 |
+
)
|
| 343 |
+
if last:
|
| 344 |
+
if norm_type == "rmsnorm":
|
| 345 |
+
self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
|
| 346 |
+
else:
|
| 347 |
+
self.norm = LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
| 348 |
+
self.head = head_cls(
|
| 349 |
+
in_features=hidden_size,
|
| 350 |
+
out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
|
| 351 |
+
process_group=gpc.get_group(ParallelMode.TENSOR),
|
| 352 |
+
bias=False,
|
| 353 |
+
device=device,
|
| 354 |
+
dtype=dtype,
|
| 355 |
+
weight_scale=embed_grad_scale,
|
| 356 |
+
)
|
| 357 |
+
for _, param in self.head.named_parameters():
|
| 358 |
+
normal_(std=0.0052)(param)
|
| 359 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 360 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 361 |
+
self.parallel_output = parallel_output
|
| 362 |
+
|
| 363 |
+
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
| 364 |
+
# attention_mask: compute attention on the places where the value is 1
|
| 365 |
+
if hasattr(self, "embedding"):
|
| 366 |
+
hidden_states = self.embedding(input_ids)
|
| 367 |
+
if self.embed_grad_scale != 1:
|
| 368 |
+
hidden_states = (
|
| 369 |
+
self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
|
| 370 |
+
)
|
| 371 |
+
if isinstance(cu_seqlens, list):
|
| 372 |
+
assert len(cu_seqlens) == 1
|
| 373 |
+
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
|
| 374 |
+
|
| 375 |
+
if cu_seqlens is not None:
|
| 376 |
+
cu_seqlens = cu_seqlens.squeeze(0)
|
| 377 |
+
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
|
| 378 |
+
# the batch dimension with a size of 1 should be directly squeezed off.
|
| 379 |
+
|
| 380 |
+
if indexes is not None:
|
| 381 |
+
assert len(indexes) == 1
|
| 382 |
+
# The indexes are used to indicate the actual position IDs of each token in the packed input.
|
| 383 |
+
indexes = indexes[0]
|
| 384 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
|
| 385 |
+
|
| 386 |
+
for _, block in enumerate(self.blocks):
|
| 387 |
+
hidden_states = block(
|
| 388 |
+
hidden_states,
|
| 389 |
+
cu_seqlens=cu_seqlens,
|
| 390 |
+
indexes=indexes,
|
| 391 |
+
inference_params=inference_params,
|
| 392 |
+
max_seqlen=max_seqlen,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
if hasattr(self, "norm"):
|
| 396 |
+
hidden_states = self.norm(hidden_states.float())
|
| 397 |
+
if hasattr(self, "head"):
|
| 398 |
+
hidden_states = self.head(hidden_states)
|
| 399 |
+
|
| 400 |
+
if not self.parallel_output:
|
| 401 |
+
hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
|
| 402 |
+
return hidden_states
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
| 406 |
+
"""
|
| 407 |
+
build generic model 1d
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
num_layers (int): The number of layer.
|
| 411 |
+
num_chunks (int): The number of partitions in pipeline parallel.
|
| 412 |
+
device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
|
| 413 |
+
|
| 414 |
+
"""
|
| 415 |
+
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
| 416 |
+
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
| 417 |
+
|
| 418 |
+
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
| 419 |
+
parts = all_parts[pipeline_rank]
|
| 420 |
+
if gpc.is_rank_for_log():
|
| 421 |
+
logger.info(f"The layer sharding is {all_parts}.")
|
| 422 |
+
|
| 423 |
+
models = []
|
| 424 |
+
|
| 425 |
+
for start, end in parts:
|
| 426 |
+
kwargs["num_layers"] = end - start
|
| 427 |
+
kwargs["first"] = start == 0
|
| 428 |
+
# If there is no content in the final layer, assign the last layer.
|
| 429 |
+
kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
|
| 430 |
+
kwargs["device"] = device
|
| 431 |
+
kwargs["start_layer_idx"] = start
|
| 432 |
+
chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
|
| 433 |
+
|
| 434 |
+
models.append(chunk)
|
| 435 |
+
torch.distributed.barrier()
|
| 436 |
+
if len(models) == 1:
|
| 437 |
+
model = models[0]
|
| 438 |
+
else:
|
| 439 |
+
model = nn.ModuleList(models)
|
| 440 |
+
|
| 441 |
+
return model
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
|
| 445 |
+
def build_vit_model_with_cfg(
|
| 446 |
+
num_chunks=1,
|
| 447 |
+
checkpoint=0.0,
|
| 448 |
+
dtype=torch.float,
|
| 449 |
+
embed_split_hidden=False,
|
| 450 |
+
num_layers=48,
|
| 451 |
+
hidden_size=2048,
|
| 452 |
+
vocab_size=50304,
|
| 453 |
+
embed_grad_scale=1,
|
| 454 |
+
parallel_output=True,
|
| 455 |
+
num_attention_heads=32,
|
| 456 |
+
mlp_ratio=4.0,
|
| 457 |
+
mlp_bias: bool = False,
|
| 458 |
+
residual_in_fp32=False,
|
| 459 |
+
norm_type="rmsnorm",
|
| 460 |
+
drop_path_rate=0,
|
| 461 |
+
attn_drop_rate=0,
|
| 462 |
+
apply_post_layer_norm=False, # pylint: disable=W0613
|
| 463 |
+
layer_norm_epsilon=1e-5,
|
| 464 |
+
is_reward=False,
|
| 465 |
+
dropout_selective_checkpoint=True,
|
| 466 |
+
use_scaled_init: bool = True,
|
| 467 |
+
use_swiglu: bool = True,
|
| 468 |
+
use_flash_attn: bool = True,
|
| 469 |
+
lvm_config=None,
|
| 470 |
+
):
|
| 471 |
+
"""
|
| 472 |
+
Build model with config.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
|
| 476 |
+
checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
|
| 477 |
+
dtype (torch.dtype): The type of data. torch.float by default.
|
| 478 |
+
embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
|
| 479 |
+
False by default.
|
| 480 |
+
num_layers (int): The number of layer. 48 by default.
|
| 481 |
+
hidden_size (int): The size of hidden state. 2048 by default.
|
| 482 |
+
vocab_size (int): The size of vocabulary. 50304 by default.
|
| 483 |
+
embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
|
| 484 |
+
parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
|
| 485 |
+
num_attention_heads (int): The number of attention head. 32 by default.
|
| 486 |
+
mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
|
| 487 |
+
residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
|
| 488 |
+
because this parameter requires inconsistent data types to be passed between pipelines,
|
| 489 |
+
which requires significant modifications to internlm.
|
| 490 |
+
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
| 491 |
+
drop_path_rate (float): The drop path rate rate of input hidden state. 0 by default.
|
| 492 |
+
attn_drop_rate (float): The dropout rate of attention module. 0 by default.
|
| 493 |
+
apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
|
| 494 |
+
layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
|
| 495 |
+
is_reward (bool): Whether to use reward model. False by default.
|
| 496 |
+
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
| 497 |
+
use_scaled_init (bool): Whether to use scaled init. True by default.
|
| 498 |
+
use_swiglu (bool): Whether to use swiglu. True by default.
|
| 499 |
+
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
| 500 |
+
|
| 501 |
+
"""
|
| 502 |
+
|
| 503 |
+
cfg = dict(
|
| 504 |
+
hidden_size=hidden_size,
|
| 505 |
+
num_attention_heads=num_attention_heads,
|
| 506 |
+
checkpoint=checkpoint,
|
| 507 |
+
dtype=dtype,
|
| 508 |
+
embed_split_hidden=embed_split_hidden,
|
| 509 |
+
vocab_size=vocab_size,
|
| 510 |
+
embed_grad_scale=embed_grad_scale,
|
| 511 |
+
parallel_output=parallel_output,
|
| 512 |
+
mlp_ratio=mlp_ratio,
|
| 513 |
+
mlp_bias=mlp_bias,
|
| 514 |
+
residual_in_fp32=residual_in_fp32,
|
| 515 |
+
norm_type=norm_type,
|
| 516 |
+
drop_path_rate=drop_path_rate,
|
| 517 |
+
attn_drop_rate=attn_drop_rate,
|
| 518 |
+
layer_norm_epsilon=layer_norm_epsilon,
|
| 519 |
+
is_reward=is_reward,
|
| 520 |
+
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
| 521 |
+
use_scaled_init=use_scaled_init,
|
| 522 |
+
use_swiglu=use_swiglu,
|
| 523 |
+
use_flash_attn=use_flash_attn,
|
| 524 |
+
lvm_config=lvm_config,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
InternLM/internlm/model/multi_head_attention.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from flash_attn.modules.mha import (
|
| 9 |
+
CrossAttention,
|
| 10 |
+
FlashCrossAttention,
|
| 11 |
+
FlashSelfAttention,
|
| 12 |
+
SelfAttention,
|
| 13 |
+
_update_kv_cache,
|
| 14 |
+
)
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
| 18 |
+
from internlm.core.context import global_context as gpc
|
| 19 |
+
from internlm.model.embedding import RotaryEmbedding
|
| 20 |
+
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MHA(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
Multi-head self-attention and cross-attention.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
embed_dim (int): The dimention of hidden state.
|
| 29 |
+
num_heads (int): The number of attention heads.
|
| 30 |
+
process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
|
| 31 |
+
bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
|
| 32 |
+
output projection. True by default.
|
| 33 |
+
dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
|
| 34 |
+
softmax_scale (float): The temperature to use for the softmax attention.
|
| 35 |
+
causal (boolean): Whether to apply causal attention mask. False by default.
|
| 36 |
+
layer_idx (int): The index of current layer. None by default.
|
| 37 |
+
rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
|
| 38 |
+
rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
|
| 39 |
+
XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
|
| 40 |
+
use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
|
| 41 |
+
False by default.
|
| 42 |
+
sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw
|
| 43 |
+
of x will be done before doing the matmul.
|
| 44 |
+
device (Optional[Union[str, torch.device]]): The device will be used.
|
| 45 |
+
dtype (Optional[torch.dtype]): The type of data.
|
| 46 |
+
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
embed_dim: int,
|
| 53 |
+
num_heads: int,
|
| 54 |
+
process_group: Optional[torch.distributed.ProcessGroup],
|
| 55 |
+
dropout: float = 0.0,
|
| 56 |
+
softmax_scale: float = None,
|
| 57 |
+
causal: bool = False,
|
| 58 |
+
layer_idx: int = None,
|
| 59 |
+
rotary_emb_dim: int = 0,
|
| 60 |
+
rotary_emb_scale_base: int = 0,
|
| 61 |
+
use_flash_attn: bool = True,
|
| 62 |
+
device: Optional[torch.device] = None,
|
| 63 |
+
dtype: Optional[torch.dtype] = None,
|
| 64 |
+
) -> None:
|
| 65 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.embed_dim = embed_dim
|
| 68 |
+
self.causal = causal
|
| 69 |
+
self.layer_idx = layer_idx
|
| 70 |
+
self.rotary_emb_dim = rotary_emb_dim
|
| 71 |
+
self.use_flash_attn = use_flash_attn
|
| 72 |
+
self.num_heads = num_heads
|
| 73 |
+
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
| 74 |
+
self.head_dim = self.embed_dim // num_heads
|
| 75 |
+
|
| 76 |
+
if self.rotary_emb_dim > 0:
|
| 77 |
+
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
| 78 |
+
|
| 79 |
+
# notice here should change bias=True
|
| 80 |
+
self.Wqkv = ColumnParallelLinearTorch(
|
| 81 |
+
embed_dim,
|
| 82 |
+
3 * embed_dim,
|
| 83 |
+
process_group,
|
| 84 |
+
bias=True,
|
| 85 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 86 |
+
**factory_kwargs,
|
| 87 |
+
) # according to https://spaces.ac.cn/archives/9577
|
| 88 |
+
|
| 89 |
+
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
| 90 |
+
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
| 91 |
+
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
|
| 92 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
| 93 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# output projection always have the bias (for now)
|
| 97 |
+
self.out_proj = RowParallelLinearTorch(
|
| 98 |
+
embed_dim,
|
| 99 |
+
embed_dim,
|
| 100 |
+
process_group,
|
| 101 |
+
sequence_parallel=gpc.config.parallel.sequence_parallel,
|
| 102 |
+
**factory_kwargs,
|
| 103 |
+
)
|
| 104 |
+
# need to assign tp attribute so that internlm know it is tensor parallel module
|
| 105 |
+
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
| 106 |
+
for name in ["out_proj", "Wqkv"]:
|
| 107 |
+
for param in getattr(self, name).parameters():
|
| 108 |
+
setattr(param, IS_TENSOR_PARALLEL, True)
|
| 109 |
+
|
| 110 |
+
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
| 111 |
+
if kwargs.get("indexes", None) is not None:
|
| 112 |
+
return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
|
| 113 |
+
else:
|
| 114 |
+
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
| 115 |
+
|
| 116 |
+
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
| 117 |
+
"""
|
| 118 |
+
Arguments:
|
| 119 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
| 120 |
+
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
| 121 |
+
split x during sequence parallel, we split the batch * seqlen dimension
|
| 122 |
+
(in case batch is small).
|
| 123 |
+
"""
|
| 124 |
+
qkv = self.Wqkv(x)
|
| 125 |
+
if seqlen is None:
|
| 126 |
+
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
| 127 |
+
else:
|
| 128 |
+
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
| 129 |
+
|
| 130 |
+
if self.rotary_emb_dim > 0:
|
| 131 |
+
kwargs["inference_params"] = inference_params
|
| 132 |
+
qkv = self.rotary_emb(qkv, **kwargs)
|
| 133 |
+
|
| 134 |
+
if inference_params is None:
|
| 135 |
+
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
| 136 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 137 |
+
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
| 138 |
+
qkv = qkv.to(torch.bfloat16)
|
| 139 |
+
context = self.inner_attn(qkv).to(x.dtype)
|
| 140 |
+
else:
|
| 141 |
+
context = self.inner_attn(qkv)
|
| 142 |
+
else:
|
| 143 |
+
q = qkv[:, :, 0]
|
| 144 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
| 145 |
+
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
| 146 |
+
# If we're processing the prompt, causal=None (use self.causal).
|
| 147 |
+
# If we're decoding, then causal=False.
|
| 148 |
+
causal = None if inference_params.sequence_len_offset == 0 else False
|
| 149 |
+
context = self.inner_cross_attn(q, kv, causal=causal)
|
| 150 |
+
|
| 151 |
+
if seqlen is None:
|
| 152 |
+
context = rearrange(context, "b s h d -> b s (h d)")
|
| 153 |
+
else:
|
| 154 |
+
context = rearrange(context, "b s h d -> (b s) (h d)")
|
| 155 |
+
|
| 156 |
+
out = self.out_proj(context)
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
def _packed_forward(self, x, inference_params=None, **kwargs):
|
| 160 |
+
"""
|
| 161 |
+
Arguments:
|
| 162 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
| 163 |
+
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
|
| 164 |
+
split x during sequence parallel, we split the batch * seqlen dimension
|
| 165 |
+
(in case batch is small).
|
| 166 |
+
"""
|
| 167 |
+
qkv = self.Wqkv(x) # total x hsz'
|
| 168 |
+
qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
|
| 169 |
+
qkv = self.rotary_emb(qkv, **kwargs)
|
| 170 |
+
kwargs.pop("indexes")
|
| 171 |
+
|
| 172 |
+
if inference_params is None:
|
| 173 |
+
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
| 174 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 175 |
+
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
| 176 |
+
qkv = qkv.to(torch.bfloat16)
|
| 177 |
+
context = self.inner_attn(qkv, **kwargs).to(x.dtype)
|
| 178 |
+
else:
|
| 179 |
+
context = self.inner_attn(qkv, **kwargs)
|
| 180 |
+
|
| 181 |
+
else:
|
| 182 |
+
raise RuntimeError("Not support this right now")
|
| 183 |
+
|
| 184 |
+
context = rearrange(context, "b h d -> b (h d)") # recover the shape
|
| 185 |
+
out = self.out_proj(context)
|
| 186 |
+
return out
|
InternLM/internlm/model/muse/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
__version__ = "0.0.1"
|
| 17 |
+
|
| 18 |
+
from .modeling_taming_vqgan import VQGANModel
|
InternLM/internlm/model/muse/modeling_taming_vqgan.py
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Taming Transformers Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import torch.utils.checkpoint
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Upsample(nn.Module):
|
| 28 |
+
def __init__(self, in_channels: int, with_conv: bool):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.with_conv = with_conv
|
| 31 |
+
if self.with_conv:
|
| 32 |
+
self.conv = nn.Conv2d(
|
| 33 |
+
in_channels,
|
| 34 |
+
in_channels,
|
| 35 |
+
kernel_size=3,
|
| 36 |
+
stride=1,
|
| 37 |
+
padding=1,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, hidden_states):
|
| 41 |
+
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
| 42 |
+
if self.with_conv:
|
| 43 |
+
hidden_states = self.conv(hidden_states)
|
| 44 |
+
return hidden_states
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Downsample(nn.Module):
|
| 48 |
+
def __init__(self, in_channels: int, with_conv: bool):
|
| 49 |
+
super().__init__()
|
| 50 |
+
|
| 51 |
+
self.with_conv = with_conv
|
| 52 |
+
if self.with_conv:
|
| 53 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 54 |
+
|
| 55 |
+
def forward(self, hidden_states):
|
| 56 |
+
if self.with_conv:
|
| 57 |
+
pad = (0, 1, 0, 1) # pad height and width dim
|
| 58 |
+
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
| 59 |
+
hidden_states = self.conv(hidden_states)
|
| 60 |
+
else:
|
| 61 |
+
hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
|
| 62 |
+
return hidden_states
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ResnetBlock(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
in_channels: int,
|
| 69 |
+
out_channels: int = None,
|
| 70 |
+
use_conv_shortcut: bool = False,
|
| 71 |
+
dropout_prob: float = 0.0,
|
| 72 |
+
):
|
| 73 |
+
super().__init__()
|
| 74 |
+
|
| 75 |
+
self.in_channels = in_channels
|
| 76 |
+
self.out_channels = out_channels
|
| 77 |
+
self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
|
| 78 |
+
self.use_conv_shortcut = use_conv_shortcut
|
| 79 |
+
|
| 80 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 81 |
+
self.conv1 = nn.Conv2d(
|
| 82 |
+
self.in_channels,
|
| 83 |
+
self.out_channels_,
|
| 84 |
+
kernel_size=3,
|
| 85 |
+
stride=1,
|
| 86 |
+
padding=1,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
|
| 90 |
+
self.dropout = nn.Dropout(dropout_prob)
|
| 91 |
+
self.conv2 = nn.Conv2d(
|
| 92 |
+
self.out_channels_,
|
| 93 |
+
self.out_channels_,
|
| 94 |
+
kernel_size=3,
|
| 95 |
+
stride=(1, 1),
|
| 96 |
+
padding=1,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if self.in_channels != self.out_channels_:
|
| 100 |
+
if use_conv_shortcut:
|
| 101 |
+
self.conv_shortcut = nn.Conv2d(
|
| 102 |
+
self.in_channels,
|
| 103 |
+
self.out_channels_,
|
| 104 |
+
kernel_size=3,
|
| 105 |
+
stride=1,
|
| 106 |
+
padding=1,
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
self.nin_shortcut = nn.Conv2d(
|
| 110 |
+
self.in_channels,
|
| 111 |
+
self.out_channels_,
|
| 112 |
+
kernel_size=1,
|
| 113 |
+
stride=1,
|
| 114 |
+
padding=0,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def forward(self, hidden_states):
|
| 118 |
+
residual = hidden_states
|
| 119 |
+
hidden_states = self.norm1(hidden_states)
|
| 120 |
+
hidden_states = F.silu(hidden_states)
|
| 121 |
+
hidden_states = self.conv1(hidden_states)
|
| 122 |
+
|
| 123 |
+
hidden_states = self.norm2(hidden_states)
|
| 124 |
+
hidden_states = F.silu(hidden_states)
|
| 125 |
+
hidden_states = self.dropout(hidden_states)
|
| 126 |
+
hidden_states = self.conv2(hidden_states)
|
| 127 |
+
|
| 128 |
+
if self.in_channels != self.out_channels_:
|
| 129 |
+
if self.use_conv_shortcut:
|
| 130 |
+
residual = self.conv_shortcut(residual)
|
| 131 |
+
else:
|
| 132 |
+
residual = self.nin_shortcut(residual)
|
| 133 |
+
|
| 134 |
+
return hidden_states + residual
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class AttnBlock(nn.Module):
|
| 138 |
+
def __init__(self, in_channels: int):
|
| 139 |
+
super().__init__()
|
| 140 |
+
|
| 141 |
+
self.in_channels = in_channels
|
| 142 |
+
conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0)
|
| 143 |
+
|
| 144 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
|
| 145 |
+
self.q, self.k, self.v = conv(), conv(), conv()
|
| 146 |
+
self.proj_out = conv()
|
| 147 |
+
|
| 148 |
+
def forward(self, hidden_states):
|
| 149 |
+
residual = hidden_states
|
| 150 |
+
hidden_states = self.norm(hidden_states)
|
| 151 |
+
|
| 152 |
+
query = self.q(hidden_states)
|
| 153 |
+
key = self.k(hidden_states)
|
| 154 |
+
value = self.v(hidden_states)
|
| 155 |
+
|
| 156 |
+
# compute attentions
|
| 157 |
+
batch, channels, height, width = query.shape
|
| 158 |
+
query = query.reshape((batch, channels, height * width))
|
| 159 |
+
query = query.permute(0, 2, 1) # (b, hw, c)
|
| 160 |
+
key = key.reshape((batch, channels, height * width))
|
| 161 |
+
|
| 162 |
+
attn_weights = torch.bmm(query, key) # b,hw,hw
|
| 163 |
+
attn_weights = attn_weights * (int(channels) ** -0.5)
|
| 164 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=2)
|
| 165 |
+
|
| 166 |
+
# attend to values
|
| 167 |
+
value = value.reshape((batch, channels, height * width))
|
| 168 |
+
attn_weights = attn_weights.permute(0, 2, 1)
|
| 169 |
+
hidden_states = torch.bmm(value, attn_weights)
|
| 170 |
+
hidden_states = hidden_states.reshape((batch, channels, height, width))
|
| 171 |
+
|
| 172 |
+
hidden_states = self.proj_out(hidden_states)
|
| 173 |
+
hidden_states = hidden_states + residual
|
| 174 |
+
return hidden_states
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class UpsamplingBlock(nn.Module):
|
| 178 |
+
def __init__(self, config, curr_res: int, block_idx: int):
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
self.config = config
|
| 182 |
+
self.block_idx = block_idx
|
| 183 |
+
self.curr_res = curr_res
|
| 184 |
+
|
| 185 |
+
if self.block_idx == self.config.num_resolutions - 1:
|
| 186 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[-1]
|
| 187 |
+
else:
|
| 188 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
|
| 189 |
+
|
| 190 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
|
| 191 |
+
|
| 192 |
+
res_blocks = []
|
| 193 |
+
attn_blocks = []
|
| 194 |
+
for _ in range(self.config.num_res_blocks + 1):
|
| 195 |
+
res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
|
| 196 |
+
block_in = block_out
|
| 197 |
+
if self.curr_res in self.config.attn_resolutions:
|
| 198 |
+
attn_blocks.append(AttnBlock(block_in))
|
| 199 |
+
|
| 200 |
+
self.block = nn.ModuleList(res_blocks)
|
| 201 |
+
self.attn = nn.ModuleList(attn_blocks)
|
| 202 |
+
|
| 203 |
+
self.upsample = None
|
| 204 |
+
if self.block_idx != 0:
|
| 205 |
+
self.upsample = Upsample(block_in, self.config.resample_with_conv)
|
| 206 |
+
|
| 207 |
+
def forward(self, hidden_states):
|
| 208 |
+
for i, res_block in enumerate(self.block):
|
| 209 |
+
hidden_states = res_block(hidden_states)
|
| 210 |
+
if len(self.attn) > 1:
|
| 211 |
+
hidden_states = self.attn[i](hidden_states)
|
| 212 |
+
|
| 213 |
+
if self.upsample is not None:
|
| 214 |
+
hidden_states = self.upsample(hidden_states)
|
| 215 |
+
|
| 216 |
+
return hidden_states
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class DownsamplingBlock(nn.Module):
|
| 220 |
+
def __init__(self, config, curr_res: int, block_idx: int):
|
| 221 |
+
super().__init__()
|
| 222 |
+
|
| 223 |
+
self.config = config
|
| 224 |
+
self.curr_res = curr_res
|
| 225 |
+
self.block_idx = block_idx
|
| 226 |
+
|
| 227 |
+
in_channel_mult = (1,) + tuple(self.config.channel_mult)
|
| 228 |
+
block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
|
| 229 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
|
| 230 |
+
|
| 231 |
+
res_blocks = nn.ModuleList()
|
| 232 |
+
attn_blocks = nn.ModuleList()
|
| 233 |
+
for _ in range(self.config.num_res_blocks):
|
| 234 |
+
res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
|
| 235 |
+
block_in = block_out
|
| 236 |
+
if self.curr_res in self.config.attn_resolutions:
|
| 237 |
+
attn_blocks.append(AttnBlock(block_in))
|
| 238 |
+
|
| 239 |
+
self.block = res_blocks
|
| 240 |
+
self.attn = attn_blocks
|
| 241 |
+
|
| 242 |
+
self.downsample = None
|
| 243 |
+
if self.block_idx != self.config.num_resolutions - 1:
|
| 244 |
+
self.downsample = Downsample(block_in, self.config.resample_with_conv)
|
| 245 |
+
|
| 246 |
+
def forward(self, hidden_states):
|
| 247 |
+
for i, res_block in enumerate(self.block):
|
| 248 |
+
hidden_states = res_block(hidden_states)
|
| 249 |
+
if len(self.attn) > 1:
|
| 250 |
+
hidden_states = self.attn[i](hidden_states)
|
| 251 |
+
|
| 252 |
+
if self.downsample is not None:
|
| 253 |
+
hidden_states = self.downsample(hidden_states)
|
| 254 |
+
|
| 255 |
+
return hidden_states
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class MidBlock(nn.Module):
|
| 259 |
+
def __init__(self, config, in_channels: int, no_attn: False, dropout: float):
|
| 260 |
+
super().__init__()
|
| 261 |
+
|
| 262 |
+
self.config = config
|
| 263 |
+
self.in_channels = in_channels
|
| 264 |
+
self.no_attn = no_attn
|
| 265 |
+
self.dropout = dropout
|
| 266 |
+
|
| 267 |
+
self.block_1 = ResnetBlock(
|
| 268 |
+
self.in_channels,
|
| 269 |
+
self.in_channels,
|
| 270 |
+
dropout_prob=self.dropout,
|
| 271 |
+
)
|
| 272 |
+
if not no_attn:
|
| 273 |
+
self.attn_1 = AttnBlock(self.in_channels)
|
| 274 |
+
self.block_2 = ResnetBlock(
|
| 275 |
+
self.in_channels,
|
| 276 |
+
self.in_channels,
|
| 277 |
+
dropout_prob=self.dropout,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def forward(self, hidden_states):
|
| 281 |
+
hidden_states = self.block_1(hidden_states)
|
| 282 |
+
if not self.no_attn:
|
| 283 |
+
hidden_states = self.attn_1(hidden_states)
|
| 284 |
+
hidden_states = self.block_2(hidden_states)
|
| 285 |
+
return hidden_states
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class Encoder(nn.Module):
|
| 289 |
+
def __init__(self, config):
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.config = config
|
| 293 |
+
|
| 294 |
+
# downsampling
|
| 295 |
+
self.conv_in = nn.Conv2d(
|
| 296 |
+
self.config.num_channels,
|
| 297 |
+
self.config.hidden_channels,
|
| 298 |
+
kernel_size=3,
|
| 299 |
+
stride=1,
|
| 300 |
+
padding=1,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
curr_res = self.config.resolution
|
| 304 |
+
downsample_blocks = []
|
| 305 |
+
for i_level in range(self.config.num_resolutions):
|
| 306 |
+
downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
|
| 307 |
+
|
| 308 |
+
if i_level != self.config.num_resolutions - 1:
|
| 309 |
+
curr_res = curr_res // 2
|
| 310 |
+
self.down = nn.ModuleList(downsample_blocks)
|
| 311 |
+
|
| 312 |
+
# middle
|
| 313 |
+
mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
|
| 314 |
+
self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout)
|
| 315 |
+
|
| 316 |
+
# end
|
| 317 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
|
| 318 |
+
self.conv_out = nn.Conv2d(
|
| 319 |
+
mid_channels,
|
| 320 |
+
self.config.z_channels,
|
| 321 |
+
kernel_size=3,
|
| 322 |
+
stride=1,
|
| 323 |
+
padding=1,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
def forward(self, pixel_values):
|
| 327 |
+
# downsampling
|
| 328 |
+
hidden_states = self.conv_in(pixel_values)
|
| 329 |
+
for block in self.down:
|
| 330 |
+
hidden_states = block(hidden_states)
|
| 331 |
+
|
| 332 |
+
# middle
|
| 333 |
+
hidden_states = self.mid(hidden_states)
|
| 334 |
+
|
| 335 |
+
# end
|
| 336 |
+
hidden_states = self.norm_out(hidden_states)
|
| 337 |
+
hidden_states = F.silu(hidden_states)
|
| 338 |
+
hidden_states = self.conv_out(hidden_states)
|
| 339 |
+
|
| 340 |
+
return hidden_states
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class Decoder(nn.Module):
|
| 344 |
+
def __init__(self, config):
|
| 345 |
+
super().__init__()
|
| 346 |
+
|
| 347 |
+
self.config = config
|
| 348 |
+
|
| 349 |
+
# compute in_channel_mult, block_in and curr_res at lowest res
|
| 350 |
+
block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
|
| 351 |
+
curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
|
| 352 |
+
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
|
| 353 |
+
|
| 354 |
+
# z to block_in
|
| 355 |
+
self.conv_in = nn.Conv2d(
|
| 356 |
+
self.config.z_channels,
|
| 357 |
+
block_in,
|
| 358 |
+
kernel_size=3,
|
| 359 |
+
stride=1,
|
| 360 |
+
padding=1,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# middle
|
| 364 |
+
self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout)
|
| 365 |
+
|
| 366 |
+
# upsampling
|
| 367 |
+
upsample_blocks = []
|
| 368 |
+
for i_level in reversed(range(self.config.num_resolutions)):
|
| 369 |
+
upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level))
|
| 370 |
+
if i_level != 0:
|
| 371 |
+
curr_res = curr_res * 2
|
| 372 |
+
self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
|
| 373 |
+
|
| 374 |
+
# end
|
| 375 |
+
block_out = self.config.hidden_channels * self.config.channel_mult[0]
|
| 376 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
|
| 377 |
+
self.conv_out = nn.Conv2d(
|
| 378 |
+
block_out,
|
| 379 |
+
self.config.num_channels,
|
| 380 |
+
kernel_size=3,
|
| 381 |
+
stride=1,
|
| 382 |
+
padding=1,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def forward(self, hidden_states):
|
| 386 |
+
# z to block_in
|
| 387 |
+
hidden_states = self.conv_in(hidden_states)
|
| 388 |
+
|
| 389 |
+
# middle
|
| 390 |
+
hidden_states = self.mid(hidden_states)
|
| 391 |
+
|
| 392 |
+
# upsampling
|
| 393 |
+
for block in reversed(self.up):
|
| 394 |
+
hidden_states = block(hidden_states)
|
| 395 |
+
|
| 396 |
+
# end
|
| 397 |
+
hidden_states = self.norm_out(hidden_states)
|
| 398 |
+
hidden_states = F.silu(hidden_states)
|
| 399 |
+
hidden_states = self.conv_out(hidden_states)
|
| 400 |
+
|
| 401 |
+
return hidden_states
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class VectorQuantizer(nn.Module):
|
| 405 |
+
"""
|
| 406 |
+
see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
|
| 407 |
+
Discretization bottleneck part of the VQ-VAE.
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
def __init__(self, num_embeddings, embedding_dim, commitment_cost):
|
| 411 |
+
r"""
|
| 412 |
+
Args:
|
| 413 |
+
num_embeddings: number of vectors in the quantized space.
|
| 414 |
+
embedding_dim: dimensionality of the tensors in the quantized space.
|
| 415 |
+
Inputs to the modules must be in this format as well.
|
| 416 |
+
commitment_cost: scalar which controls the weighting of the loss terms
|
| 417 |
+
(see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
|
| 418 |
+
"""
|
| 419 |
+
super().__init__()
|
| 420 |
+
|
| 421 |
+
self.num_embeddings = num_embeddings
|
| 422 |
+
self.embedding_dim = embedding_dim
|
| 423 |
+
self.commitment_cost = commitment_cost
|
| 424 |
+
|
| 425 |
+
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
| 426 |
+
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
|
| 427 |
+
|
| 428 |
+
def forward(self, hidden_states, return_loss=False):
|
| 429 |
+
"""
|
| 430 |
+
Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
|
| 431 |
+
closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
|
| 432 |
+
quantization pipeline:
|
| 433 |
+
1. get encoder input (B,C,H,W)
|
| 434 |
+
2. flatten input to (B*H*W,C)
|
| 435 |
+
"""
|
| 436 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
| 437 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
|
| 438 |
+
|
| 439 |
+
distances = self.compute_distances(hidden_states)
|
| 440 |
+
min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
|
| 441 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
|
| 442 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
| 443 |
+
|
| 444 |
+
# get quantized latent vectors
|
| 445 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
|
| 446 |
+
|
| 447 |
+
# reshape to (batch, num_tokens)
|
| 448 |
+
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
|
| 449 |
+
|
| 450 |
+
# compute loss for embedding
|
| 451 |
+
loss = None
|
| 452 |
+
if return_loss:
|
| 453 |
+
loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
|
| 454 |
+
(z_q - hidden_states.detach()) ** 2
|
| 455 |
+
)
|
| 456 |
+
# preserve gradients
|
| 457 |
+
z_q = hidden_states + (z_q - hidden_states).detach()
|
| 458 |
+
|
| 459 |
+
# reshape back to match original input shape
|
| 460 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
| 461 |
+
|
| 462 |
+
return z_q, min_encoding_indices, loss
|
| 463 |
+
|
| 464 |
+
def compute_distances(self, hidden_states):
|
| 465 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
| 466 |
+
hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
|
| 467 |
+
emb_weights = self.embedding.weight.t()
|
| 468 |
+
|
| 469 |
+
inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
|
| 470 |
+
codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
|
| 471 |
+
distances = torch.addmm(
|
| 472 |
+
inputs_norm_sq + codebook_t_norm_sq,
|
| 473 |
+
hidden_states_flattended,
|
| 474 |
+
emb_weights,
|
| 475 |
+
alpha=-2.0,
|
| 476 |
+
)
|
| 477 |
+
return distances
|
| 478 |
+
|
| 479 |
+
def get_codebook_entry(self, indices):
|
| 480 |
+
# indices are expected to be of shape (batch, num_tokens)
|
| 481 |
+
# get quantized latent vectors
|
| 482 |
+
batch, num_tokens = indices.shape
|
| 483 |
+
z_q = self.embedding(indices)
|
| 484 |
+
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
|
| 485 |
+
return z_q
|
| 486 |
+
|
| 487 |
+
def get_codebook_entry_for_lvm(self, indices):
|
| 488 |
+
batch, num_tokens = indices.shape
|
| 489 |
+
z_q = self.embedding(indices)
|
| 490 |
+
z_q = z_q.reshape(batch, num_tokens, -1)
|
| 491 |
+
return z_q
|
| 492 |
+
|
| 493 |
+
# adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
|
| 494 |
+
def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
|
| 495 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
|
| 496 |
+
distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
|
| 497 |
+
|
| 498 |
+
soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
|
| 499 |
+
if stochastic:
|
| 500 |
+
code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
|
| 501 |
+
else:
|
| 502 |
+
code = distances.argmin(dim=-1) # (batch * height * width)
|
| 503 |
+
|
| 504 |
+
code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
|
| 505 |
+
batch, num_tokens = code.shape
|
| 506 |
+
soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
|
| 507 |
+
return soft_code, code
|
| 508 |
+
|
| 509 |
+
def get_code(self, hidden_states):
|
| 510 |
+
# reshape z -> (batch, height, width, channel)
|
| 511 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
|
| 512 |
+
distances = self.compute_distances(hidden_states)
|
| 513 |
+
indices = torch.argmin(distances, axis=1).unsqueeze(1)
|
| 514 |
+
indices = indices.reshape(hidden_states.shape[0], -1)
|
| 515 |
+
return indices
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class VQGANModel(ModelMixin, ConfigMixin):
|
| 519 |
+
@register_to_config
|
| 520 |
+
def __init__(
|
| 521 |
+
self,
|
| 522 |
+
resolution: int = 256,
|
| 523 |
+
num_channels: int = 3,
|
| 524 |
+
hidden_channels: int = 128,
|
| 525 |
+
channel_mult: Tuple = (1, 1, 2, 2, 4),
|
| 526 |
+
num_res_blocks: int = 2,
|
| 527 |
+
attn_resolutions: int = (16,),
|
| 528 |
+
no_attn_mid_block: bool = False,
|
| 529 |
+
z_channels: int = 256,
|
| 530 |
+
num_embeddings: int = 1024,
|
| 531 |
+
quantized_embed_dim: int = 256,
|
| 532 |
+
dropout: float = 0.0,
|
| 533 |
+
resample_with_conv: bool = True,
|
| 534 |
+
commitment_cost: float = 0.25,
|
| 535 |
+
):
|
| 536 |
+
super().__init__()
|
| 537 |
+
|
| 538 |
+
self.config.num_resolutions = len(channel_mult)
|
| 539 |
+
self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1)
|
| 540 |
+
self.config.latent_size = resolution // self.config.reduction_factor
|
| 541 |
+
|
| 542 |
+
self.encoder = Encoder(self.config)
|
| 543 |
+
self.decoder = Decoder(self.config)
|
| 544 |
+
self.quantize = VectorQuantizer(
|
| 545 |
+
self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost
|
| 546 |
+
)
|
| 547 |
+
self.quant_conv = nn.Conv2d(
|
| 548 |
+
self.config.z_channels,
|
| 549 |
+
self.config.quantized_embed_dim,
|
| 550 |
+
kernel_size=1,
|
| 551 |
+
)
|
| 552 |
+
self.post_quant_conv = nn.Conv2d(
|
| 553 |
+
self.config.quantized_embed_dim,
|
| 554 |
+
self.config.z_channels,
|
| 555 |
+
kernel_size=1,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
def encode(self, pixel_values, return_loss=False):
|
| 559 |
+
hidden_states = self.encoder(pixel_values)
|
| 560 |
+
hidden_states = self.quant_conv(hidden_states)
|
| 561 |
+
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
|
| 562 |
+
output = (quantized_states, codebook_indices)
|
| 563 |
+
if return_loss:
|
| 564 |
+
output = output + (codebook_loss,)
|
| 565 |
+
return output
|
| 566 |
+
|
| 567 |
+
def decode(self, quantized_states):
|
| 568 |
+
hidden_states = self.post_quant_conv(quantized_states)
|
| 569 |
+
reconstructed_pixel_values = self.decoder(hidden_states)
|
| 570 |
+
return reconstructed_pixel_values
|
| 571 |
+
|
| 572 |
+
def decode_code(self, codebook_indices):
|
| 573 |
+
quantized_states = self.quantize.get_codebook_entry(codebook_indices)
|
| 574 |
+
reconstructed_pixel_values = self.decode(quantized_states)
|
| 575 |
+
return reconstructed_pixel_values
|
| 576 |
+
|
| 577 |
+
def get_code(self, pixel_values):
|
| 578 |
+
hidden_states = self.encoder(pixel_values)
|
| 579 |
+
hidden_states = self.quant_conv(hidden_states)
|
| 580 |
+
codebook_indices = self.quantize.get_code(hidden_states)
|
| 581 |
+
return codebook_indices
|
| 582 |
+
|
| 583 |
+
def forward(self, pixel_values, return_loss=False):
|
| 584 |
+
hidden_states = self.encoder(pixel_values)
|
| 585 |
+
hidden_states = self.quant_conv(hidden_states)
|
| 586 |
+
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
|
| 587 |
+
reconstructed_pixel_values = self.decode(quantized_states)
|
| 588 |
+
outputs = (reconstructed_pixel_values, quantized_states, codebook_indices)
|
| 589 |
+
if return_loss:
|
| 590 |
+
outputs = outputs + (codebook_loss,)
|
| 591 |
+
return outputs
|
InternLM/internlm/model/muse/modeling_utils.py
ADDED
|
@@ -0,0 +1,1171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import functools
|
| 17 |
+
import inspect
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from collections import OrderedDict
|
| 21 |
+
from functools import partial
|
| 22 |
+
from pathlib import PosixPath
|
| 23 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import accelerate
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from accelerate.utils import set_module_tensor_to_device
|
| 29 |
+
from huggingface_hub import hf_hub_download
|
| 30 |
+
from huggingface_hub.utils import (
|
| 31 |
+
EntryNotFoundError,
|
| 32 |
+
RepositoryNotFoundError,
|
| 33 |
+
RevisionNotFoundError,
|
| 34 |
+
)
|
| 35 |
+
from requests import HTTPError
|
| 36 |
+
from torch import Tensor, device
|
| 37 |
+
|
| 38 |
+
from . import __version__
|
| 39 |
+
from internlm.utils.logger import get_logger
|
| 40 |
+
|
| 41 |
+
logger = get_logger(__file__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
hf_cache_home = os.path.expanduser(
|
| 45 |
+
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
| 46 |
+
)
|
| 47 |
+
default_cache_path = os.path.join(hf_cache_home, "muse")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
CONFIG_NAME = "config.json"
|
| 51 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 52 |
+
SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
|
| 53 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
| 54 |
+
MUSE_CACHE = default_cache_path
|
| 55 |
+
MUSE_DYNAMIC_MODULE_NAME = "myse_modules"
|
| 56 |
+
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
| 63 |
+
try:
|
| 64 |
+
return next(parameter.parameters()).device
|
| 65 |
+
except StopIteration:
|
| 66 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 67 |
+
|
| 68 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 69 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 70 |
+
return tuples
|
| 71 |
+
|
| 72 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 73 |
+
first_tuple = next(gen)
|
| 74 |
+
return first_tuple[1].device
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
| 78 |
+
try:
|
| 79 |
+
return next(parameter.parameters()).dtype
|
| 80 |
+
except StopIteration:
|
| 81 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 82 |
+
|
| 83 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 84 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 85 |
+
return tuples
|
| 86 |
+
|
| 87 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 88 |
+
first_tuple = next(gen)
|
| 89 |
+
return first_tuple[1].dtype
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
| 93 |
+
"""
|
| 94 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
|
| 98 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
| 99 |
+
except Exception as e:
|
| 100 |
+
try:
|
| 101 |
+
with open(checkpoint_file) as f:
|
| 102 |
+
if f.read().startswith("version"):
|
| 103 |
+
raise OSError(
|
| 104 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
| 105 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
| 106 |
+
"you cloned."
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError(
|
| 110 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
| 111 |
+
"model. Make sure you have saved the model properly."
|
| 112 |
+
) from e
|
| 113 |
+
except (UnicodeDecodeError, ValueError):
|
| 114 |
+
raise OSError(
|
| 115 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
| 116 |
+
f"at '{checkpoint_file}'. "
|
| 117 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
| 122 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
| 123 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 124 |
+
state_dict = state_dict.copy()
|
| 125 |
+
error_msgs = []
|
| 126 |
+
|
| 127 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
| 128 |
+
# so we need to apply the function recursively.
|
| 129 |
+
def load(module: torch.nn.Module, prefix=""):
|
| 130 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
| 131 |
+
module._load_from_state_dict(*args)
|
| 132 |
+
|
| 133 |
+
for name, child in module._modules.items():
|
| 134 |
+
if child is not None:
|
| 135 |
+
load(child, prefix + name + ".")
|
| 136 |
+
|
| 137 |
+
load(model_to_load)
|
| 138 |
+
|
| 139 |
+
return error_msgs
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _get_model_file(
|
| 143 |
+
pretrained_model_name_or_path,
|
| 144 |
+
*,
|
| 145 |
+
weights_name,
|
| 146 |
+
subfolder,
|
| 147 |
+
cache_dir,
|
| 148 |
+
force_download,
|
| 149 |
+
proxies,
|
| 150 |
+
resume_download,
|
| 151 |
+
local_files_only,
|
| 152 |
+
use_auth_token,
|
| 153 |
+
user_agent,
|
| 154 |
+
revision,
|
| 155 |
+
):
|
| 156 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 157 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
| 158 |
+
return pretrained_model_name_or_path
|
| 159 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
| 160 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
| 161 |
+
# Load from a PyTorch checkpoint
|
| 162 |
+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
| 163 |
+
return model_file
|
| 164 |
+
elif subfolder is not None and os.path.isfile(
|
| 165 |
+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
| 166 |
+
):
|
| 167 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
| 168 |
+
return model_file
|
| 169 |
+
else:
|
| 170 |
+
raise EnvironmentError(
|
| 171 |
+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
try:
|
| 175 |
+
# Load from URL or cache if already cached
|
| 176 |
+
model_file = hf_hub_download(
|
| 177 |
+
pretrained_model_name_or_path,
|
| 178 |
+
filename=weights_name,
|
| 179 |
+
cache_dir=cache_dir,
|
| 180 |
+
force_download=force_download,
|
| 181 |
+
proxies=proxies,
|
| 182 |
+
resume_download=resume_download,
|
| 183 |
+
local_files_only=local_files_only,
|
| 184 |
+
use_auth_token=use_auth_token,
|
| 185 |
+
user_agent=user_agent,
|
| 186 |
+
subfolder=subfolder,
|
| 187 |
+
revision=revision,
|
| 188 |
+
)
|
| 189 |
+
return model_file
|
| 190 |
+
|
| 191 |
+
except RepositoryNotFoundError:
|
| 192 |
+
raise EnvironmentError(
|
| 193 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
| 194 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
| 195 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
| 196 |
+
"login`."
|
| 197 |
+
)
|
| 198 |
+
except RevisionNotFoundError:
|
| 199 |
+
raise EnvironmentError(
|
| 200 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
| 201 |
+
"this model name. Check the model page at "
|
| 202 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
| 203 |
+
)
|
| 204 |
+
except EntryNotFoundError:
|
| 205 |
+
raise EnvironmentError(
|
| 206 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
| 207 |
+
)
|
| 208 |
+
except HTTPError as err:
|
| 209 |
+
raise EnvironmentError(
|
| 210 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
| 211 |
+
)
|
| 212 |
+
except ValueError:
|
| 213 |
+
raise EnvironmentError(
|
| 214 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
| 215 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
| 216 |
+
f" directory containing a file named {weights_name} or"
|
| 217 |
+
" \nCheckout your internet connection or see how to run the library in"
|
| 218 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
| 219 |
+
)
|
| 220 |
+
except EnvironmentError:
|
| 221 |
+
raise EnvironmentError(
|
| 222 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
| 223 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
| 224 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
| 225 |
+
f"containing a file named {weights_name}"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ModelMixin(torch.nn.Module):
|
| 230 |
+
r"""
|
| 231 |
+
Base class for all models.
|
| 232 |
+
|
| 233 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
| 234 |
+
and saving models.
|
| 235 |
+
|
| 236 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
| 237 |
+
[`~models.ModelMixin.save_pretrained`].
|
| 238 |
+
"""
|
| 239 |
+
config_name = CONFIG_NAME
|
| 240 |
+
_automatically_saved_args = ["_version", "_class_name", "_name_or_path"]
|
| 241 |
+
_supports_gradient_checkpointing = False
|
| 242 |
+
|
| 243 |
+
def __init__(self):
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
@property
|
| 247 |
+
def is_gradient_checkpointing(self) -> bool:
|
| 248 |
+
"""
|
| 249 |
+
Whether gradient checkpointing is activated for this model or not.
|
| 250 |
+
|
| 251 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
| 252 |
+
activations".
|
| 253 |
+
"""
|
| 254 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
| 255 |
+
|
| 256 |
+
def enable_gradient_checkpointing(self):
|
| 257 |
+
"""
|
| 258 |
+
Activates gradient checkpointing for the current model.
|
| 259 |
+
|
| 260 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
| 261 |
+
activations".
|
| 262 |
+
"""
|
| 263 |
+
if not self._supports_gradient_checkpointing:
|
| 264 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 265 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
| 266 |
+
|
| 267 |
+
def disable_gradient_checkpointing(self):
|
| 268 |
+
"""
|
| 269 |
+
Deactivates gradient checkpointing for the current model.
|
| 270 |
+
|
| 271 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
| 272 |
+
activations".
|
| 273 |
+
"""
|
| 274 |
+
if self._supports_gradient_checkpointing:
|
| 275 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
| 276 |
+
|
| 277 |
+
def set_use_memory_efficient_attention_xformers(
|
| 278 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
| 279 |
+
) -> None:
|
| 280 |
+
# Recursively walk through all the children.
|
| 281 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
| 282 |
+
# gets the message
|
| 283 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 284 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 285 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
| 286 |
+
|
| 287 |
+
for child in module.children():
|
| 288 |
+
fn_recursive_set_mem_eff(child)
|
| 289 |
+
|
| 290 |
+
for module in self.children():
|
| 291 |
+
if isinstance(module, torch.nn.Module):
|
| 292 |
+
fn_recursive_set_mem_eff(module)
|
| 293 |
+
|
| 294 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
| 295 |
+
r"""
|
| 296 |
+
Enable memory efficient attention as implemented in xformers.
|
| 297 |
+
|
| 298 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
| 299 |
+
time. Speed up at training time is not guaranteed.
|
| 300 |
+
|
| 301 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
| 302 |
+
is used.
|
| 303 |
+
|
| 304 |
+
Parameters:
|
| 305 |
+
attention_op (`Callable`, *optional*):
|
| 306 |
+
Override the default `None` operator for use as `op` argument to the
|
| 307 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
| 308 |
+
function of xFormers.
|
| 309 |
+
|
| 310 |
+
Examples:
|
| 311 |
+
|
| 312 |
+
```py
|
| 313 |
+
>>> import torch
|
| 314 |
+
>>> from diffusers import UNet2DConditionModel
|
| 315 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
| 316 |
+
|
| 317 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
| 318 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
| 319 |
+
... )
|
| 320 |
+
>>> model = model.to("cuda")
|
| 321 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
| 322 |
+
```
|
| 323 |
+
"""
|
| 324 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
| 325 |
+
|
| 326 |
+
def disable_xformers_memory_efficient_attention(self):
|
| 327 |
+
r"""
|
| 328 |
+
Disable memory efficient attention as implemented in xformers.
|
| 329 |
+
"""
|
| 330 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
| 331 |
+
|
| 332 |
+
def save_pretrained(
|
| 333 |
+
self,
|
| 334 |
+
save_directory: Union[str, os.PathLike],
|
| 335 |
+
is_main_process: bool = True,
|
| 336 |
+
save_function: Callable = None,
|
| 337 |
+
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
| 338 |
+
):
|
| 339 |
+
"""
|
| 340 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
| 341 |
+
`[`~models.ModelMixin.from_pretrained`]` class method.
|
| 342 |
+
|
| 343 |
+
Arguments:
|
| 344 |
+
save_directory (`str` or `os.PathLike`):
|
| 345 |
+
Directory to which to save. Will be created if it doesn't exist.
|
| 346 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
| 347 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
| 348 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
| 349 |
+
the main process to avoid race conditions.
|
| 350 |
+
save_function (`Callable`):
|
| 351 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
| 352 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
| 353 |
+
`DIFFUSERS_SAVE_MODE`.
|
| 354 |
+
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
| 355 |
+
The state dictionary to save. If `None`, the model's state dictionary will be saved.
|
| 356 |
+
"""
|
| 357 |
+
if os.path.isfile(save_directory):
|
| 358 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
if save_function is None:
|
| 362 |
+
save_function = torch.save
|
| 363 |
+
|
| 364 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 365 |
+
|
| 366 |
+
model_to_save = self
|
| 367 |
+
|
| 368 |
+
# Attach architecture to the config
|
| 369 |
+
# Save the config
|
| 370 |
+
if is_main_process:
|
| 371 |
+
model_to_save.save_config(save_directory)
|
| 372 |
+
|
| 373 |
+
# Save the model
|
| 374 |
+
if state_dict is None:
|
| 375 |
+
state_dict = model_to_save.state_dict()
|
| 376 |
+
|
| 377 |
+
weights_name = WEIGHTS_NAME
|
| 378 |
+
|
| 379 |
+
# Save the model
|
| 380 |
+
save_function(state_dict, os.path.join(save_directory, weights_name))
|
| 381 |
+
|
| 382 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
| 383 |
+
|
| 384 |
+
@classmethod
|
| 385 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
| 386 |
+
r"""
|
| 387 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
| 388 |
+
|
| 389 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
| 390 |
+
the model, you should first set it back in training mode with `model.train()`.
|
| 391 |
+
|
| 392 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
| 393 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
| 394 |
+
task.
|
| 395 |
+
|
| 396 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
| 397 |
+
weights are discarded.
|
| 398 |
+
|
| 399 |
+
Parameters:
|
| 400 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 401 |
+
Can be either:
|
| 402 |
+
|
| 403 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
| 404 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
| 405 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
| 406 |
+
`./my_model_directory/`.
|
| 407 |
+
|
| 408 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 409 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 410 |
+
standard cache should not be used.
|
| 411 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
| 412 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
| 413 |
+
will be automatically derived from the model's weights.
|
| 414 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 416 |
+
cached versions if they exist.
|
| 417 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
| 418 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
| 419 |
+
file exists.
|
| 420 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 421 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 422 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 423 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 424 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 425 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 426 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
| 427 |
+
use_auth_token (`str` or *bool*, *optional*):
|
| 428 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 429 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
| 430 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 431 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 432 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 433 |
+
identifier allowed by git.
|
| 434 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
| 435 |
+
Load the model weights from a Flax checkpoint save file.
|
| 436 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 437 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
| 438 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
| 439 |
+
|
| 440 |
+
mirror (`str`, *optional*):
|
| 441 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
| 442 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
| 443 |
+
Please refer to the mirror site for more information.
|
| 444 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
| 445 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
| 446 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
| 447 |
+
same device.
|
| 448 |
+
|
| 449 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
| 450 |
+
more information about each option see [designing a device
|
| 451 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
| 452 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
| 453 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
| 454 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
| 455 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
| 456 |
+
setting this argument to `True` will raise an error.
|
| 457 |
+
|
| 458 |
+
<Tip>
|
| 459 |
+
|
| 460 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
| 461 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
| 462 |
+
|
| 463 |
+
</Tip>
|
| 464 |
+
|
| 465 |
+
<Tip>
|
| 466 |
+
|
| 467 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
| 468 |
+
this method in a firewalled environment.
|
| 469 |
+
|
| 470 |
+
</Tip>
|
| 471 |
+
|
| 472 |
+
"""
|
| 473 |
+
cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
|
| 474 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
| 475 |
+
force_download = kwargs.pop("force_download", False)
|
| 476 |
+
resume_download = kwargs.pop("resume_download", False)
|
| 477 |
+
proxies = kwargs.pop("proxies", None)
|
| 478 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
| 479 |
+
local_files_only = kwargs.pop("local_files_only", False) # TODO
|
| 480 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 481 |
+
revision = kwargs.pop("revision", None)
|
| 482 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
| 483 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 484 |
+
device_map = kwargs.pop("device_map", None)
|
| 485 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
| 486 |
+
|
| 487 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
| 488 |
+
raise ValueError(
|
| 489 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
| 490 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
user_agent = {
|
| 494 |
+
"diffusers": __version__,
|
| 495 |
+
"file_type": "model",
|
| 496 |
+
"framework": "pytorch",
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
# Load config if we don't provide a configuration
|
| 500 |
+
config_path = pretrained_model_name_or_path
|
| 501 |
+
|
| 502 |
+
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
| 503 |
+
# Load model
|
| 504 |
+
|
| 505 |
+
model_file = None
|
| 506 |
+
|
| 507 |
+
if model_file is None:
|
| 508 |
+
model_file = _get_model_file(
|
| 509 |
+
pretrained_model_name_or_path,
|
| 510 |
+
weights_name=WEIGHTS_NAME,
|
| 511 |
+
cache_dir=cache_dir,
|
| 512 |
+
force_download=force_download,
|
| 513 |
+
resume_download=resume_download,
|
| 514 |
+
proxies=proxies,
|
| 515 |
+
local_files_only=local_files_only,
|
| 516 |
+
use_auth_token=use_auth_token,
|
| 517 |
+
revision=revision,
|
| 518 |
+
subfolder=subfolder,
|
| 519 |
+
user_agent=user_agent,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if low_cpu_mem_usage:
|
| 523 |
+
# Instantiate model with empty weights
|
| 524 |
+
with accelerate.init_empty_weights():
|
| 525 |
+
config, unused_kwargs = cls.load_config(
|
| 526 |
+
config_path,
|
| 527 |
+
cache_dir=cache_dir,
|
| 528 |
+
return_unused_kwargs=True,
|
| 529 |
+
force_download=force_download,
|
| 530 |
+
resume_download=resume_download,
|
| 531 |
+
proxies=proxies,
|
| 532 |
+
local_files_only=local_files_only,
|
| 533 |
+
use_auth_token=use_auth_token,
|
| 534 |
+
revision=revision,
|
| 535 |
+
subfolder=subfolder,
|
| 536 |
+
device_map=device_map,
|
| 537 |
+
**kwargs,
|
| 538 |
+
)
|
| 539 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 540 |
+
|
| 541 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
| 542 |
+
if device_map is None:
|
| 543 |
+
param_device = "cpu"
|
| 544 |
+
state_dict = load_state_dict(model_file)
|
| 545 |
+
# move the params from meta device to cpu
|
| 546 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
| 547 |
+
if len(missing_keys) > 0:
|
| 548 |
+
raise ValueError(
|
| 549 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
| 550 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
| 551 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
|
| 552 |
+
" those weights or else make sure your checkpoint file is correct."
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
for param_name, param in state_dict.items():
|
| 556 |
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
| 557 |
+
if accepts_dtype:
|
| 558 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
|
| 559 |
+
else:
|
| 560 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
| 561 |
+
else: # else let accelerate handle loading and dispatching.
|
| 562 |
+
# Load weights and dispatch according to the device_map
|
| 563 |
+
# by deafult the device_map is None and the weights are loaded on the CPU
|
| 564 |
+
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
|
| 565 |
+
|
| 566 |
+
loading_info = {
|
| 567 |
+
"missing_keys": [],
|
| 568 |
+
"unexpected_keys": [],
|
| 569 |
+
"mismatched_keys": [],
|
| 570 |
+
"error_msgs": [],
|
| 571 |
+
}
|
| 572 |
+
else:
|
| 573 |
+
config, unused_kwargs = cls.load_config(
|
| 574 |
+
config_path,
|
| 575 |
+
cache_dir=cache_dir,
|
| 576 |
+
return_unused_kwargs=True,
|
| 577 |
+
force_download=force_download,
|
| 578 |
+
resume_download=resume_download,
|
| 579 |
+
proxies=proxies,
|
| 580 |
+
local_files_only=local_files_only,
|
| 581 |
+
use_auth_token=use_auth_token,
|
| 582 |
+
revision=revision,
|
| 583 |
+
subfolder=subfolder,
|
| 584 |
+
device_map=device_map,
|
| 585 |
+
**kwargs,
|
| 586 |
+
)
|
| 587 |
+
model = cls.from_config(config, **unused_kwargs)
|
| 588 |
+
|
| 589 |
+
state_dict = load_state_dict(model_file)
|
| 590 |
+
|
| 591 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
| 592 |
+
model,
|
| 593 |
+
state_dict,
|
| 594 |
+
model_file,
|
| 595 |
+
pretrained_model_name_or_path,
|
| 596 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
loading_info = {
|
| 600 |
+
"missing_keys": missing_keys,
|
| 601 |
+
"unexpected_keys": unexpected_keys,
|
| 602 |
+
"mismatched_keys": mismatched_keys,
|
| 603 |
+
"error_msgs": error_msgs,
|
| 604 |
+
}
|
| 605 |
+
|
| 606 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
| 609 |
+
)
|
| 610 |
+
elif torch_dtype is not None:
|
| 611 |
+
model = model.to(torch_dtype)
|
| 612 |
+
|
| 613 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
| 614 |
+
|
| 615 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
| 616 |
+
model.eval()
|
| 617 |
+
if output_loading_info:
|
| 618 |
+
return model, loading_info
|
| 619 |
+
|
| 620 |
+
return model
|
| 621 |
+
|
| 622 |
+
@classmethod
|
| 623 |
+
def _load_pretrained_model(
|
| 624 |
+
cls,
|
| 625 |
+
model,
|
| 626 |
+
state_dict,
|
| 627 |
+
resolved_archive_file,
|
| 628 |
+
pretrained_model_name_or_path,
|
| 629 |
+
ignore_mismatched_sizes=False,
|
| 630 |
+
):
|
| 631 |
+
# Retrieve missing & unexpected_keys
|
| 632 |
+
model_state_dict = model.state_dict()
|
| 633 |
+
loaded_keys = [k for k in state_dict.keys()]
|
| 634 |
+
|
| 635 |
+
expected_keys = list(model_state_dict.keys())
|
| 636 |
+
|
| 637 |
+
original_loaded_keys = loaded_keys
|
| 638 |
+
|
| 639 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
| 640 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
| 641 |
+
|
| 642 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
| 643 |
+
model_to_load = model
|
| 644 |
+
|
| 645 |
+
def _find_mismatched_keys(
|
| 646 |
+
state_dict,
|
| 647 |
+
model_state_dict,
|
| 648 |
+
loaded_keys,
|
| 649 |
+
ignore_mismatched_sizes,
|
| 650 |
+
):
|
| 651 |
+
mismatched_keys = []
|
| 652 |
+
if ignore_mismatched_sizes:
|
| 653 |
+
for checkpoint_key in loaded_keys:
|
| 654 |
+
model_key = checkpoint_key
|
| 655 |
+
|
| 656 |
+
if (
|
| 657 |
+
model_key in model_state_dict
|
| 658 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
| 659 |
+
):
|
| 660 |
+
mismatched_keys.append(
|
| 661 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
| 662 |
+
)
|
| 663 |
+
del state_dict[checkpoint_key]
|
| 664 |
+
return mismatched_keys
|
| 665 |
+
|
| 666 |
+
if state_dict is not None:
|
| 667 |
+
# Whole checkpoint
|
| 668 |
+
mismatched_keys = _find_mismatched_keys(
|
| 669 |
+
state_dict,
|
| 670 |
+
model_state_dict,
|
| 671 |
+
original_loaded_keys,
|
| 672 |
+
ignore_mismatched_sizes,
|
| 673 |
+
)
|
| 674 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
| 675 |
+
|
| 676 |
+
if len(error_msgs) > 0:
|
| 677 |
+
error_msg = "\n\t".join(error_msgs)
|
| 678 |
+
if "size mismatch" in error_msg:
|
| 679 |
+
error_msg += (
|
| 680 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
| 681 |
+
)
|
| 682 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
| 683 |
+
|
| 684 |
+
if len(unexpected_keys) > 0:
|
| 685 |
+
logger.warning(
|
| 686 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
| 687 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
| 688 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
| 689 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
| 690 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
| 691 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
| 692 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
| 693 |
+
" BertForSequenceClassification model)."
|
| 694 |
+
)
|
| 695 |
+
else:
|
| 696 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
| 697 |
+
if len(missing_keys) > 0:
|
| 698 |
+
logger.warning(
|
| 699 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 700 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
| 701 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 702 |
+
)
|
| 703 |
+
elif len(mismatched_keys) == 0:
|
| 704 |
+
logger.info(
|
| 705 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
| 706 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
| 707 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
| 708 |
+
" without further training."
|
| 709 |
+
)
|
| 710 |
+
if len(mismatched_keys) > 0:
|
| 711 |
+
mismatched_warning = "\n".join(
|
| 712 |
+
[
|
| 713 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
| 714 |
+
for key, shape1, shape2 in mismatched_keys
|
| 715 |
+
]
|
| 716 |
+
)
|
| 717 |
+
logger.warning(
|
| 718 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
| 719 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
| 720 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
| 721 |
+
" able to use it for predictions and inference."
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
| 725 |
+
|
| 726 |
+
@property
|
| 727 |
+
def device(self) -> device:
|
| 728 |
+
"""
|
| 729 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
| 730 |
+
device).
|
| 731 |
+
"""
|
| 732 |
+
return get_parameter_device(self)
|
| 733 |
+
|
| 734 |
+
@property
|
| 735 |
+
def dtype(self) -> torch.dtype:
|
| 736 |
+
"""
|
| 737 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
| 738 |
+
"""
|
| 739 |
+
return get_parameter_dtype(self)
|
| 740 |
+
|
| 741 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
| 742 |
+
"""
|
| 743 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
| 747 |
+
Whether or not to return only the number of trainable parameters
|
| 748 |
+
|
| 749 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
| 750 |
+
Whether or not to return only the number of non-embeddings parameters
|
| 751 |
+
|
| 752 |
+
Returns:
|
| 753 |
+
`int`: The number of parameters.
|
| 754 |
+
"""
|
| 755 |
+
|
| 756 |
+
if exclude_embeddings:
|
| 757 |
+
embedding_param_names = [
|
| 758 |
+
f"{name}.weight"
|
| 759 |
+
for name, module_type in self.named_modules()
|
| 760 |
+
if isinstance(module_type, torch.nn.Embedding)
|
| 761 |
+
]
|
| 762 |
+
non_embedding_parameters = [
|
| 763 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
| 764 |
+
]
|
| 765 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
| 766 |
+
else:
|
| 767 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
""" ConfigMixin base class and utilities."""
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class FrozenDict(OrderedDict):
|
| 774 |
+
def __init__(self, *args, **kwargs):
|
| 775 |
+
super().__init__(*args, **kwargs)
|
| 776 |
+
|
| 777 |
+
for key, value in self.items():
|
| 778 |
+
setattr(self, key, value)
|
| 779 |
+
|
| 780 |
+
self.__frozen = True
|
| 781 |
+
|
| 782 |
+
def __delitem__(self, *args, **kwargs):
|
| 783 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
| 784 |
+
|
| 785 |
+
def setdefault(self, *args, **kwargs):
|
| 786 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
| 787 |
+
|
| 788 |
+
def pop(self, *args, **kwargs):
|
| 789 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
| 790 |
+
|
| 791 |
+
def update(self, *args, **kwargs):
|
| 792 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
| 793 |
+
|
| 794 |
+
def __setattr__(self, name, value):
|
| 795 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
| 796 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
| 797 |
+
super().__setattr__(name, value)
|
| 798 |
+
|
| 799 |
+
def __setitem__(self, name, value):
|
| 800 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
| 801 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
| 802 |
+
super().__setitem__(name, value)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
class ConfigMixin:
|
| 806 |
+
r"""
|
| 807 |
+
Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
|
| 808 |
+
methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
|
| 809 |
+
- [`~ConfigMixin.from_config`]
|
| 810 |
+
- [`~ConfigMixin.save_config`]
|
| 811 |
+
|
| 812 |
+
Class attributes:
|
| 813 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
| 814 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
| 815 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
| 816 |
+
overridden by subclass).
|
| 817 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
| 818 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
|
| 819 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
| 820 |
+
subclass).
|
| 821 |
+
"""
|
| 822 |
+
config_name = None
|
| 823 |
+
ignore_for_config = []
|
| 824 |
+
has_compatibles = False
|
| 825 |
+
|
| 826 |
+
_deprecated_kwargs = []
|
| 827 |
+
|
| 828 |
+
def register_to_config(self, **kwargs):
|
| 829 |
+
if self.config_name is None:
|
| 830 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
| 831 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
| 832 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
| 833 |
+
# or solve in a more general way.
|
| 834 |
+
kwargs.pop("kwargs", None)
|
| 835 |
+
for key, value in kwargs.items():
|
| 836 |
+
try:
|
| 837 |
+
setattr(self, key, value)
|
| 838 |
+
except AttributeError as err:
|
| 839 |
+
logger.error(f"Can't set {key} with value {value} for {self}")
|
| 840 |
+
raise err
|
| 841 |
+
|
| 842 |
+
if not hasattr(self, "_internal_dict"):
|
| 843 |
+
internal_dict = kwargs
|
| 844 |
+
else:
|
| 845 |
+
previous_dict = dict(self._internal_dict)
|
| 846 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
| 847 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
| 848 |
+
|
| 849 |
+
self._internal_dict = FrozenDict(internal_dict)
|
| 850 |
+
|
| 851 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 852 |
+
"""
|
| 853 |
+
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
| 854 |
+
[`~ConfigMixin.from_config`] class method.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
save_directory (`str` or `os.PathLike`):
|
| 858 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
| 859 |
+
"""
|
| 860 |
+
if os.path.isfile(save_directory):
|
| 861 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 862 |
+
|
| 863 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 864 |
+
|
| 865 |
+
# If we save using the predefined names, we can load using `from_config`
|
| 866 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
| 867 |
+
|
| 868 |
+
self.to_json_file(output_config_file)
|
| 869 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
| 870 |
+
|
| 871 |
+
@classmethod
|
| 872 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, **kwargs):
|
| 873 |
+
r"""
|
| 874 |
+
Instantiate a Python class from a config dictionary
|
| 875 |
+
|
| 876 |
+
Parameters:
|
| 877 |
+
config (`Dict[str, Any]`):
|
| 878 |
+
A config dictionary from which the Python class will be instantiated. Make sure to only load
|
| 879 |
+
configuration files of compatible classes.
|
| 880 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 881 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
| 882 |
+
|
| 883 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
| 884 |
+
Can be used to update the configuration object (after it being loaded) and initiate the Python class.
|
| 885 |
+
`**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
|
| 886 |
+
overwrite same named arguments of `config`.
|
| 887 |
+
|
| 888 |
+
Examples:
|
| 889 |
+
|
| 890 |
+
```python
|
| 891 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
| 892 |
+
|
| 893 |
+
>>> # Download scheduler from huggingface.co and cache.
|
| 894 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
| 895 |
+
|
| 896 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
| 897 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
| 898 |
+
|
| 899 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
| 900 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
| 901 |
+
```
|
| 902 |
+
"""
|
| 903 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
| 904 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
| 905 |
+
if "pretrained_model_name_or_path" in kwargs:
|
| 906 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
| 907 |
+
|
| 908 |
+
if config is None:
|
| 909 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
| 910 |
+
# ======>
|
| 911 |
+
|
| 912 |
+
# Return model and optionally state and/or unused_kwargs
|
| 913 |
+
model = cls(**config)
|
| 914 |
+
return model
|
| 915 |
+
|
| 916 |
+
@classmethod
|
| 917 |
+
def load_config(
|
| 918 |
+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
| 919 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 920 |
+
r"""
|
| 921 |
+
Instantiate a Python class from a config dictionary
|
| 922 |
+
|
| 923 |
+
Parameters:
|
| 924 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 925 |
+
Can be either:
|
| 926 |
+
|
| 927 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
| 928 |
+
organization name, like `google/ddpm-celebahq-256`.
|
| 929 |
+
- A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
|
| 930 |
+
`./my_model_directory/`.
|
| 931 |
+
|
| 932 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 933 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
| 934 |
+
standard cache should not be used.
|
| 935 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 936 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 937 |
+
cached versions if they exist.
|
| 938 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
| 939 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
| 940 |
+
file exists.
|
| 941 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 942 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
| 943 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 944 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 945 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 946 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
| 947 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
| 948 |
+
use_auth_token (`str` or *bool*, *optional*):
|
| 949 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
| 950 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
| 951 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 952 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
| 953 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
| 954 |
+
identifier allowed by git.
|
| 955 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 956 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
| 957 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
| 958 |
+
|
| 959 |
+
<Tip>
|
| 960 |
+
|
| 961 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
| 962 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
| 963 |
+
|
| 964 |
+
</Tip>
|
| 965 |
+
|
| 966 |
+
<Tip>
|
| 967 |
+
|
| 968 |
+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
| 969 |
+
use this method in a firewalled environment.
|
| 970 |
+
|
| 971 |
+
</Tip>
|
| 972 |
+
"""
|
| 973 |
+
cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
|
| 974 |
+
force_download = kwargs.pop("force_download", False)
|
| 975 |
+
resume_download = kwargs.pop("resume_download", False)
|
| 976 |
+
proxies = kwargs.pop("proxies", None)
|
| 977 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
| 978 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 979 |
+
revision = kwargs.pop("revision", None)
|
| 980 |
+
_ = kwargs.pop("mirror", None)
|
| 981 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 982 |
+
|
| 983 |
+
user_agent = {"file_type": "config"}
|
| 984 |
+
|
| 985 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 986 |
+
|
| 987 |
+
if cls.config_name is None:
|
| 988 |
+
raise ValueError(
|
| 989 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
| 990 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
| 991 |
+
)
|
| 992 |
+
|
| 993 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
| 994 |
+
config_file = pretrained_model_name_or_path
|
| 995 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
| 996 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
| 997 |
+
# Load from a PyTorch checkpoint
|
| 998 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
| 999 |
+
elif subfolder is not None and os.path.isfile(
|
| 1000 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
| 1001 |
+
):
|
| 1002 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
| 1003 |
+
else:
|
| 1004 |
+
raise EnvironmentError(
|
| 1005 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
| 1006 |
+
)
|
| 1007 |
+
else:
|
| 1008 |
+
try:
|
| 1009 |
+
# Load from URL or cache if already cached
|
| 1010 |
+
config_file = hf_hub_download(
|
| 1011 |
+
pretrained_model_name_or_path,
|
| 1012 |
+
filename=cls.config_name,
|
| 1013 |
+
cache_dir=cache_dir,
|
| 1014 |
+
force_download=force_download,
|
| 1015 |
+
proxies=proxies,
|
| 1016 |
+
resume_download=resume_download,
|
| 1017 |
+
local_files_only=local_files_only,
|
| 1018 |
+
use_auth_token=use_auth_token,
|
| 1019 |
+
user_agent=user_agent,
|
| 1020 |
+
subfolder=subfolder,
|
| 1021 |
+
revision=revision,
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
+
except RepositoryNotFoundError:
|
| 1025 |
+
raise EnvironmentError(
|
| 1026 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
| 1027 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
| 1028 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
| 1029 |
+
" login`."
|
| 1030 |
+
)
|
| 1031 |
+
except RevisionNotFoundError:
|
| 1032 |
+
raise EnvironmentError(
|
| 1033 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
| 1034 |
+
" this model name. Check the model page at"
|
| 1035 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
| 1036 |
+
)
|
| 1037 |
+
except EntryNotFoundError:
|
| 1038 |
+
raise EnvironmentError(
|
| 1039 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
| 1040 |
+
)
|
| 1041 |
+
except HTTPError as err:
|
| 1042 |
+
raise EnvironmentError(
|
| 1043 |
+
"There was a specific connection error when trying to load"
|
| 1044 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
| 1045 |
+
)
|
| 1046 |
+
except ValueError:
|
| 1047 |
+
raise EnvironmentError(
|
| 1048 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
| 1049 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
| 1050 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
| 1051 |
+
" run the library in offline mode at"
|
| 1052 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
| 1053 |
+
)
|
| 1054 |
+
except EnvironmentError:
|
| 1055 |
+
raise EnvironmentError(
|
| 1056 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
| 1057 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
| 1058 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
| 1059 |
+
f"containing a {cls.config_name} file"
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
# Load config dict
|
| 1064 |
+
config_dict = cls._dict_from_json_file(config_file)
|
| 1065 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 1066 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
| 1067 |
+
|
| 1068 |
+
if return_unused_kwargs:
|
| 1069 |
+
return config_dict, kwargs
|
| 1070 |
+
|
| 1071 |
+
return config_dict
|
| 1072 |
+
|
| 1073 |
+
@staticmethod
|
| 1074 |
+
def _get_init_keys(cls):
|
| 1075 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
| 1076 |
+
|
| 1077 |
+
@classmethod
|
| 1078 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
| 1079 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 1080 |
+
text = reader.read()
|
| 1081 |
+
return json.loads(text)
|
| 1082 |
+
|
| 1083 |
+
def __repr__(self):
|
| 1084 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 1085 |
+
|
| 1086 |
+
@property
|
| 1087 |
+
def config(self) -> Dict[str, Any]:
|
| 1088 |
+
"""
|
| 1089 |
+
Returns the config of the class as a frozen dictionary
|
| 1090 |
+
|
| 1091 |
+
Returns:
|
| 1092 |
+
`Dict[str, Any]`: Config of the class.
|
| 1093 |
+
"""
|
| 1094 |
+
return self._internal_dict
|
| 1095 |
+
|
| 1096 |
+
def to_json_string(self) -> str:
|
| 1097 |
+
"""
|
| 1098 |
+
Serializes this instance to a JSON string.
|
| 1099 |
+
|
| 1100 |
+
Returns:
|
| 1101 |
+
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
| 1102 |
+
"""
|
| 1103 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
| 1104 |
+
config_dict["_class_name"] = self.__class__.__name__
|
| 1105 |
+
config_dict["_version"] = __version__
|
| 1106 |
+
|
| 1107 |
+
def to_json_saveable(value):
|
| 1108 |
+
if isinstance(value, np.ndarray):
|
| 1109 |
+
value = value.tolist()
|
| 1110 |
+
elif isinstance(value, PosixPath):
|
| 1111 |
+
value = str(value)
|
| 1112 |
+
return value
|
| 1113 |
+
|
| 1114 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
| 1115 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 1116 |
+
|
| 1117 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
| 1118 |
+
"""
|
| 1119 |
+
Save this instance to a JSON file.
|
| 1120 |
+
|
| 1121 |
+
Args:
|
| 1122 |
+
json_file_path (`str` or `os.PathLike`):
|
| 1123 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 1124 |
+
"""
|
| 1125 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 1126 |
+
writer.write(self.to_json_string())
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
def register_to_config(init):
|
| 1130 |
+
r"""
|
| 1131 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
| 1132 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
| 1133 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
| 1134 |
+
|
| 1135 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
| 1136 |
+
"""
|
| 1137 |
+
|
| 1138 |
+
@functools.wraps(init)
|
| 1139 |
+
def inner_init(self, *args, **kwargs):
|
| 1140 |
+
# Ignore private kwargs in the init.
|
| 1141 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
| 1142 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
| 1143 |
+
if not isinstance(self, ConfigMixin):
|
| 1144 |
+
raise RuntimeError(
|
| 1145 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
| 1146 |
+
"not inherit from `ConfigMixin`."
|
| 1147 |
+
)
|
| 1148 |
+
|
| 1149 |
+
ignore = getattr(self, "ignore_for_config", [])
|
| 1150 |
+
# Get positional arguments aligned with kwargs
|
| 1151 |
+
new_kwargs = {}
|
| 1152 |
+
signature = inspect.signature(init)
|
| 1153 |
+
parameters = {
|
| 1154 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
| 1155 |
+
}
|
| 1156 |
+
for arg, name in zip(args, parameters.keys()):
|
| 1157 |
+
new_kwargs[name] = arg
|
| 1158 |
+
|
| 1159 |
+
# Then add all kwargs
|
| 1160 |
+
new_kwargs.update(
|
| 1161 |
+
{
|
| 1162 |
+
k: init_kwargs.get(k, default)
|
| 1163 |
+
for k, default in parameters.items()
|
| 1164 |
+
if k not in ignore and k not in new_kwargs
|
| 1165 |
+
}
|
| 1166 |
+
)
|
| 1167 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
| 1168 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
| 1169 |
+
init(self, *args, **init_kwargs)
|
| 1170 |
+
|
| 1171 |
+
return inner_init
|
InternLM/internlm/model/norm.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
|
| 2 |
+
|
| 3 |
+
import numbers
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn import init
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def manual_rms_norm(my_input, normalized_shape, weight, eps):
|
| 11 |
+
# layer norm should always be calculated in float32
|
| 12 |
+
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
|
| 13 |
+
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
|
| 14 |
+
my_input = my_input * torch.rsqrt(variance + eps)
|
| 15 |
+
|
| 16 |
+
if weight is None:
|
| 17 |
+
return my_input
|
| 18 |
+
|
| 19 |
+
# model_hf into half-precision if necessary
|
| 20 |
+
if weight.dtype in [torch.float16, torch.bfloat16]:
|
| 21 |
+
my_input = my_input.to(weight.dtype)
|
| 22 |
+
|
| 23 |
+
return weight * my_input
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class RMSNormTorch(torch.nn.Module):
|
| 27 |
+
"""A custom PyTorch module for RMS normalization."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, normalized_shape, eps=1e-5):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 33 |
+
normalized_shape = (normalized_shape,)
|
| 34 |
+
self.normalized_shape = torch.Size(normalized_shape)
|
| 35 |
+
self.eps = eps
|
| 36 |
+
self.weight = Parameter(torch.empty(*normalized_shape))
|
| 37 |
+
self.reset_parameters()
|
| 38 |
+
|
| 39 |
+
def forward(self, _input: torch.Tensor):
|
| 40 |
+
return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
|
| 41 |
+
|
| 42 |
+
def reset_parameters(self):
|
| 43 |
+
init.ones_(self.weight)
|
| 44 |
+
|
| 45 |
+
def extra_repr(self):
|
| 46 |
+
return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
|
InternLM/internlm/model/utils.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from flash_attn.ops.fused_dense import FusedDenseFunc
|
| 9 |
+
from flash_attn.utils.distributed import (
|
| 10 |
+
all_gather_raw,
|
| 11 |
+
all_reduce_raw,
|
| 12 |
+
reduce_scatter_raw,
|
| 13 |
+
)
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch.cuda.amp import custom_bwd
|
| 16 |
+
from torch.distributed import ProcessGroup
|
| 17 |
+
|
| 18 |
+
from internlm.core.context import global_context as gpc
|
| 19 |
+
from internlm.utils.logger import get_logger
|
| 20 |
+
|
| 21 |
+
logger = get_logger(__file__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _split(input_, parallel_mode, dim=-1):
|
| 25 |
+
# skip if only one rank involved
|
| 26 |
+
world_size = gpc.get_world_size(parallel_mode)
|
| 27 |
+
if world_size == 1:
|
| 28 |
+
return input_
|
| 29 |
+
|
| 30 |
+
# Split along last dimension.
|
| 31 |
+
dim_size = input_.size(dim)
|
| 32 |
+
assert dim_size % world_size == 0, (
|
| 33 |
+
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
|
| 34 |
+
f"cannot split tensor evenly"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
| 38 |
+
rank = gpc.get_local_rank(parallel_mode)
|
| 39 |
+
output = tensor_list[rank].contiguous()
|
| 40 |
+
|
| 41 |
+
return output
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _gather(input_, parallel_mode, dim=-1):
|
| 45 |
+
# skip if only one rank involved
|
| 46 |
+
world_size = gpc.get_world_size(parallel_mode)
|
| 47 |
+
if world_size == 1:
|
| 48 |
+
return input_
|
| 49 |
+
|
| 50 |
+
# all gather
|
| 51 |
+
rank = gpc.get_local_rank(parallel_mode)
|
| 52 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
| 53 |
+
tensor_list[rank] = input_
|
| 54 |
+
group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
|
| 55 |
+
torch.distributed.all_gather(tensor_list, input_, group=group)
|
| 56 |
+
|
| 57 |
+
# concat
|
| 58 |
+
output = torch.cat(tensor_list, dim=dim).contiguous()
|
| 59 |
+
|
| 60 |
+
return output
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class _GatherForwardSplitBackward(torch.autograd.Function):
|
| 64 |
+
"""Gather the input from model parallel region and concatenate.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
input_: input matrix.
|
| 68 |
+
parallel_mode: parallel mode.
|
| 69 |
+
dim: dimension
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
@staticmethod
|
| 73 |
+
def symbolic(input_):
|
| 74 |
+
return _gather(input_, parallel_mode=None)
|
| 75 |
+
|
| 76 |
+
@staticmethod
|
| 77 |
+
def forward(ctx, input_, parallel_mode, dim):
|
| 78 |
+
ctx.mode = parallel_mode
|
| 79 |
+
ctx.dim = dim
|
| 80 |
+
return _gather(input_, parallel_mode, dim)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def backward(ctx, grad_output):
|
| 84 |
+
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def gather_forward_split_backward(input_, parallel_mode, dim):
|
| 88 |
+
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
|
| 92 |
+
assert my_input.dtype == grad_output.dtype
|
| 93 |
+
grad_weight = torch.matmul(grad_output.t(), my_input)
|
| 94 |
+
grad_bias = grad_output.sum(dim=0) if has_d_bias else None
|
| 95 |
+
return grad_weight, grad_bias
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
|
| 99 |
+
class FusedDenseFuncTorch(FusedDenseFunc):
|
| 100 |
+
"""A custom PyTorch module extending FusedDenseFunc."""
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
@custom_bwd
|
| 104 |
+
def backward(ctx, grad_output, *args):
|
| 105 |
+
grad_output = grad_output.contiguous()
|
| 106 |
+
if ctx.return_residual:
|
| 107 |
+
(grad_input,) = args
|
| 108 |
+
grad_input = grad_input.contiguous()
|
| 109 |
+
process_group = ctx.process_group
|
| 110 |
+
sequence_parallel = ctx.sequence_parallel
|
| 111 |
+
if ctx.compute_weight_gradient:
|
| 112 |
+
x, weight = ctx.saved_tensors
|
| 113 |
+
if process_group is not None and sequence_parallel:
|
| 114 |
+
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
| 115 |
+
else:
|
| 116 |
+
total_x = x
|
| 117 |
+
else:
|
| 118 |
+
(weight,) = ctx.saved_tensors
|
| 119 |
+
total_x = None
|
| 120 |
+
batch_shape = grad_output.shape[:-1]
|
| 121 |
+
batch_dim = batch_shape.numel()
|
| 122 |
+
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
| 123 |
+
if ctx.needs_input_grad[0]:
|
| 124 |
+
if not ctx.return_residual:
|
| 125 |
+
grad_input = F.linear(grad_output, weight.t())
|
| 126 |
+
else:
|
| 127 |
+
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
|
| 128 |
+
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
| 129 |
+
if process_group is not None:
|
| 130 |
+
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
| 131 |
+
grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
|
| 132 |
+
else:
|
| 133 |
+
grad_input = None
|
| 134 |
+
if ctx.needs_input_grad[1]:
|
| 135 |
+
assert ctx.compute_weight_gradient
|
| 136 |
+
if process_group is not None and sequence_parallel:
|
| 137 |
+
handle_x.wait()
|
| 138 |
+
# we remove the cuda independence, which is different from flash_attn.
|
| 139 |
+
grad_weight, grad_bias = linear_bias_wgrad_torch(
|
| 140 |
+
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
grad_weight = None
|
| 144 |
+
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
| 145 |
+
if process_group is not None and ctx.needs_input_grad[0]:
|
| 146 |
+
handle_grad_input.wait()
|
| 147 |
+
return grad_input, grad_weight, grad_bias, None, None, None
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def fused_dense_func_torch(
|
| 151 |
+
x: Tensor,
|
| 152 |
+
weight: Tensor,
|
| 153 |
+
bias: Optional[Tensor] = None,
|
| 154 |
+
return_residual: bool = False,
|
| 155 |
+
process_group: Optional[ProcessGroup] = None,
|
| 156 |
+
sequence_parallel: bool = True,
|
| 157 |
+
):
|
| 158 |
+
dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
|
| 159 |
+
x.dtype == torch.float32 and torch.is_autocast_enabled()
|
| 160 |
+
)
|
| 161 |
+
if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
|
| 162 |
+
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
| 163 |
+
else:
|
| 164 |
+
return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class _SplitForwardGatherBackward(torch.autograd.Function):
|
| 168 |
+
"""
|
| 169 |
+
Split the input and keep only the corresponding chuck to the rank.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
input_: input matrix.
|
| 173 |
+
parallel_mode: parallel mode.
|
| 174 |
+
dim: dimension
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def symbolic(input_):
|
| 179 |
+
return _split(input_, parallel_mode=None)
|
| 180 |
+
|
| 181 |
+
@staticmethod
|
| 182 |
+
def forward(ctx, input_, parallel_mode, dim):
|
| 183 |
+
ctx.mode = parallel_mode
|
| 184 |
+
ctx.dim = dim
|
| 185 |
+
return _split(input_, parallel_mode, dim)
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def backward(ctx, grad_output):
|
| 189 |
+
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def split_forward_gather_backward(input_, parallel_mode, dim):
|
| 193 |
+
return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def try_import_RMSNorm():
|
| 197 |
+
"""
|
| 198 |
+
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
| 199 |
+
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
|
| 203 |
+
|
| 204 |
+
return RMSNorm
|
| 205 |
+
except ModuleNotFoundError:
|
| 206 |
+
logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
|
| 207 |
+
from internlm.model.norm import RMSNormTorch as RMSNorm
|
| 208 |
+
|
| 209 |
+
return RMSNorm
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def try_import_LayerNorm():
|
| 213 |
+
"""
|
| 214 |
+
Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
try:
|
| 218 |
+
from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
|
| 219 |
+
|
| 220 |
+
return LayerNorm
|
| 221 |
+
except ModuleNotFoundError:
|
| 222 |
+
import torch.nn as nn
|
| 223 |
+
|
| 224 |
+
return nn.LayerNorm
|