jirong commited on
Commit
ee3e701
·
verified ·
1 Parent(s): 225894b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +14 -0
  2. InternLM/__init__.py +0 -0
  3. InternLM/configs/kd_1b_to_300m.py +208 -0
  4. InternLM/configs/pretrain_300m.py +168 -0
  5. InternLM/internlm/__init__.py +10 -0
  6. InternLM/internlm/apis/__init__.py +0 -0
  7. InternLM/internlm/apis/inference.py +848 -0
  8. InternLM/internlm/core/__init__.py +9 -0
  9. InternLM/internlm/core/communication/__init__.py +32 -0
  10. InternLM/internlm/core/communication/p2p.py +582 -0
  11. InternLM/internlm/core/communication/utils.py +125 -0
  12. InternLM/internlm/core/context/__init__.py +49 -0
  13. InternLM/internlm/core/context/parallel_context.py +569 -0
  14. InternLM/internlm/core/context/process_group_initializer.py +418 -0
  15. InternLM/internlm/core/context/random.py +131 -0
  16. InternLM/internlm/core/engine.py +227 -0
  17. InternLM/internlm/core/gradient_handler.py +76 -0
  18. InternLM/internlm/core/naive_amp.py +136 -0
  19. InternLM/internlm/core/scheduler/__init__.py +14 -0
  20. InternLM/internlm/core/scheduler/base_scheduler.py +187 -0
  21. InternLM/internlm/core/scheduler/no_pipeline_scheduler.py +266 -0
  22. InternLM/internlm/core/scheduler/pipeline_scheduler.py +1363 -0
  23. InternLM/internlm/core/trainer.py +190 -0
  24. InternLM/internlm/data/__init__.py +13 -0
  25. InternLM/internlm/data/batch_sampler.py +354 -0
  26. InternLM/internlm/data/collaters.py +88 -0
  27. InternLM/internlm/data/dataset.py +56 -0
  28. InternLM/internlm/data/dummy_dataset.py +44 -0
  29. InternLM/internlm/data/packed_dataset.py +421 -0
  30. InternLM/internlm/data/single_dataset.py +117 -0
  31. InternLM/internlm/data/utils.py +46 -0
  32. InternLM/internlm/initialize/__init__.py +16 -0
  33. InternLM/internlm/initialize/initialize_tensor.py +63 -0
  34. InternLM/internlm/initialize/initialize_trainer.py +235 -0
  35. InternLM/internlm/initialize/launch.py +511 -0
  36. InternLM/internlm/initialize/legacy/__init__.py +0 -0
  37. InternLM/internlm/initialize/legacy/launch.py +40 -0
  38. InternLM/internlm/model/__init__.py +23 -0
  39. InternLM/internlm/model/embedding.py +273 -0
  40. InternLM/internlm/model/linear.py +201 -0
  41. InternLM/internlm/model/loss.py +81 -0
  42. InternLM/internlm/model/metrics.py +263 -0
  43. InternLM/internlm/model/modeling_internlm.py +524 -0
  44. InternLM/internlm/model/modeling_vit.py +527 -0
  45. InternLM/internlm/model/multi_head_attention.py +186 -0
  46. InternLM/internlm/model/muse/__init__.py +18 -0
  47. InternLM/internlm/model/muse/modeling_taming_vqgan.py +591 -0
  48. InternLM/internlm/model/muse/modeling_utils.py +1171 -0
  49. InternLM/internlm/model/norm.py +46 -0
  50. InternLM/internlm/model/utils.py +224 -0
.gitattributes CHANGED
@@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ InternLM/tools/data/derain_prompt/000000_img.png filter=lfs diff=lfs merge=lfs -text
37
+ InternLM/tools/data/derain_prompt/000000_label.png filter=lfs diff=lfs merge=lfs -text
38
+ InternLM/tools/data/derain_prompt/000001_img.png filter=lfs diff=lfs merge=lfs -text
39
+ InternLM/tools/data/derain_prompt/000001_label.png filter=lfs diff=lfs merge=lfs -text
40
+ InternLM/tools/data/derain_prompt/000002_img.png filter=lfs diff=lfs merge=lfs -text
41
+ InternLM/tools/data/derain_prompt/000002_label.png filter=lfs diff=lfs merge=lfs -text
42
+ InternLM/tools/data/examples/derain_1.png filter=lfs diff=lfs merge=lfs -text
43
+ InternLM/tools/data/examples/derain_2.png filter=lfs diff=lfs merge=lfs -text
44
+ InternLM/tools/data/examples/pose_2.png filter=lfs diff=lfs merge=lfs -text
45
+ InternLM/tools/data/examples/seg_1.png filter=lfs diff=lfs merge=lfs -text
46
+ InternLM/tools/data/examples/seg_2.png filter=lfs diff=lfs merge=lfs -text
47
+ InternLM/tools/data/pose_prompt/000002_img.png filter=lfs diff=lfs merge=lfs -text
48
+ InternLM/tools/data/seg_prompt/000000_img.png filter=lfs diff=lfs merge=lfs -text
49
+ figs/DeLVM.PNG filter=lfs diff=lfs merge=lfs -text
InternLM/__init__.py ADDED
File without changes
InternLM/configs/kd_1b_to_300m.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kd_config = dict(gt_weight=1., kd_weight=1., temperature=1)
2
+ teacher_type = "INTERNLM"
3
+
4
+ teacher_ckpt_folder = '/path/to/teacher'
5
+
6
+ VQGAN_FOLDER = '/path/to/vqgan'
7
+ T_SEQ_LEN = 2048
8
+ T_HIDDEN_SIZE = 2048
9
+ T_NUM_ATTENTION_HEAD = 16
10
+ T_MLP_RATIO = 8 / 3
11
+ T_NUM_LAYER = 22
12
+ T_VOCAB_SIZE = 8192
13
+
14
+ teacher = dict(
15
+ checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
16
+ num_attention_heads=T_NUM_ATTENTION_HEAD,
17
+ embed_split_hidden=True,
18
+ vocab_size=T_VOCAB_SIZE,
19
+ embed_grad_scale=1,
20
+ parallel_output=True,
21
+ hidden_size=T_HIDDEN_SIZE,
22
+ num_layers=T_NUM_LAYER,
23
+ mlp_ratio=T_MLP_RATIO,
24
+ apply_post_layer_norm=False,
25
+ dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
26
+ norm_type="rmsnorm",
27
+ layer_norm_epsilon=1e-5,
28
+ use_flash_attn=True,
29
+ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
30
+ lvm_config=dict(
31
+ enable=True,
32
+ embedding_cfg=dict(
33
+ vq_model_path=VQGAN_FOLDER,
34
+ embedding_dim=T_HIDDEN_SIZE,
35
+ freeze_vq_model=True,
36
+ ),
37
+ )
38
+ )
39
+
40
+ ########################################################
41
+ JOB_NAME = "lvm_llama_kd"
42
+ DO_ALERT = False
43
+ model_type = "INTERNLM"
44
+
45
+ SEQ_LEN = 2048
46
+ HIDDEN_SIZE = 1024
47
+ NUM_ATTENTION_HEAD = 8
48
+ MLP_RATIO = 8 / 3
49
+ NUM_LAYER = 22
50
+ VOCAB_SIZE = 8192
51
+
52
+ MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
53
+ SAVE_CKPT_FOLDER = "local:/path_to_save/"
54
+ LOAD_CKPT_FOLDER = "local:/path_to_load/"
55
+
56
+ CHECKPOINT_EVERY = 10000
57
+ ckpt = dict(
58
+ enable_save_ckpt=True, # set True to enable ckpt save.
59
+ save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
60
+ # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"),
61
+ # load_ckpt_folder="local:llm_ckpts/",
62
+ # 'load_ckpt_info' setting guide:
63
+ # 1. the 'path' indicate ckpt path,
64
+ # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
65
+ # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
66
+ # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
67
+ checkpoint_every=CHECKPOINT_EVERY,
68
+ async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
69
+ async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
70
+ # oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
71
+ oss_snapshot_freq=0,
72
+ )
73
+
74
+ TRAIN_FOLDER = "/path/to/dataset"
75
+ VALID_FOLDER = "/path/to/dataset"
76
+ data = dict(
77
+ seq_len=SEQ_LEN,
78
+ # micro_num means the number of micro_batch contained in one gradient update
79
+ micro_num=1,
80
+ # packed_length = micro_bsz * SEQ_LEN
81
+ micro_bsz=16,
82
+ # defaults to the value of micro_num
83
+ valid_micro_num=1,
84
+ # defaults to 0, means disable evaluate
85
+ valid_every=0,
86
+ pack_sample_into_one=False,
87
+ train_one_epoch=False,
88
+ total_steps=40000,
89
+ skip_batches="",
90
+ rampup_batch_size="",
91
+ # Datasets with less than 50 rows will be discarded
92
+ min_length=50,
93
+ train_folder=TRAIN_FOLDER,
94
+ valid_folder=None,
95
+ empty_cache_and_diag_interval=10000,
96
+ diag_outlier_ratio=1.1,
97
+ )
98
+
99
+ grad_scaler = dict(
100
+ fp16=dict(
101
+ # the initial loss scale, defaults to 2**16
102
+ initial_scale=2**16,
103
+ # the minimum loss scale, defaults to None
104
+ min_scale=1,
105
+ # the number of steps to increase loss scale when no overflow occurs
106
+ growth_interval=1000,
107
+ ),
108
+ # the multiplication factor for increasing loss scale, defaults to 2
109
+ growth_factor=2,
110
+ # the multiplication factor for decreasing loss scale, defaults to 0.5
111
+ backoff_factor=0.5,
112
+ # the maximum loss scale, defaults to None
113
+ max_scale=2**24,
114
+ # the number of overflows before decreasing loss scale, defaults to 2
115
+ hysteresis=2,
116
+ )
117
+
118
+ hybrid_zero_optimizer = dict(
119
+ # Enable low_level_optimzer overlap_communication
120
+ overlap_sync_grad=True,
121
+ overlap_sync_param=True,
122
+ # bucket size for nccl communication params
123
+ reduce_bucket_size=512 * 1024 * 1024,
124
+ # grad clipping
125
+ clip_grad_norm=1.0,
126
+ )
127
+
128
+ loss = dict(
129
+ label_smoothing=0,
130
+ )
131
+
132
+ adam = dict(
133
+ lr=1.5e-4,
134
+ adam_beta1=0.9,
135
+ adam_beta2=0.95,
136
+ adam_beta2_c=0,
137
+ adam_eps=1e-8,
138
+ weight_decay=0.1,
139
+ )
140
+
141
+ lr_scheduler = dict(
142
+ total_steps=data["total_steps"],
143
+ init_steps=0, # optimizer_warmup_step
144
+ warmup_ratio=0.0056,
145
+ eta_min=1.5e-5,
146
+ last_epoch=-1,
147
+ )
148
+
149
+ beta2_scheduler = dict(
150
+ init_beta2=adam["adam_beta2"],
151
+ c=adam["adam_beta2_c"],
152
+ cur_iter=-1,
153
+ )
154
+
155
+ model = dict(
156
+ checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
157
+ num_attention_heads=NUM_ATTENTION_HEAD,
158
+ embed_split_hidden=True,
159
+ vocab_size=VOCAB_SIZE,
160
+ embed_grad_scale=1,
161
+ parallel_output=True,
162
+ hidden_size=HIDDEN_SIZE,
163
+ num_layers=NUM_LAYER,
164
+ mlp_ratio=MLP_RATIO,
165
+ apply_post_layer_norm=False,
166
+ dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
167
+ norm_type="rmsnorm",
168
+ layer_norm_epsilon=1e-5,
169
+ use_flash_attn=True,
170
+ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
171
+ lvm_config=dict(
172
+ enable=True,
173
+ embedding_cfg=dict(
174
+ vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/',
175
+ embedding_dim=HIDDEN_SIZE,
176
+ freeze_vq_model=True,
177
+ ),
178
+ )
179
+ )
180
+ """
181
+ zero1 parallel:
182
+ 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
183
+ so parameters will be divided within the range of dp.
184
+ 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
185
+ 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
186
+ For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
187
+ pipeline parallel (dict):
188
+ 1. size: int, the size of pipeline parallel.
189
+ 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
190
+ tensor parallel: tensor parallel size, usually the number of GPUs per node.
191
+ """
192
+ parallel = dict(
193
+ zero1=8,
194
+ pipeline=dict(size=1, interleaved_overlap=True),
195
+ sequence_parallel=False,
196
+ )
197
+
198
+ cudnn_deterministic = False
199
+ cudnn_benchmark = False
200
+
201
+ monitor = dict(
202
+ # feishu alert configs
203
+ alert=dict(
204
+ enable_feishu_alert=DO_ALERT,
205
+ feishu_alert_address=None, # feishu webhook to send alert message
206
+ light_monitor_address=None, # light_monitor address to send heartbeat
207
+ ),
208
+ )
InternLM/configs/pretrain_300m.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ JOB_NAME = "lvm_llama"
2
+ DO_ALERT = False
3
+ model_type = "INTERNLM"
4
+
5
+ SEQ_LEN = 2048
6
+ HIDDEN_SIZE = 1024
7
+ NUM_ATTENTION_HEAD = 8
8
+ MLP_RATIO = 8 / 3
9
+ NUM_LAYER = 22
10
+ VOCAB_SIZE = 8192
11
+
12
+ MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
13
+ SAVE_CKPT_FOLDER = "local:/path_to_save/"
14
+ LOAD_CKPT_FOLDER = "local:/path_to_load/"
15
+
16
+ CHECKPOINT_EVERY = 10000
17
+ ckpt = dict(
18
+ enable_save_ckpt=True, # set True to enable ckpt save.
19
+ save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
20
+ # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["all"], ckpt_type="normal"),
21
+ # load_ckpt_folder="local:llm_ckpts/",
22
+ # 'load_ckpt_info' setting guide:
23
+ # 1. the 'path' indicate ckpt path,
24
+ # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
25
+ # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
26
+ # load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
27
+ checkpoint_every=CHECKPOINT_EVERY,
28
+ async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
29
+ async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
30
+ # oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
31
+ oss_snapshot_freq=0,
32
+ )
33
+
34
+ TRAIN_FOLDER = "/path/to/dataset"
35
+ VALID_FOLDER = "/path/to/dataset"
36
+ data = dict(
37
+ seq_len=SEQ_LEN,
38
+ # micro_num means the number of micro_batch contained in one gradient update
39
+ micro_num=1,
40
+ # packed_length = micro_bsz * SEQ_LEN
41
+ micro_bsz=16,
42
+ # defaults to the value of micro_num
43
+ valid_micro_num=1,
44
+ # defaults to 0, means disable evaluate
45
+ valid_every=0,
46
+ pack_sample_into_one=False,
47
+ train_one_epoch=False,
48
+ total_steps=40000,
49
+ skip_batches="",
50
+ rampup_batch_size="",
51
+ # Datasets with less than 50 rows will be discarded
52
+ min_length=50,
53
+ train_folder=TRAIN_FOLDER,
54
+ valid_folder=None,
55
+ empty_cache_and_diag_interval=10000,
56
+ diag_outlier_ratio=1.1,
57
+ )
58
+
59
+ grad_scaler = dict(
60
+ fp16=dict(
61
+ # the initial loss scale, defaults to 2**16
62
+ initial_scale=2**16,
63
+ # the minimum loss scale, defaults to None
64
+ min_scale=1,
65
+ # the number of steps to increase loss scale when no overflow occurs
66
+ growth_interval=1000,
67
+ ),
68
+ # the multiplication factor for increasing loss scale, defaults to 2
69
+ growth_factor=2,
70
+ # the multiplication factor for decreasing loss scale, defaults to 0.5
71
+ backoff_factor=0.5,
72
+ # the maximum loss scale, defaults to None
73
+ max_scale=2**24,
74
+ # the number of overflows before decreasing loss scale, defaults to 2
75
+ hysteresis=2,
76
+ )
77
+
78
+ hybrid_zero_optimizer = dict(
79
+ # Enable low_level_optimzer overlap_communication
80
+ overlap_sync_grad=True,
81
+ overlap_sync_param=True,
82
+ # bucket size for nccl communication params
83
+ reduce_bucket_size=512 * 1024 * 1024,
84
+ # grad clipping
85
+ clip_grad_norm=1.0,
86
+ )
87
+
88
+ loss = dict(
89
+ label_smoothing=0,
90
+ )
91
+
92
+ adam = dict(
93
+ lr=1.5e-4,
94
+ adam_beta1=0.9,
95
+ adam_beta2=0.95,
96
+ adam_beta2_c=0,
97
+ adam_eps=1e-8,
98
+ weight_decay=0.1,
99
+ )
100
+
101
+ lr_scheduler = dict(
102
+ total_steps=data["total_steps"],
103
+ init_steps=0, # optimizer_warmup_step
104
+ warmup_ratio=0.0056,
105
+ eta_min=1.5e-5,
106
+ last_epoch=-1,
107
+ )
108
+
109
+ beta2_scheduler = dict(
110
+ init_beta2=adam["adam_beta2"],
111
+ c=adam["adam_beta2_c"],
112
+ cur_iter=-1,
113
+ )
114
+
115
+ model = dict(
116
+ checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
117
+ num_attention_heads=NUM_ATTENTION_HEAD,
118
+ embed_split_hidden=True,
119
+ vocab_size=VOCAB_SIZE,
120
+ embed_grad_scale=1,
121
+ parallel_output=True,
122
+ hidden_size=HIDDEN_SIZE,
123
+ num_layers=NUM_LAYER,
124
+ mlp_ratio=MLP_RATIO,
125
+ apply_post_layer_norm=False,
126
+ dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32"
127
+ norm_type="rmsnorm",
128
+ layer_norm_epsilon=1e-5,
129
+ use_flash_attn=True,
130
+ num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
131
+ lvm_config=dict(
132
+ enable=True,
133
+ embedding_cfg=dict(
134
+ vq_model_path='/cache/ckpt/vqgan-f16-8192-laion/',
135
+ embedding_dim=HIDDEN_SIZE,
136
+ freeze_vq_model=True,
137
+ ),
138
+ )
139
+ )
140
+ """
141
+ zero1 parallel:
142
+ 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
143
+ so parameters will be divided within the range of dp.
144
+ 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
145
+ 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
146
+ For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
147
+ pipeline parallel (dict):
148
+ 1. size: int, the size of pipeline parallel.
149
+ 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
150
+ tensor parallel: tensor parallel size, usually the number of GPUs per node.
151
+ """
152
+ parallel = dict(
153
+ zero1=8,
154
+ pipeline=dict(size=1, interleaved_overlap=True),
155
+ sequence_parallel=False,
156
+ )
157
+
158
+ cudnn_deterministic = False
159
+ cudnn_benchmark = False
160
+
161
+ monitor = dict(
162
+ # feishu alert configs
163
+ alert=dict(
164
+ enable_feishu_alert=DO_ALERT,
165
+ feishu_alert_address=None, # feishu webhook to send alert message
166
+ light_monitor_address=None, # light_monitor address to send heartbeat
167
+ ),
168
+ )
InternLM/internlm/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .initialize.initialize_trainer import initialize_trainer, initialize_kd_trainer
2
+ from .initialize.launch import get_default_parser, launch_from_slurm, launch_from_torch
3
+
4
+ __all__ = [
5
+ "get_default_parser",
6
+ "initialize_kd_trainer",
7
+ "initialize_trainer",
8
+ "launch_from_slurm",
9
+ "launch_from_torch",
10
+ ]
InternLM/internlm/apis/__init__.py ADDED
File without changes
InternLM/internlm/apis/inference.py ADDED
@@ -0,0 +1,848 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ __all__ = ["SequenceGenerator"]
9
+
10
+
11
+ class InferenceParams:
12
+ """
13
+ Intermediate cache objects for inference
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ max_sequence_len,
19
+ max_batch_size,
20
+ sequence_len_offset=0,
21
+ batch_size_offset=0,
22
+ key_value_memory_dict: dict = None,
23
+ lengths_per_sample=None,
24
+ attention_mask=None,
25
+ ) -> None:
26
+
27
+ self.max_sequence_len: int = max_sequence_len
28
+ self.max_batch_size: int = max_batch_size
29
+ self.sequence_len_offset: int = sequence_len_offset
30
+ self.batch_size_offset: int = batch_size_offset
31
+ if key_value_memory_dict is None:
32
+ key_value_memory_dict = {}
33
+ self.key_value_memory_dict: dict = key_value_memory_dict
34
+ self.fused_ft_kernel: bool = False
35
+ self.lengths_per_sample = lengths_per_sample
36
+ self.attention_mask = attention_mask
37
+
38
+ def reorder_state(self, indices):
39
+ if self.lengths_per_sample is not None:
40
+ self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0)
41
+ for key, value in list(self.key_value_memory_dict.items()):
42
+ value = value.index_select(index=indices, dim=0)
43
+ self.key_value_memory_dict[key] = value
44
+
45
+
46
+ def _get_model_device(model):
47
+ """
48
+ obtain the device of an nn.Module.model
49
+
50
+ Args:
51
+ model: nn.Module
52
+
53
+ Return: torch.device. if None, the parameters of this model is None.
54
+ """
55
+ assert isinstance(model, nn.Module)
56
+
57
+ parameters = list(model.parameters())
58
+ if len(parameters) == 0:
59
+ return None
60
+ else:
61
+ return parameters[0].device
62
+
63
+
64
+ class SequenceGenerator:
65
+ """
66
+ Sequence Generator.
67
+ """
68
+
69
+ def __init__(self, decoder, eos_token_id, pad_token_id, bos_token_id):
70
+ self.decoder = decoder
71
+ self.eos_token_id = eos_token_id
72
+ self.pad_token_id = pad_token_id
73
+ self.bos_token_id = bos_token_id
74
+
75
+ @torch.no_grad()
76
+ def generate(
77
+ self,
78
+ tokens: "torch.LongTensor" = None,
79
+ num_return_sequences=1,
80
+ max_length: int = 20,
81
+ num_beams: int = 1,
82
+ do_sample: bool = True,
83
+ temperature: float = 1.0,
84
+ top_k: int = 50,
85
+ top_p: float = 1.0,
86
+ repetition_penalty: float = 1,
87
+ length_penalty: float = 1.0,
88
+ ):
89
+ """
90
+ Args:
91
+ tokens: the beginning tokens whose shape is [bsz, length]. If shape is None, default ''bos_token'' will be
92
+ added to conduct generation.
93
+ num_return_sequences: number of returned sequences.
94
+ max_length: the max length of generated sequence.
95
+ num_beams: the size of beam search.
96
+ do_sample: whether using sample.
97
+ temperature: it's meaningful when do_sample is True.
98
+ top_k: sampling from top_k.
99
+ top_p: sampling from top_p tokens(nucleus sampling).
100
+
101
+ Return:
102
+ the token sequence whose shape is [bsz, num_return_sequences, max_length]. If eos_token_id is not None,
103
+ the ending of each sequence must be eos_token_id.
104
+ """
105
+ assert num_return_sequences <= num_beams, f"The `{num_return_sequences}` must be less than `{num_beams}`..."
106
+ if do_sample:
107
+ return sample_generate(
108
+ self.decoder,
109
+ tokens=tokens,
110
+ max_length=max_length,
111
+ num_beams=num_beams,
112
+ num_return_sequences=num_return_sequences,
113
+ temperature=temperature,
114
+ top_k=top_k,
115
+ top_p=top_p,
116
+ eos_token_id=self.eos_token_id, # the ending token id
117
+ pad_token_id=self.pad_token_id,
118
+ repetition_penalty=repetition_penalty, # the penalty degree for repetition tokens
119
+ length_penalty=length_penalty, # the penalty for length. if it > 1, then encourages long sequence.
120
+ # Otherwise, encourages short sequence.
121
+ bos_token_id=self.bos_token_id,
122
+ )
123
+ else:
124
+ return greedy_generate(
125
+ self.decoder,
126
+ tokens=tokens,
127
+ max_length=max_length,
128
+ num_beams=num_beams,
129
+ num_return_sequences=num_return_sequences,
130
+ eos_token_id=self.eos_token_id,
131
+ pad_token_id=self.pad_token_id,
132
+ repetition_penalty=repetition_penalty,
133
+ length_penalty=length_penalty,
134
+ bos_token_id=self.bos_token_id,
135
+ )
136
+
137
+
138
+ @torch.no_grad()
139
+ def greedy_generate(
140
+ decoder,
141
+ tokens=None,
142
+ max_length=20,
143
+ num_beams=1,
144
+ num_return_sequences=1,
145
+ eos_token_id=None,
146
+ pad_token_id=0,
147
+ repetition_penalty=1,
148
+ length_penalty=1.0,
149
+ bos_token_id=1,
150
+ feat_mask=None,
151
+ ffn_mask=None,
152
+ layer_mask=None,
153
+ ):
154
+ """
155
+ Search sequence greedily.
156
+
157
+ Args:
158
+ decoder: the Decoder object.
159
+ tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
160
+ max_length: the max length for generated sequence.
161
+ num_beams: the size of beam to decode.
162
+ eos_token_id: the ending token id. If None, the decode length is max_length.
163
+ pad_token_id: the token id of pad.
164
+ repetition_penalty: the penalty degree for repetition tokens
165
+ length_penalty: the penalty for length.
166
+
167
+ """
168
+ if num_beams == 1:
169
+ token_ids = _no_beam_search_generate(
170
+ decoder,
171
+ tokens=tokens,
172
+ max_length=max_length,
173
+ temperature=1,
174
+ top_k=50,
175
+ top_p=1,
176
+ eos_token_id=eos_token_id,
177
+ do_sample=False,
178
+ repetition_penalty=repetition_penalty,
179
+ length_penalty=length_penalty,
180
+ pad_token_id=pad_token_id,
181
+ bos_token_id=bos_token_id,
182
+ feat_mask=feat_mask,
183
+ ffn_mask=ffn_mask,
184
+ layer_mask=layer_mask,
185
+ )
186
+ else:
187
+ token_ids = _beam_search_generate(
188
+ decoder,
189
+ tokens=tokens,
190
+ max_length=max_length,
191
+ num_beams=num_beams,
192
+ num_return_sequences=num_return_sequences,
193
+ temperature=1,
194
+ top_k=50,
195
+ top_p=1,
196
+ eos_token_id=eos_token_id,
197
+ do_sample=False,
198
+ repetition_penalty=repetition_penalty,
199
+ length_penalty=length_penalty,
200
+ pad_token_id=pad_token_id,
201
+ bos_token_id=bos_token_id,
202
+ feat_mask=feat_mask,
203
+ ffn_mask=ffn_mask,
204
+ layer_mask=layer_mask,
205
+ )
206
+
207
+ return token_ids
208
+
209
+
210
+ @torch.no_grad()
211
+ def sample_generate(
212
+ decoder,
213
+ tokens,
214
+ max_length=20,
215
+ num_beams=1,
216
+ num_return_sequences=1,
217
+ temperature=1.0,
218
+ top_k=50,
219
+ top_p=1.0,
220
+ eos_token_id=None,
221
+ pad_token_id=0,
222
+ repetition_penalty=1.0,
223
+ length_penalty=1.0,
224
+ bos_token_id=1,
225
+ ):
226
+ """
227
+ generate sequence in sampling way.
228
+
229
+ Args:
230
+ decoder: the Decoder object.
231
+ tokens: the shape is [batch size, length]. If decoder is None, generating begins with bos_token_id.
232
+ max_length: the max length for generated sequence.
233
+ num_beams: the size of beam to decode.
234
+ num_return_sequences: number of returned sequence.
235
+ temperature: annealing magnitude during sampling.
236
+ top_k: sampling from top_k. (Default: 50)
237
+ top_p: sampling from top_p tokens(nucleus sampling). (Default: 1.0)
238
+ eos_token_id: the ending token id. If None, the decode length is max_length.
239
+ pad_token_id: the token id of pad.
240
+ repetition_penalty: the penalty degree for repetition tokens
241
+ length_penalty: the penalty for length.
242
+
243
+ """
244
+ if num_beams == 1:
245
+ token_ids = _no_beam_search_generate(
246
+ decoder,
247
+ tokens=tokens,
248
+ max_length=max_length,
249
+ temperature=temperature,
250
+ top_k=top_k,
251
+ top_p=top_p,
252
+ eos_token_id=eos_token_id,
253
+ do_sample=True,
254
+ repetition_penalty=repetition_penalty,
255
+ length_penalty=length_penalty,
256
+ pad_token_id=pad_token_id,
257
+ bos_token_id=bos_token_id,
258
+ )
259
+ else:
260
+ token_ids = _beam_search_generate(
261
+ decoder,
262
+ tokens=tokens,
263
+ max_length=max_length,
264
+ num_beams=num_beams,
265
+ num_return_sequences=num_return_sequences,
266
+ temperature=temperature,
267
+ top_k=top_k,
268
+ top_p=top_p,
269
+ eos_token_id=eos_token_id,
270
+ do_sample=True,
271
+ repetition_penalty=repetition_penalty,
272
+ length_penalty=length_penalty,
273
+ pad_token_id=pad_token_id,
274
+ bos_token_id=bos_token_id,
275
+ )
276
+ return token_ids
277
+
278
+
279
+ @torch.no_grad()
280
+ def _no_beam_search_generate(
281
+ decoder,
282
+ tokens,
283
+ inference_params=None,
284
+ max_length=20,
285
+ temperature=1.0,
286
+ top_k=50,
287
+ top_p=1.0,
288
+ eos_token_id=None,
289
+ do_sample=True,
290
+ repetition_penalty=1.0,
291
+ length_penalty=1.0,
292
+ pad_token_id=0,
293
+ bos_token_id=1,
294
+ feat_mask=None,
295
+ ffn_mask=None,
296
+ layer_mask=None,
297
+ ):
298
+ # delete num_return_sequences=1 for lint check;
299
+ batch_size = tokens.size(0)
300
+ if eos_token_id is None:
301
+ _eos_token_id = -1
302
+ else:
303
+ _eos_token_id = eos_token_id
304
+
305
+ has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
306
+ if has_bos:
307
+ bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
308
+ bos_sum = bos_pos.cumsum(dim=-1)
309
+ bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
310
+ to_atten_x = bos_pos[:, :, None]
311
+ to_atten_y = bos_pos[:, None, :]
312
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
313
+ else:
314
+ bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
315
+ to_atten_x = bos_pos[:, :, None]
316
+ to_atten_y = bos_pos[:, None, :]
317
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
318
+ attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
319
+ if inference_params is None:
320
+ inference_params = InferenceParams(
321
+ max_sequence_len=max_length,
322
+ max_batch_size=tokens.size(0),
323
+ sequence_len_offset=0,
324
+ batch_size_offset=0,
325
+ key_value_memory_dict=None,
326
+ lengths_per_sample=None,
327
+ attention_mask=attention_mask,
328
+ )
329
+
330
+ if layer_mask is None:
331
+ if feat_mask is None and ffn_mask is None:
332
+ scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
333
+ else:
334
+ scores = decoder(
335
+ **{
336
+ "input_ids": tokens,
337
+ "inference_params": inference_params,
338
+ "feat_mask": feat_mask,
339
+ "ffn_mask": ffn_mask,
340
+ }
341
+ )
342
+ else:
343
+ scores = decoder(
344
+ **{
345
+ "input_ids": tokens,
346
+ "inference_params": inference_params,
347
+ "feat_mask": feat_mask,
348
+ "ffn_mask": ffn_mask,
349
+ "layer_mask": layer_mask,
350
+ }
351
+ )
352
+
353
+ if isinstance(scores, (list, tuple)):
354
+ scores = scores[0]
355
+ scores = scores[:, -1].float()
356
+ inference_params.sequence_len_offset += tokens.size(1)
357
+ if _eos_token_id != -1:
358
+ scores[:, _eos_token_id] = -1e12
359
+ next_tokens = scores.argmax(dim=-1, keepdim=True)
360
+ token_ids = torch.cat([tokens, next_tokens], dim=1)
361
+ cur_len = token_ids.size(1)
362
+ dones = token_ids.new_zeros(batch_size).eq(1)
363
+ # tokens = tokens[:, -1:]
364
+
365
+ real_max_length = max_length
366
+ max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
367
+
368
+ while cur_len < real_max_length:
369
+ # batch_size x vocab_size
370
+ if has_bos:
371
+ bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
372
+ bos_sum = bos_pos.cumsum(dim=-1)
373
+ bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
374
+ to_atten_x = bos_pos[:, :, None]
375
+ to_atten_y = bos_pos[:, None, :]
376
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
377
+ else:
378
+ bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
379
+ to_atten_x = bos_pos[:, :, None]
380
+ to_atten_y = bos_pos[:, None, :]
381
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
382
+ attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
383
+ inference_params.attention_mask = attention_mask
384
+ if layer_mask is None:
385
+ if feat_mask is None and ffn_mask is None:
386
+ scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
387
+ else:
388
+ scores = decoder(
389
+ **{
390
+ "input_ids": token_ids[:, -1:],
391
+ "inference_params": inference_params,
392
+ "feat_mask": feat_mask,
393
+ "ffn_mask": ffn_mask,
394
+ }
395
+ )
396
+ else:
397
+ scores = decoder(
398
+ **{
399
+ "input_ids": token_ids[:, -1:],
400
+ "inference_params": inference_params,
401
+ "feat_mask": feat_mask,
402
+ "ffn_mask": ffn_mask,
403
+ "layer_mask": layer_mask,
404
+ }
405
+ )
406
+
407
+ if isinstance(scores, (list, tuple)):
408
+ scores = scores[0]
409
+ scores = scores[:, -1].float()
410
+ inference_params.sequence_len_offset += 1
411
+
412
+ if repetition_penalty != 1.0:
413
+ token_scores = scores.gather(dim=1, index=token_ids)
414
+ lt_zero_mask = token_scores.lt(0).float()
415
+ ge_zero_mask = lt_zero_mask.eq(0).float()
416
+ token_scores = (
417
+ lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
418
+ )
419
+ scores.scatter_(dim=1, index=token_ids, src=token_scores)
420
+
421
+ if eos_token_id is not None and length_penalty != 1.0:
422
+ # batch_size x vocab_size
423
+ token_scores = scores / cur_len**length_penalty
424
+ eos_mask = scores.new_ones(scores.size(1))
425
+ eos_mask[eos_token_id] = 0
426
+ eos_mask = eos_mask.unsqueeze(0).eq(1)
427
+
428
+ scores = scores.masked_scatter(eos_mask, token_scores)
429
+
430
+ if do_sample:
431
+ if temperature > 0 and temperature != 1:
432
+ scores = scores / temperature
433
+
434
+ scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2)
435
+ # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
436
+ probs = F.softmax(scores, dim=-1) + 1e-12
437
+
438
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size
439
+ else:
440
+ next_tokens = torch.argmax(scores, dim=-1) # batch_size
441
+
442
+ if _eos_token_id != -1:
443
+ next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), _eos_token_id)
444
+ next_tokens = next_tokens.masked_fill(dones, pad_token_id)
445
+ tokens = next_tokens.unsqueeze(1)
446
+
447
+ token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
448
+
449
+ end_mask = next_tokens.eq(_eos_token_id)
450
+ dones = dones.__or__(end_mask)
451
+ cur_len += 1
452
+
453
+ if dones.min() == 1:
454
+ break
455
+
456
+ # if eos_token_id is not None:
457
+ # # setting the eos at the maximum length position
458
+ # tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id)
459
+ # if cur_len == max_length:
460
+ # # If eos is not reached by the maximum length, forcibly replace the last word with eos
461
+ # token_ids[:, -1].masked_fill_(~dones, eos_token_id)
462
+ # TODO Here we are simply adding an extra dimension for interface compatibility, but in the future it will need to
463
+ # be able to return multiple real results
464
+ return token_ids[:, None]
465
+
466
+
467
+ @torch.no_grad()
468
+ def _beam_search_generate(
469
+ decoder,
470
+ tokens,
471
+ inference_params=None,
472
+ max_length=20,
473
+ num_beams=4,
474
+ num_return_sequences=1,
475
+ temperature=1.0,
476
+ top_k=50,
477
+ top_p=1.0,
478
+ eos_token_id=None,
479
+ do_sample=True,
480
+ repetition_penalty=1.0,
481
+ length_penalty=1.0,
482
+ pad_token_id=0,
483
+ bos_token_id=1,
484
+ feat_mask=None,
485
+ ffn_mask=None,
486
+ layer_mask=None,
487
+ ) -> torch.LongTensor:
488
+
489
+ device = _get_model_device(decoder)
490
+ batch_size = tokens.size(0)
491
+
492
+ if eos_token_id is None:
493
+ _eos_token_id = -1
494
+ else:
495
+ _eos_token_id = eos_token_id
496
+
497
+ has_bos = torch.all(tokens[:, 0].eq(bos_token_id))
498
+
499
+ if has_bos:
500
+ bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
501
+ bos_sum = bos_pos.cumsum(dim=-1)
502
+ bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
503
+ to_atten_x = bos_pos[:, :, None]
504
+ to_atten_y = bos_pos[:, None, :]
505
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
506
+ else:
507
+ bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0)
508
+ to_atten_x = bos_pos[:, :, None]
509
+ to_atten_y = bos_pos[:, None, :]
510
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
511
+ attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
512
+
513
+ if inference_params is None:
514
+ inference_params = InferenceParams(
515
+ max_sequence_len=max_length,
516
+ max_batch_size=tokens.size(0),
517
+ sequence_len_offset=0,
518
+ batch_size_offset=0,
519
+ key_value_memory_dict=None,
520
+ lengths_per_sample=None,
521
+ attention_mask=attention_mask,
522
+ )
523
+
524
+ if layer_mask is None:
525
+ if feat_mask is None and ffn_mask is None:
526
+ scores = decoder(**{"input_ids": tokens, "inference_params": inference_params})
527
+ else:
528
+ scores = decoder(
529
+ **{
530
+ "input_ids": tokens,
531
+ "inference_params": inference_params,
532
+ "feat_mask": feat_mask,
533
+ "ffn_mask": ffn_mask,
534
+ }
535
+ )
536
+ else:
537
+ scores = decoder(
538
+ **{
539
+ "input_ids": tokens,
540
+ "inference_params": inference_params,
541
+ "feat_mask": feat_mask,
542
+ "ffn_mask": ffn_mask,
543
+ "layer_mask": layer_mask,
544
+ }
545
+ )
546
+
547
+ if isinstance(scores, (list, tuple)):
548
+ scores = scores[0]
549
+ scores = scores[:, -1].float()
550
+ inference_params.sequence_len_offset += tokens.size(1)
551
+ if _eos_token_id != -1:
552
+ scores[:, _eos_token_id] = -1e12
553
+ vocab_size = scores.size(1)
554
+ assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size."
555
+
556
+ if do_sample:
557
+ probs = F.softmax(scores, dim=-1) + 1e-12
558
+ # (batch_size, num_beams)
559
+ next_tokens = torch.multinomial(probs, num_samples=num_beams)
560
+ logits = probs.log()
561
+ # (batch_size, num_beams)
562
+ next_scores = logits.gather(dim=1, index=next_tokens)
563
+ else:
564
+ scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
565
+ # obtain (batch_size, num_beams), (batch_size, num_beams)
566
+ next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True)
567
+
568
+ indices = torch.arange(batch_size, dtype=torch.long).to(device)
569
+ indices = indices.repeat_interleave(num_beams)
570
+ inference_params.reorder_state(indices)
571
+
572
+ # batch_size * num_beams x length
573
+ tokens = tokens.index_select(dim=0, index=indices)
574
+ # genrated token (batch_size', cur_len)
575
+ token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
576
+ dones = [False] * batch_size
577
+
578
+ beam_scores = next_scores.view(-1) # batch_size * num_beams
579
+
580
+ cur_len = token_ids.size(1)
581
+
582
+ real_max_length = max_length
583
+ max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
584
+ hypos = [
585
+ BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
586
+ ]
587
+ # 0, num_beams, 2*num_beams, ...
588
+ batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
589
+
590
+ while cur_len < real_max_length:
591
+ if has_bos:
592
+ bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
593
+ bos_sum = bos_pos.cumsum(dim=-1)
594
+ bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1)
595
+ to_atten_x = bos_pos[:, :, None]
596
+ to_atten_y = bos_pos[:, None, :]
597
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
598
+ else:
599
+ bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0)
600
+ to_atten_x = bos_pos[:, :, None]
601
+ to_atten_y = bos_pos[:, None, :]
602
+ # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1)
603
+ attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1)
604
+
605
+ inference_params.attention_mask = attention_mask
606
+ # (bsz x num_beams, vocab_size)
607
+
608
+ if layer_mask is None:
609
+ if feat_mask is None and ffn_mask is None:
610
+ scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params})
611
+ else:
612
+ scores = decoder(
613
+ **{
614
+ "input_ids": token_ids[:, -1:],
615
+ "inference_params": inference_params,
616
+ "feat_mask": feat_mask,
617
+ "ffn_mask": ffn_mask,
618
+ }
619
+ )
620
+ else:
621
+ scores = decoder(
622
+ **{
623
+ "input_ids": token_ids[:, -1:],
624
+ "inference_params": inference_params,
625
+ "feat_mask": feat_mask,
626
+ "ffn_mask": ffn_mask,
627
+ "layer_mask": layer_mask,
628
+ }
629
+ )
630
+
631
+ if isinstance(scores, (list, tuple)):
632
+ scores = scores[0]
633
+ scores = scores[:, -1].float()
634
+ inference_params.sequence_len_offset += 1
635
+ if repetition_penalty != 1.0:
636
+ token_scores = scores.gather(dim=1, index=token_ids)
637
+ lt_zero_mask = token_scores.lt(0).float()
638
+ ge_zero_mask = lt_zero_mask.eq(0).float()
639
+ token_scores = (
640
+ lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
641
+ )
642
+ scores.scatter_(dim=1, index=token_ids, src=token_scores)
643
+
644
+ if _eos_token_id != -1:
645
+ max_len_eos_mask = max_lengths.eq(cur_len + 1)
646
+ eos_scores = scores[:, _eos_token_id]
647
+ scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores)
648
+
649
+ if do_sample:
650
+ if temperature > 0 and temperature != 1:
651
+ scores = scores / temperature
652
+
653
+ scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1)
654
+ # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523
655
+ probs = F.softmax(scores, dim=-1) + 1e-12
656
+
657
+ # batch_size' x (num_beams+1)
658
+ _tokens = torch.multinomial(probs, num_samples=num_beams + 1)
659
+
660
+ logits = probs.log()
661
+ # batch_size' x (num_beams+1)
662
+ _scores = logits.gather(dim=1, index=_tokens)
663
+ # batch_size' x (num_beams+1)
664
+ _scores = _scores + beam_scores[:, None]
665
+ _scores = _scores.view(batch_size, num_beams * (num_beams + 1))
666
+ next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
667
+ _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1))
668
+ # (batch_size, 2*num_beams)
669
+ next_tokens = _tokens.gather(dim=1, index=ids)
670
+ # (batch_size, 2*num_beams)
671
+ from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long()
672
+ else:
673
+ # (batch_size * num_beams, vocab_size)
674
+ scores = F.log_softmax(scores, dim=-1)
675
+ # (batch_size * num_beams, vocab_size)
676
+ _scores = scores + beam_scores[:, None]
677
+ # (batch_size, num_beams*vocab_size)
678
+ _scores = _scores.view(batch_size, -1)
679
+ # (bsz, 2*num_beams)
680
+ next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
681
+ # (batch_size, 2*num_beams)
682
+ from_which_beam = torch.floor(ids.float() / vocab_size).long()
683
+ next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
684
+
685
+ # next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True)
686
+ # next_tokens = next_tokens.gather(dim=1, index=sorted_inds)
687
+ # from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds)
688
+
689
+ not_eos_mask = next_tokens.ne(_eos_token_id)
690
+ keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
691
+ keep_mask = not_eos_mask.__and__(keep_mask)
692
+
693
+ _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
694
+ _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
695
+ _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
696
+ beam_scores = _next_scores.view(-1)
697
+
698
+ flag = True
699
+ if cur_len + 1 == real_max_length:
700
+ eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
701
+ eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
702
+ eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
703
+ else:
704
+ effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
705
+ if effective_eos_mask.sum().gt(0):
706
+ eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
707
+ eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
708
+ eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
709
+ else:
710
+ flag = False
711
+
712
+ if flag:
713
+ _token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
714
+ for batch_idx, beam_ind, beam_idx in zip(
715
+ eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()
716
+ ):
717
+ if not dones[batch_idx]:
718
+ score = next_scores[batch_idx, beam_ind].item()
719
+ if _eos_token_id != -1:
720
+ hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
721
+ else:
722
+ hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
723
+
724
+ reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1)
725
+ inference_params.reorder_state(reorder_inds)
726
+ token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
727
+
728
+ for batch_idx in range(batch_size):
729
+ dones[batch_idx] = (
730
+ dones[batch_idx]
731
+ or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item())
732
+ or max_lengths[batch_idx * num_beams] == cur_len + 1
733
+ )
734
+
735
+ cur_len += 1
736
+
737
+ if all(dones):
738
+ break
739
+
740
+ # select the best hypotheses
741
+ tgt_len = token_ids.new_zeros(batch_size, num_return_sequences)
742
+ best = []
743
+
744
+ for i, hypotheses in enumerate(hypos):
745
+ # best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
746
+ sorted_hyp = list(sorted(hypotheses.hyp, key=lambda x: x[0], reverse=True))
747
+ _best = []
748
+ for j, hyp in zip(range(num_return_sequences), sorted_hyp):
749
+ hyp = hyp[1]
750
+ if _eos_token_id != -1:
751
+ hyp = torch.cat([hyp, token_ids.new_ones(1) * _eos_token_id])
752
+ tgt_len[i, j] = len(hyp)
753
+ _best.append(hyp)
754
+ best.append(_best)
755
+
756
+ # generate target batch
757
+ decoded = token_ids.new_zeros(batch_size, num_return_sequences, tgt_len.max().item()).fill_(pad_token_id)
758
+ for i, hypo in enumerate(best):
759
+ for j, _hypo in enumerate(hypo):
760
+ decoded[i, j, : tgt_len[i, j]] = _hypo
761
+
762
+ return decoded
763
+
764
+
765
+ class BeamHypotheses(object):
766
+ """
767
+ BeamHypotheses
768
+ """
769
+
770
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
771
+ """Initialize n-best list of hypotheses."""
772
+ self.max_length = max_length - 1 # ignoring bos_token
773
+ self.length_penalty = length_penalty
774
+ self.early_stopping = early_stopping
775
+ self.num_beams = num_beams
776
+ self.hyp = []
777
+ self.worst_score = 1e9
778
+
779
+ def __len__(self):
780
+ """Number of hypotheses in the list."""
781
+ return len(self.hyp)
782
+
783
+ def add(self, hyp, sum_logprobs):
784
+ """Add a new hypothesis to the list."""
785
+ score = sum_logprobs / len(hyp) ** self.length_penalty
786
+ if len(self) < self.num_beams or score > self.worst_score:
787
+ self.hyp.append((score, hyp))
788
+ if len(self) > self.num_beams:
789
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
790
+ del self.hyp[sorted_scores[0][1]]
791
+ self.worst_score = sorted_scores[1][0]
792
+ else:
793
+ self.worst_score = min(score, self.worst_score)
794
+
795
+ def is_done(self, best_sum_logprobs):
796
+ """If there are enough hypotheses and that none of the hypotheses being
797
+ generated can become better than the worst one in the heap, then we are
798
+ done with this sentence."""
799
+ if len(self) < self.num_beams:
800
+ return False
801
+ elif self.early_stopping:
802
+ return True
803
+ else:
804
+ return self.worst_score >= best_sum_logprobs / self.max_length**self.length_penalty
805
+
806
+
807
+ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
808
+ """
809
+ Based on the values of top_k and top_p, set the values that do not meet the criteria to the filter_value.
810
+
811
+ Args:
812
+ logits: logit value, shape is [bsz, vocab_size].
813
+ top_k: If it is greater than 0, only the probabilities of the top_k vocabulary are kept, and the rest of
814
+ the positions are set to filter_value.
815
+ top_p: according to http://arxiv.org/abs/1904.09751.
816
+ filter_value: filter value
817
+ min_tokens_to_keep: The probability of words in each sample‘s returned distribution will not be
818
+ lower than this value.
819
+
820
+ """
821
+ if top_k > 0:
822
+ # Safety check
823
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
824
+ # Remove all tokens with a probability less than the last token of
825
+ # the top-k
826
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
827
+ logits[indices_to_remove] = filter_value
828
+
829
+ if top_p < 1.0:
830
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
831
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
832
+
833
+ # Remove tokens with cumulative probability above the threshold
834
+ # (token with 0 are kept)
835
+ sorted_indices_to_remove = cumulative_probs > top_p
836
+ if min_tokens_to_keep > 1:
837
+ # Keep at least min_tokens_to_keep
838
+ # (set to min_tokens_to_keep-1 because we add the first one below)
839
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
840
+ # Shift the indices to the right to keep also the first token
841
+ # above the threshold
842
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
843
+ sorted_indices_to_remove[..., 0] = 0
844
+
845
+ # scatter sorted tensors to original indexing
846
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
847
+ logits[indices_to_remove] = filter_value
848
+ return logits
InternLM/internlm/core/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .engine import Engine
2
+ from .naive_amp import NaiveAMPModel
3
+ from .trainer import Trainer
4
+
5
+ __all__ = [
6
+ "NaiveAMPModel",
7
+ "Engine",
8
+ "Trainer",
9
+ ]
InternLM/internlm/core/communication/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .p2p import (
2
+ AsynCommunicator,
3
+ recv_backward,
4
+ recv_forward,
5
+ send_backward,
6
+ send_backward_and_recv_next_backward_async,
7
+ send_backward_recv_backward,
8
+ send_backward_recv_forward,
9
+ send_forward,
10
+ send_forward_and_recv_next_forward_async,
11
+ send_forward_backward_recv_forward_backward,
12
+ send_forward_recv_backward,
13
+ send_forward_recv_forward,
14
+ )
15
+ from .utils import recv_obj_meta, send_obj_meta
16
+
17
+ __all__ = [
18
+ "send_forward",
19
+ "send_forward_recv_forward",
20
+ "send_forward_backward_recv_forward_backward",
21
+ "send_backward",
22
+ "send_backward_recv_backward",
23
+ "send_backward_recv_forward",
24
+ "send_forward_recv_backward",
25
+ "recv_backward",
26
+ "recv_forward",
27
+ "send_obj_meta",
28
+ "recv_obj_meta",
29
+ "send_backward_and_recv_next_backward_async",
30
+ "send_forward_and_recv_next_forward_async",
31
+ "AsynCommunicator",
32
+ ]
InternLM/internlm/core/communication/p2p.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
5
+
6
+ import operator
7
+ from functools import reduce
8
+ from typing import List, Tuple, Union
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ from internlm.core.context import ParallelMode
14
+ from internlm.core.context import global_context as gpc
15
+ from internlm.utils.common import get_current_device
16
+
17
+ from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
18
+
19
+ TensorShape = Union[torch.Size, List[int], Tuple[int]]
20
+
21
+
22
+ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
23
+ """get the exact tensor shape when communicating and return whether the tensor is a chunk
24
+
25
+ Args:
26
+ tensor_shape (:class:`torch.Size`): shape of tensor
27
+ chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
28
+
29
+ Returns:
30
+ Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
31
+ """
32
+ if chunk_tensor:
33
+ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
34
+ tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
35
+ if tensor_chunk_shape % tensor_parallel_world_size == 0:
36
+ tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
37
+ else:
38
+ tensor_chunk_shape = tensor_shape
39
+ chunk_tensor = False
40
+ else:
41
+ tensor_chunk_shape = tensor_shape
42
+ return tensor_chunk_shape, chunk_tensor
43
+
44
+
45
+ def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
46
+ if isinstance(recv_shapes, torch.Size):
47
+ recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
48
+ buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
49
+ return buffer_recv, recv_split
50
+ buffer_recv = []
51
+ for recv_shape in recv_shapes:
52
+ recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
53
+ tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
54
+ buffer_recv.append(tensor_recv)
55
+ return buffer_recv, recv_split
56
+
57
+
58
+ def process_object_to_send(object_send, scatter_gather_tensors):
59
+ if isinstance(object_send, torch.Tensor):
60
+ send_split = _get_tensor_shape(object_send.shape, scatter_gather_tensors)[1]
61
+ if send_split:
62
+ object_send = split_tensor_into_1d_equal_chunks(object_send)
63
+ return object_send
64
+
65
+ object_send_list = []
66
+ for tensor_send in object_send:
67
+ send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
68
+ if send_split:
69
+ object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
70
+ else:
71
+ object_send_list.append(tensor_send)
72
+ object_send = tuple(object_send_list)
73
+
74
+ return object_send
75
+
76
+
77
+ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
78
+ if isinstance(obj, torch.Tensor):
79
+ op_to_add = dist.P2POp(comm_op, obj, comm_rank)
80
+ ops_queue.append(op_to_add)
81
+ else:
82
+ for tensor_to_comm in obj:
83
+ op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
84
+ ops_queue.append(op_to_add)
85
+
86
+
87
+ def _communicate(
88
+ object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
89
+ object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
90
+ recv_prev: bool = False,
91
+ recv_next: bool = False,
92
+ recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
93
+ recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
94
+ prev_rank: int = None,
95
+ next_rank: int = None,
96
+ dtype: torch.dtype = None,
97
+ scatter_gather_tensors: bool = False,
98
+ ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
99
+ """
100
+ Adapted from megatron.p2p_communication.
101
+ Communicate tensors between stages. Used as helper method in other
102
+ communication methods that are used in pipeline schedule.
103
+ Takes the following arguments:
104
+ object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank
105
+ (no tensor sent if set to None).
106
+ object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank
107
+ (no tensor sent if set to None).
108
+ recv_prev (bool): boolean for whether tensor should be received from
109
+ previous rank.
110
+ recv_next (bool): boolean for whether tensor should be received from
111
+ next rank.
112
+ recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
113
+ from the previous stage, defualts to None.
114
+ recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received
115
+ from the next stage, defualts to None.
116
+ prev_rank (int): the rank of the previous pipeline stage, defualts to None,
117
+ next_rank (int): the rank of the next pipeline stage, defualts to None,
118
+ dtype (torch.dtype): data type of intermediate buffers, defaults to None
119
+ scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False
120
+
121
+ Returns:
122
+ Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next
123
+ """
124
+
125
+ # Create placeholder tensors for receive in forward and backward directions
126
+ # if needed.
127
+ tensor_recv_prev = None
128
+ tensor_recv_next = None
129
+
130
+ if recv_prev:
131
+ assert recv_prev_shape is not None
132
+ tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
133
+ recv_prev_shape, dtype, scatter_gather_tensors
134
+ )
135
+
136
+ if recv_next:
137
+ assert recv_next_shape is not None
138
+ tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
139
+ recv_next_shape, dtype, scatter_gather_tensors
140
+ )
141
+
142
+ if object_send_prev is not None or recv_prev:
143
+ if prev_rank is None:
144
+ prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
145
+
146
+ if object_send_next is not None or recv_next:
147
+ if next_rank is None:
148
+ next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
149
+
150
+ if object_send_prev is not None:
151
+ object_send_prev = process_object_to_send(object_send_prev, scatter_gather_tensors)
152
+
153
+ if object_send_next is not None:
154
+ object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)
155
+
156
+ ops = []
157
+ if object_send_prev is not None:
158
+ filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)
159
+
160
+ if tensor_recv_prev is not None:
161
+ filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
162
+
163
+ if tensor_recv_next is not None:
164
+ filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
165
+
166
+ if object_send_next is not None:
167
+ filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
168
+
169
+ if len(ops) > 0:
170
+ reqs = dist.batch_isend_irecv(ops)
171
+ for req in reqs:
172
+ req.wait()
173
+ # To protect against race condition when using batch_isend_irecv().
174
+ torch.cuda.synchronize()
175
+
176
+ if recv_prev and recv_prev_split:
177
+ if isinstance(tensor_recv_prev, torch.Tensor):
178
+ tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
179
+ else:
180
+ for index in range(len(tensor_recv_prev)):
181
+ tensor_recv_prev[index] = (
182
+ gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
183
+ )
184
+
185
+ if recv_next and recv_next_split:
186
+ if isinstance(tensor_recv_next, torch.Tensor):
187
+ tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
188
+ else:
189
+ for index in range(len(tensor_recv_next)):
190
+ tensor_recv_next[index] = (
191
+ gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
192
+ )
193
+
194
+ return tensor_recv_prev, tensor_recv_next
195
+
196
+
197
+ def recv_forward(
198
+ input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
199
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
200
+ """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
201
+
202
+ Args:
203
+ input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
204
+ to be received.
205
+ prev_rank (int, optional): The rank of the source of the tensor.
206
+
207
+ Returns:
208
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor or input tensor list.
209
+ """
210
+ input_tensor, _ = _communicate(
211
+ recv_prev=True,
212
+ recv_prev_shape=input_tensor_shape,
213
+ prev_rank=prev_rank,
214
+ dtype=dtype,
215
+ scatter_gather_tensors=scatter_gather_tensors,
216
+ )
217
+ return input_tensor
218
+
219
+
220
+ def recv_backward(
221
+ output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
222
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
223
+ """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
224
+
225
+ Args:
226
+ output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
227
+ to be received.
228
+ next_rank (int, optional): The rank of the source of the tensor.
229
+
230
+ Returns:
231
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor or gradident tensor list.
232
+ """
233
+ _, output_tensor_grad = _communicate(
234
+ recv_next=True,
235
+ recv_next_shape=output_grad_shape,
236
+ next_rank=next_rank,
237
+ dtype=dtype,
238
+ scatter_gather_tensors=scatter_gather_tensors,
239
+ )
240
+ return output_tensor_grad
241
+
242
+
243
+ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False) -> None:
244
+ """Sends the input tensor to the next stage in pipeline.
245
+
246
+ Args:
247
+ output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
248
+ next_rank (int, optional): The rank of the recipient of the tensor.
249
+ """
250
+ _communicate(object_send_next=output_tensor, next_rank=next_rank, scatter_gather_tensors=scatter_gather_tensors)
251
+
252
+
253
+ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False) -> None:
254
+ """Sends the gradient tensor to the previous stage in pipeline.
255
+
256
+ Args:
257
+ input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent
258
+ prev_rank (int, optional): The rank of the recipient of the tensor
259
+ """
260
+
261
+ _communicate(object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors)
262
+
263
+
264
+ def send_forward_recv_backward(
265
+ output_tensor, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
266
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
267
+ """Batched communication operation. Sends the input tensor to the
268
+ next stage in pipeline, while receives the gradient tensor from the
269
+ next stage in pipeline as the input gradient tensor of this stage.
270
+
271
+ Args:
272
+ output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
273
+ output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
274
+ to be received.
275
+
276
+ Returns:
277
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
278
+ """
279
+ _, output_tensor_grad = _communicate(
280
+ object_send_next=output_tensor,
281
+ recv_next=output_grad_shape is not None,
282
+ recv_next_shape=output_grad_shape,
283
+ next_rank=next_rank,
284
+ dtype=dtype,
285
+ scatter_gather_tensors=scatter_gather_tensors,
286
+ )
287
+
288
+ return output_tensor_grad
289
+
290
+
291
+ def send_backward_recv_forward(
292
+ input_tensor_grad,
293
+ input_tensor_shape,
294
+ prev_rank=None,
295
+ dtype=torch.float,
296
+ scatter_gather_tensors=False,
297
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
298
+ """Batched communication operation. Sends the gradient tensor to the
299
+ previous stage in pipeline, while receives the output tensor from the
300
+ previous stage in pipeline as the input of this stage.
301
+
302
+ Args:
303
+ input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
304
+ input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
305
+ to be received.
306
+
307
+ Returns:
308
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
309
+ """
310
+ input_tensor, _ = _communicate(
311
+ object_send_prev=input_tensor_grad,
312
+ recv_prev=input_tensor_shape is not None,
313
+ recv_prev_shape=input_tensor_shape,
314
+ prev_rank=prev_rank,
315
+ dtype=dtype,
316
+ scatter_gather_tensors=scatter_gather_tensors,
317
+ )
318
+
319
+ return input_tensor
320
+
321
+
322
+ def send_forward_recv_forward(
323
+ output_tensor,
324
+ input_tensor_shape,
325
+ prev_rank=None,
326
+ next_rank=None,
327
+ dtype=torch.float,
328
+ scatter_gather_tensors=False,
329
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
330
+ """Batched communication operation. Sends the input tensor to the
331
+ next stage in pipeline, while receives the output tensor from the
332
+ previous stage in pipeline as the input of this stage.
333
+
334
+ Args:
335
+ output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
336
+ input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
337
+ to be received.
338
+
339
+ Returns:
340
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
341
+ """
342
+ input_tensor, _ = _communicate(
343
+ object_send_next=output_tensor,
344
+ recv_prev=input_tensor_shape is not None,
345
+ recv_prev_shape=input_tensor_shape,
346
+ prev_rank=prev_rank,
347
+ next_rank=next_rank,
348
+ dtype=dtype,
349
+ scatter_gather_tensors=scatter_gather_tensors,
350
+ )
351
+ return input_tensor
352
+
353
+
354
+ def send_backward_recv_backward(
355
+ input_tensor_grad,
356
+ output_grad_shape,
357
+ prev_rank=None,
358
+ next_rank=None,
359
+ dtype=torch.float,
360
+ scatter_gather_tensors=False,
361
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
362
+ """Batched communication operation. Sends the gradient tensor to the
363
+ previous stage in pipeline, while receives the gradient tensor from the
364
+ next member in pipeline as the input of this stage.
365
+
366
+ Args:
367
+ input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor to be sent.
368
+ output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor
369
+ to be received.
370
+
371
+ Returns:
372
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
373
+ """
374
+ _, output_tensor_grad = _communicate(
375
+ object_send_prev=input_tensor_grad,
376
+ recv_next=output_grad_shape is not None,
377
+ recv_next_shape=output_grad_shape,
378
+ prev_rank=prev_rank,
379
+ next_rank=next_rank,
380
+ dtype=dtype,
381
+ scatter_gather_tensors=scatter_gather_tensors,
382
+ )
383
+ return output_tensor_grad
384
+
385
+
386
+ def send_forward_backward_recv_forward_backward(
387
+ output_tensor,
388
+ input_tensor_grad,
389
+ input_tensor_shape,
390
+ output_grad_shape,
391
+ prev_rank=None,
392
+ next_rank=None,
393
+ dtype=torch.float,
394
+ scatter_gather_tensors=False,
395
+ ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
396
+ """Batched communication operation. Sends the input tensor to the next stage in pipeline and
397
+ the gradient tensor to the previous stage, while receives the input gradient tensor from the
398
+ next stage and the input tensor from the previous stage.
399
+
400
+ Args:
401
+ output_tensor (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the next.
402
+ input_tensor_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Tensor sent to the previous.
403
+ input_tensor_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
404
+ from the previous.
405
+ output_grad_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the tensor received
406
+ from the next.
407
+
408
+ Returns:
409
+ Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`,
410
+ List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
411
+ """
412
+ input_tensor, output_tensor_grad = _communicate(
413
+ object_send_next=output_tensor,
414
+ object_send_prev=input_tensor_grad,
415
+ recv_prev=input_tensor_shape is not None,
416
+ recv_next=output_grad_shape is not None,
417
+ recv_prev_shape=input_tensor_shape,
418
+ recv_next_shape=output_grad_shape,
419
+ prev_rank=prev_rank,
420
+ next_rank=next_rank,
421
+ dtype=dtype,
422
+ scatter_gather_tensors=scatter_gather_tensors,
423
+ )
424
+ return input_tensor, output_tensor_grad
425
+
426
+
427
+ def send_forward_and_recv_next_forward_async(
428
+ output_tensor,
429
+ recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
430
+ dtype: torch.dtype = None,
431
+ scatter_gather_tensors=False,
432
+ ):
433
+ """send forward output to next rank and recv forward input from prev rank"""
434
+
435
+ reqs = []
436
+ tensor_recv_prev = None
437
+
438
+ # prepare send opreations
439
+ if output_tensor is not None:
440
+ next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
441
+
442
+ output_tensor = process_object_to_send(output_tensor, scatter_gather_tensors)
443
+
444
+ if isinstance(output_tensor, torch.Tensor):
445
+ reqs.append(dist.P2POp(dist.isend, output_tensor, next_rank))
446
+ else:
447
+ for tensor_to_comm in output_tensor:
448
+ reqs.append(dist.P2POp(dist.isend, tensor_to_comm, next_rank))
449
+
450
+ # prepare receive opreations
451
+ if recv_prev_shape is not None:
452
+ prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
453
+ # create receive buffer
454
+ tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
455
+ recv_prev_shape, dtype, scatter_gather_tensors
456
+ )
457
+ # generate async receive opterations
458
+ if isinstance(tensor_recv_prev, torch.Tensor):
459
+ reqs.append(dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank))
460
+ else:
461
+ for tensor_to_comm in tensor_recv_prev:
462
+ reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, prev_rank))
463
+
464
+ if len(reqs) > 0:
465
+ reqs = dist.batch_isend_irecv(reqs)
466
+
467
+ # return and do other things
468
+ yield
469
+
470
+ # check communication completed
471
+ for req in reqs:
472
+ req.wait()
473
+ # To protect against race condition when using batch_isend_irecv()
474
+ torch.cuda.synchronize()
475
+
476
+ # Process received data
477
+ if recv_prev_shape is not None and recv_prev_split:
478
+ if isinstance(tensor_recv_prev, torch.Tensor):
479
+ tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
480
+ else:
481
+ for index in range(len(tensor_recv_prev)):
482
+ tensor_recv_prev[index] = (
483
+ gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
484
+ )
485
+
486
+ yield tensor_recv_prev
487
+
488
+
489
+ def send_backward_and_recv_next_backward_async(
490
+ input_tensor,
491
+ recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
492
+ dtype: torch.dtype = None,
493
+ scatter_gather_tensors=False,
494
+ ):
495
+ reqs = []
496
+ tensor_recv_next = None
497
+
498
+ # prepare send opreations
499
+ if input_tensor is not None:
500
+ prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
501
+
502
+ input_tensor = process_object_to_send(input_tensor, scatter_gather_tensors)
503
+
504
+ if isinstance(input_tensor, torch.Tensor):
505
+ reqs.append(dist.P2POp(dist.isend, input_tensor, prev_rank))
506
+ else:
507
+ for tensor_to_comm in input_tensor:
508
+ reqs.append(dist.P2POp(dist.isend, tensor_to_comm, prev_rank))
509
+
510
+ # prepare receive opreations
511
+ if recv_next_shape is not None:
512
+ next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
513
+ # create receive buffer
514
+ tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
515
+ recv_next_shape, dtype, scatter_gather_tensors
516
+ )
517
+ # generate async receive opreations
518
+ if isinstance(tensor_recv_next, torch.Tensor):
519
+ reqs.append(dist.P2POp(dist.irecv, tensor_recv_next, next_rank))
520
+ else:
521
+ for tensor_to_comm in tensor_recv_next:
522
+ reqs.append(dist.P2POp(dist.irecv, tensor_to_comm, next_rank))
523
+
524
+ if len(reqs) > 0:
525
+ reqs = dist.batch_isend_irecv(reqs)
526
+
527
+ # return and do other things
528
+ yield
529
+
530
+ # check communication completed
531
+ for req in reqs:
532
+ req.wait()
533
+ # To protect against race condition when using batch_isend_irecv()
534
+ torch.cuda.synchronize()
535
+
536
+ # Process received data
537
+ if recv_next_shape is not None and recv_next_split:
538
+ if isinstance(tensor_recv_next, torch.Tensor):
539
+ tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
540
+ else:
541
+ for index in range(len(tensor_recv_next)):
542
+ tensor_recv_next[index] = (
543
+ gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
544
+ )
545
+
546
+ yield tensor_recv_next
547
+
548
+
549
+ class AsynCommunicator:
550
+ """AsynCommunicator for managing async communication."""
551
+
552
+ def __init__(
553
+ self,
554
+ tensor_to_send: Union[torch.Tensor, List[torch.Tensor]],
555
+ recv_shape: Union[torch.Size, List[torch.Size]],
556
+ dtype: torch.dtype = None,
557
+ scatter_gather_tensors=False,
558
+ forward: bool = True,
559
+ ) -> None:
560
+ self._need_receive = recv_shape is not None
561
+
562
+ if forward:
563
+ self._coroutine = send_forward_and_recv_next_forward_async(
564
+ tensor_to_send, recv_shape, dtype, scatter_gather_tensors
565
+ )
566
+ else:
567
+ self._coroutine = send_backward_and_recv_next_backward_async(
568
+ tensor_to_send, recv_shape, dtype, scatter_gather_tensors
569
+ )
570
+
571
+ @property
572
+ def need_receive(self) -> bool:
573
+ return self._need_receive
574
+
575
+ def start(self) -> None:
576
+ next(self._coroutine)
577
+
578
+ def wait_and_receive(self) -> Union[torch.Tensor, List[torch.Tensor]]:
579
+ received = next(self._coroutine)
580
+ self._coroutine.close()
581
+
582
+ return received
InternLM/internlm/core/communication/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
2
+
3
+ from typing import List, Tuple, Union
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ from internlm.core.context import ParallelMode
9
+ from internlm.core.context import global_context as gpc
10
+ from internlm.utils.common import get_current_device
11
+
12
+ TensorShape = Union[torch.Size, List[int], Tuple[int]]
13
+
14
+
15
+ def send_meta_helper(obj, next_rank, tensor_kwargs):
16
+ send_shape = torch.tensor(obj.size(), **tensor_kwargs)
17
+ send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
18
+ dist.send(send_ndims, next_rank)
19
+ dist.send(send_shape, next_rank)
20
+
21
+
22
+ def send_obj_meta(obj, next_rank=None):
23
+ """Sends obj meta information before sending a specific obj.
24
+ Since the recipient must know the shape of the obj in p2p communications,
25
+ meta information of the obj should be sent before communications. This function
26
+ synchronizes with :func:`recv_obj_meta`.
27
+
28
+ Args:
29
+ obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
30
+ need_meta (bool, optional): If False, meta information won't be sent.
31
+ next_rank (int): The rank of the next member in pipeline parallel group.
32
+
33
+ Returns:
34
+ bool: False
35
+ """
36
+ if next_rank is None:
37
+ next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
38
+
39
+ tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
40
+ if isinstance(obj, torch.Tensor):
41
+ send_obj_nums = torch.tensor(1, **tensor_kwargs)
42
+ dist.send(send_obj_nums, next_rank)
43
+ send_meta_helper(obj, next_rank, tensor_kwargs)
44
+ else:
45
+ send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
46
+ dist.send(send_obj_nums, next_rank)
47
+ for tensor_to_send in obj:
48
+ send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
49
+
50
+
51
+ def recv_meta_helper(prev_rank, tensor_kwargs):
52
+ recv_ndims = torch.empty((), **tensor_kwargs)
53
+ dist.recv(recv_ndims, prev_rank)
54
+ recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
55
+ dist.recv(recv_shape, prev_rank)
56
+ return recv_shape
57
+
58
+
59
+ def recv_obj_meta(prev_rank=None) -> torch.Size:
60
+ """Receives obj meta information before receiving a specific obj.
61
+ Since the recipient must know the shape of the obj in p2p communications,
62
+ meta information of the obj should be received before communications. This function
63
+ synchronizes with :func:`send_obj_meta`.
64
+
65
+ Args:
66
+ obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
67
+ prev_rank (int): The rank of the source of the obj.
68
+
69
+ Returns:
70
+ Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
71
+ """
72
+ if prev_rank is None:
73
+ prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
74
+
75
+ tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
76
+ recv_obj_nums = torch.empty((), **tensor_kwargs)
77
+ dist.recv(recv_obj_nums, prev_rank)
78
+ if recv_obj_nums.item() == 1:
79
+ recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
80
+ obj_shape = torch.Size(recv_shape)
81
+ else:
82
+ obj_shape = []
83
+ for _ in range(recv_obj_nums.item()):
84
+ recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
85
+ obj_shape.append(torch.Size(recv_shape))
86
+
87
+ return obj_shape
88
+
89
+
90
+ def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
91
+ """Break a tensor into equal 1D chunks.
92
+
93
+ Args:
94
+ tensor (:class:`torch.Tensor`): Tensor to be split before communication.
95
+ new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
96
+
97
+ Returns:
98
+ :class:`torch.Tensor`: The split tensor
99
+ """
100
+ partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR)
101
+ start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR)
102
+ end_index = start_index + partition_size
103
+ if new_buffer:
104
+ data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
105
+ data.copy_(tensor.view(-1)[start_index:end_index])
106
+ else:
107
+ data = tensor.view(-1)[start_index:end_index]
108
+ return data
109
+
110
+
111
+ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
112
+ """Opposite of above function, gather values from model parallel ranks.
113
+
114
+ Args:
115
+ tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
116
+ Returns:
117
+ :class:`torch.Tensor`: The gathered tensor.
118
+ """
119
+ world_size = gpc.get_world_size(ParallelMode.TENSOR)
120
+ numel = torch.numel(tensor)
121
+ numel_gathered = world_size * numel
122
+ gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
123
+ chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
124
+ dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
125
+ return gathered
InternLM/internlm/core/context/__init__.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parallel_context import (
2
+ IS_TENSOR_PARALLEL,
3
+ Config,
4
+ ParallelContext,
5
+ global_context,
6
+ )
7
+ from .process_group_initializer import (
8
+ Initializer_Data,
9
+ Initializer_Model,
10
+ Initializer_Nettest,
11
+ Initializer_Pipeline,
12
+ Initializer_Tensor,
13
+ Initializer_Zero1,
14
+ ParallelMode,
15
+ ProcessGroupInitializer,
16
+ )
17
+ from .random import (
18
+ add_seed,
19
+ get_current_mode,
20
+ get_seeds,
21
+ get_states,
22
+ seed,
23
+ set_mode,
24
+ set_seed_states,
25
+ sync_states,
26
+ )
27
+
28
+ __all__ = [
29
+ "Config",
30
+ "IS_TENSOR_PARALLEL",
31
+ "global_context",
32
+ "ParallelContext",
33
+ "ParallelMode",
34
+ "Initializer_Tensor",
35
+ "Initializer_Pipeline",
36
+ "Initializer_Data",
37
+ "Initializer_Zero1",
38
+ "Initializer_Nettest",
39
+ "ProcessGroupInitializer",
40
+ "Initializer_Model",
41
+ "seed",
42
+ "set_mode",
43
+ "add_seed",
44
+ "get_seeds",
45
+ "get_states",
46
+ "get_current_mode",
47
+ "set_seed_states",
48
+ "sync_states",
49
+ ]
InternLM/internlm/core/context/parallel_context.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
5
+
6
+ import inspect
7
+ import random
8
+ import socket
9
+ import sys
10
+ from collections import Counter
11
+ from importlib.machinery import SourceFileLoader
12
+ from pathlib import Path
13
+ from typing import Union
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.distributed as dist
18
+
19
+ from internlm.utils.common import SingletonMeta
20
+ from internlm.utils.logger import get_logger
21
+ from internlm.utils.timeout import LLM_NCCL_TIMEOUT
22
+
23
+ from . import process_group_initializer as pgroup_initializer
24
+ from .process_group_initializer import ParallelMode
25
+ from .random import add_seed, get_seeds, set_mode
26
+
27
+ IS_TENSOR_PARALLEL = "is_tensor_parallel"
28
+
29
+ logger = get_logger(__file__)
30
+
31
+
32
+ class Config(dict):
33
+ """This is a wrapper class for dict objects so that values of which can be
34
+ accessed as attributes.
35
+
36
+ Args:
37
+ config (dict): The dict object to be wrapped.
38
+ """
39
+
40
+ def __init__(self, config: dict = None): # pylint: disable=W0231
41
+ if config is not None:
42
+ for k, v in config.items():
43
+ self._add_item(k, v)
44
+
45
+ def __missing__(self, key):
46
+ raise KeyError(key)
47
+
48
+ def __getattr__(self, key):
49
+ try:
50
+ value = super().__getitem__(key)
51
+ return value
52
+ except KeyError:
53
+ raise AttributeError(key)
54
+
55
+ def __setattr__(self, key, value):
56
+ super().__setitem__(key, value)
57
+
58
+ def _add_item(self, key, value):
59
+ if isinstance(value, dict):
60
+ self.__setattr__(key, Config(value))
61
+ else:
62
+ self.__setattr__(key, value)
63
+
64
+ def update(self, config):
65
+ assert isinstance(config, (Config, dict)), "can only update dictionary or Config objects."
66
+ for k, v in config.items():
67
+ self._add_item(k, v)
68
+ return self
69
+
70
+ @staticmethod
71
+ def from_file(filename: str) -> object:
72
+ """Reads a python file and constructs a corresponding :class:`Config` object.
73
+
74
+ Args:
75
+ filename (str): Name of the file to construct the return object.
76
+
77
+ Returns:
78
+ :class:`Config`: A :class:`Config` object constructed with information in the file.
79
+
80
+ Raises:
81
+ AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
82
+ """
83
+
84
+ # check config path
85
+ if isinstance(filename, str):
86
+ filepath = Path(filename).absolute()
87
+ elif isinstance(filename, Path):
88
+ filepath = filename.absolute()
89
+
90
+ assert filepath.exists(), f"{filename} is not found, please check your configuration path"
91
+
92
+ # check extension
93
+ extension = filepath.suffix
94
+ assert extension == ".py", "only .py files are supported"
95
+
96
+ # import the config as module
97
+ remove_path = False
98
+ if filepath.parent not in sys.path:
99
+ sys.path.insert(0, (filepath))
100
+ remove_path = True
101
+
102
+ module_name = filepath.stem
103
+ source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
104
+ module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
105
+
106
+ # load into config
107
+ config = Config()
108
+
109
+ for k, v in module.__dict__.items():
110
+ if k.startswith("__") or inspect.ismodule(v) or inspect.isclass(v):
111
+ continue
112
+ else:
113
+ config._add_item(k, v)
114
+
115
+ # remove module
116
+ del sys.modules[module_name]
117
+ if remove_path:
118
+ sys.path.pop(0)
119
+
120
+ return config
121
+
122
+
123
+ class ParallelContext(metaclass=SingletonMeta):
124
+ """This class provides interface functions for users to get the parallel context,
125
+ such as the global rank, the local rank, the world size, etc. of each device.
126
+
127
+ """
128
+
129
+ def __init__(self):
130
+ # distributed settings
131
+ self._global_ranks = dict()
132
+ self._local_ranks = dict()
133
+ self._world_sizes = dict()
134
+ self._groups = dict()
135
+ self._cpu_groups = dict()
136
+ self._ranks_in_group = dict()
137
+
138
+ # load config from file
139
+ self._config = None
140
+
141
+ # default parallel args, will be overwritten during process group intialization
142
+ self.world_size = 1
143
+ self.data_parallel_size = 1
144
+ self.pipeline_parallel_size = 1
145
+ self.tensor_parallel_size = 1
146
+ self.zero1_parallel_size = -1
147
+ self.nettest_parallel_size = 1
148
+ self.num_processes_on_current_node = -1
149
+ self.virtual_pipeline_parallel_size = None
150
+ self.virtual_pipeline_parallel_rank = None
151
+
152
+ @property
153
+ def config(self):
154
+ return self._config
155
+
156
+ def load_config(self, config: Union[dict, str]):
157
+ """Loads the configuration from either a dict or a file.
158
+
159
+ Args:
160
+ config (dict or str): Either a dict containing the configuration information or the filename
161
+ of a file containing the configuration information.
162
+
163
+ Raises:
164
+ TypeError: Raises a TypeError if `config` is neither a dict nor a str.
165
+ """
166
+ if isinstance(config, str):
167
+ self._config = Config.from_file(config)
168
+ elif isinstance(config, dict):
169
+ self._config = Config(config)
170
+ else:
171
+ raise TypeError("Invalid type for config, only dictionary or string is supported")
172
+
173
+ def detect_num_processes_on_current_node(self):
174
+ hostname = socket.gethostname()
175
+ hostname_list = [None for _ in range(self.get_world_size(ParallelMode.GLOBAL))]
176
+ dist.all_gather_object(hostname_list, hostname, group=self.get_group(ParallelMode.GLOBAL))
177
+ counter = Counter(hostname_list)
178
+ self.num_processes_on_current_node = counter[hostname]
179
+
180
+ @staticmethod
181
+ def _check_parallel_mode(parallel_mode: ParallelMode):
182
+ assert isinstance(
183
+ parallel_mode, ParallelMode
184
+ ), f"expected the argument parallel_mode to be of enum ParallelMode, but got {type(parallel_mode)}"
185
+
186
+ def get_global_rank(self):
187
+ """Returns the global rank of the current device.
188
+
189
+ Returns:
190
+ int: The global rank of the current device
191
+ """
192
+ return self._global_ranks[ParallelMode.GLOBAL]
193
+
194
+ def get_local_rank(self, parallel_mode: ParallelMode):
195
+ """Returns the local rank of the current device.
196
+
197
+ Args:
198
+ parallel_mode: The parallel mode for the rank.
199
+
200
+ Returns:
201
+ int: The local rank of the current device for `parallel_mode`.
202
+ """
203
+ self._check_parallel_mode(parallel_mode)
204
+ return self._local_ranks.get(parallel_mode, 0)
205
+
206
+ def get_next_global_rank(self, parallel_mode: ParallelMode):
207
+ """Returns the global rank of the next device.
208
+
209
+ Args:
210
+ parallel_mode: The parallel mode for the rank.
211
+
212
+ Returns:
213
+ int: The global rank of the next device for `parallel_mode`.
214
+ """
215
+ self._check_parallel_mode(parallel_mode)
216
+
217
+ # get rank and world size
218
+ local_rank = self.get_local_rank(parallel_mode)
219
+ world_size = self.get_world_size(parallel_mode)
220
+ ranks_in_group = self.get_ranks_in_group(parallel_mode)
221
+
222
+ return ranks_in_group[(local_rank + 1) % world_size]
223
+
224
+ def get_prev_global_rank(self, parallel_mode: ParallelMode):
225
+ """Returns the global rank of the previous device.
226
+
227
+ Args:
228
+ parallel_mode: The chosen parallel mode.
229
+
230
+ Returns:
231
+ int: The global rank of the previous device for `parallel_mode`.
232
+ """
233
+ self._check_parallel_mode(parallel_mode)
234
+
235
+ # get rank and world size
236
+ local_rank = self.get_local_rank(parallel_mode)
237
+ world_size = self.get_world_size(parallel_mode)
238
+ ranks_in_group = self.get_ranks_in_group(parallel_mode)
239
+
240
+ return ranks_in_group[(local_rank - 1) % world_size]
241
+
242
+ def is_using_dp(self):
243
+ """Returns a boolean value indicating whether the current device is initilized with
244
+ ParallelMode.DATA and its world_size is greater than 1.
245
+ """
246
+ return self.is_initialized(ParallelMode.DATA) and self.get_world_size(ParallelMode.DATA) > 1
247
+
248
+ def is_using_tp(self):
249
+ """Returns a boolean value indicating whether the current device is initilized with
250
+ ParallelMode.TENSOR and its world_size is greater than 1.
251
+ """
252
+ return self.is_initialized(ParallelMode.TENSOR) and self.get_world_size(ParallelMode.TENSOR) > 1
253
+
254
+ def is_using_pp(self):
255
+ """Returns a boolean value indicating whether the current device is initilized with
256
+ ParallelMode.PIPELINE and its world_size is greater than 1.
257
+ """
258
+ return self.is_initialized(ParallelMode.PIPELINE) and self.get_world_size(ParallelMode.PIPELINE) > 1
259
+
260
+ def is_using_sequence(self):
261
+ """Returns a boolean value indicating whether the current device is initilized with
262
+ ParallelMode.SEQUENCE and its world_size is greater than 1.
263
+ """
264
+ return False
265
+ # return gpc.is_initialized(ParallelMode.SEQUENCE) and gpc.get_world_size(ParallelMode.SEQUENCE) > 1
266
+
267
+ def is_first_rank(self, parallel_mode: ParallelMode):
268
+ """Returns a boolean value indicating whether the current device is the first one
269
+ among its group for `parallel_mode`.
270
+
271
+ Args:
272
+ parallel_mode: The chosen parallel mode.
273
+
274
+ Returns:
275
+ bool: a boolean value indicating whether the current device is the first one
276
+ among its group for `parallel_mode`.
277
+ """
278
+ rank = 0
279
+ if self.is_initialized(parallel_mode):
280
+ rank = self.get_local_rank(parallel_mode)
281
+ return rank == 0
282
+
283
+ def is_rank_for_log(self):
284
+ """Returns a boolean value indicating whether the current device should print log."""
285
+ is_log_rank = (
286
+ self.is_first_rank(ParallelMode.DATA)
287
+ and self.is_first_rank(ParallelMode.TENSOR)
288
+ and self.is_last_rank(ParallelMode.PIPELINE)
289
+ )
290
+ return is_log_rank
291
+
292
+ def is_last_rank(self, parallel_mode: ParallelMode):
293
+ """Returns a boolean value indicating whether the current device is the last one
294
+ among its group for `parallel_mode`.
295
+
296
+ Args:
297
+ parallel_mode: The chosen parallel mode.
298
+
299
+ Returns:
300
+ bool: a boolean value indicating whether the current device is the first one
301
+ among its group for `parallel_mode`.
302
+ """
303
+ rank = 0
304
+ world_size = 1
305
+ if self.is_initialized(parallel_mode):
306
+ rank = self.get_local_rank(parallel_mode)
307
+ world_size = self.get_world_size(parallel_mode)
308
+ return rank == world_size - 1
309
+
310
+ def is_pipeline_first_stage(self, ignore_virtual=False):
311
+ if not ignore_virtual:
312
+ if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
313
+ return False
314
+ return self.is_first_rank(ParallelMode.PIPELINE)
315
+
316
+ def is_pipeline_last_stage(self, ignore_virtual=False):
317
+ if not ignore_virtual:
318
+ if (
319
+ self.virtual_pipeline_parallel_size is not None
320
+ and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1
321
+ ):
322
+ return False
323
+ return self.is_last_rank(ParallelMode.PIPELINE)
324
+
325
+ def get_world_size(self, parallel_mode: ParallelMode):
326
+ """Returns the world size for `parallel_mode`.
327
+
328
+ Args:
329
+ parallel_mode: The chosen parallel mode.
330
+
331
+ Returns:
332
+ int: The world size for `parallel_mode`.
333
+ """
334
+ self._check_parallel_mode(parallel_mode)
335
+ return self._world_sizes.get(parallel_mode, 1)
336
+
337
+ def get_group(self, parallel_mode: ParallelMode):
338
+ """Returns the group of the current device for `parallel_mode`.
339
+
340
+ Args:
341
+ parallel_mode: The chosen parallel mode.
342
+
343
+ Returns:
344
+ torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
345
+ """
346
+ self._check_parallel_mode(parallel_mode)
347
+ return self._groups[parallel_mode]
348
+
349
+ def get_ranks_in_group(self, parallel_mode: ParallelMode):
350
+ """Returns the rank of the current device for `parallel_mode` in the group.
351
+
352
+ Args:
353
+ parallel_mode: The chosen parallel mode.
354
+
355
+ Returns:
356
+ int: The rank of the current device for `parallel_mode` in the group.
357
+ """
358
+ self._check_parallel_mode(parallel_mode)
359
+ return self._ranks_in_group[parallel_mode]
360
+
361
+ def get_cpu_group(self, parallel_mode: ParallelMode):
362
+ self._check_parallel_mode(parallel_mode)
363
+ return self._cpu_groups[parallel_mode]
364
+
365
+ def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int, use_cpu: bool = False):
366
+ """Initializes the global distributed environment
367
+
368
+ Args:
369
+ rank (int): rank for the default process group.
370
+ world_size (int): world size of the default process group.
371
+ backend (str): backend for ``torch.distributed``
372
+ host (str): the master address for distributed training.
373
+ port (str): the master port for distributed training.
374
+ use_cpu (bool): whether to set up cpu process group.
375
+ """
376
+ # initialize the default process group
377
+ init_method = f"tcp://[{host}]:{port}"
378
+ dist.init_process_group(
379
+ rank=rank,
380
+ world_size=world_size,
381
+ backend=backend,
382
+ init_method=init_method,
383
+ timeout=LLM_NCCL_TIMEOUT,
384
+ )
385
+
386
+ # None will give the default global process group for pytorch dist operations
387
+ ranks = list(range(world_size))
388
+ if use_cpu:
389
+ cpu_group = (
390
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
391
+ if dist.get_backend() != "gloo"
392
+ else None
393
+ )
394
+ else:
395
+ cpu_group = None
396
+ self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
397
+ self._global_ranks[ParallelMode.GLOBAL] = rank
398
+
399
+ def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
400
+ self._check_parallel_mode(mode)
401
+ self._local_ranks[mode] = local_rank
402
+ self._world_sizes[mode] = world_size
403
+ self._groups[mode] = process_group
404
+ self._cpu_groups[mode] = cpu_group
405
+ self._ranks_in_group[mode] = ranks_in_group
406
+
407
+ def check_sanity(self):
408
+ """Checks sanity of the parallel context.
409
+
410
+ Raises:
411
+ AssertionError: Raises an AssertionError if the world size does not equal to the product
412
+ of data parallel size, pipeline parallel size and tensor parallel size.
413
+ """
414
+ dps = self.data_parallel_size
415
+ pps = self.pipeline_parallel_size
416
+ tps = self.tensor_parallel_size
417
+ ws = self.world_size
418
+ assert ws == dps * pps * tps, (
419
+ f"Expected the world size {ws} to be equal to data"
420
+ f" parallel size ({dps}) * pipeline parallel size "
421
+ f"({pps}) * tensor parallel size ({tps})"
422
+ )
423
+ assert self.zero1_parallel_size > 0
424
+ assert self.data_parallel_size % self.zero1_parallel_size == 0
425
+
426
+ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
427
+ if key in config:
428
+ ele = config[key]
429
+ if isinstance(ele, int):
430
+ setattr(self, attr_name, ele)
431
+ elif isinstance(ele, dict):
432
+ setattr(self, attr_name, ele["size"])
433
+ else:
434
+ raise NotImplementedError(
435
+ f'{"Parallel configuration does not support this kind of argument, please use int or dict"}'
436
+ )
437
+
438
+ def init_parallel_groups(self):
439
+ """Initializes the parallel groups."""
440
+
441
+ # get rank and world size
442
+ rank = self.get_global_rank()
443
+ world_size = self.get_world_size(ParallelMode.GLOBAL)
444
+ self.world_size = world_size
445
+
446
+ # set parallel size as attributes for global context
447
+ parallel_config = self.config.get("parallel", None)
448
+ if parallel_config is not None:
449
+ self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
450
+ self._set_parallel_size_from_config(parallel_config, "tensor", "tensor_parallel_size")
451
+ self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
452
+
453
+ # the user should not set the data parallel size manually
454
+ # instead, it should be calculated based on other parallel config
455
+ self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
456
+
457
+ # the recommended nettest_parallel_size is 32 GPUs
458
+ self.nettest_parallel_size = 32
459
+
460
+ if self.zero1_parallel_size <= 0:
461
+ self.zero1_parallel_size = self.data_parallel_size
462
+
463
+ self.check_sanity()
464
+
465
+ initializer_args = [
466
+ rank,
467
+ world_size,
468
+ self.data_parallel_size,
469
+ self.pipeline_parallel_size,
470
+ self.tensor_parallel_size,
471
+ self.zero1_parallel_size,
472
+ self.nettest_parallel_size,
473
+ ]
474
+
475
+ # run initialization of different process groups
476
+ initializers = []
477
+ initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
478
+ initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
479
+ initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
480
+ initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
481
+ initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
482
+ if self.pipeline_parallel_size > 1:
483
+ initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
484
+ for initializer in initializers:
485
+ parallel_setting = initializer.init_dist_group()
486
+ if isinstance(parallel_setting, list):
487
+ for args in parallel_setting:
488
+ self._register_dist(*args)
489
+ else:
490
+ self._register_dist(*parallel_setting)
491
+
492
+ def is_initialized(self, parallel_mode: ParallelMode):
493
+ """Returns a boolean value indicating whether `parallel_mode` is initialized
494
+ in the current system.
495
+ """
496
+ return parallel_mode in self._groups
497
+
498
+ def destroy(self):
499
+ """Destroys the current distributed parallel environment."""
500
+ for mode, group in self._groups.items():
501
+ if mode is not ParallelMode.GLOBAL:
502
+ dist.destroy_process_group(group)
503
+ # destroy global process group
504
+ dist.destroy_process_group()
505
+ self._groups.clear()
506
+
507
+ def set_device(self, device_ordinal: int = None):
508
+ """Sets distributed processes to be bound to devices.
509
+
510
+ Args:
511
+ device_ordinal (int, optional): the device id to be bound to
512
+ """
513
+ global_rank = self.get_global_rank()
514
+ if device_ordinal is None:
515
+ devices_per_node = torch.cuda.device_count()
516
+ device_ordinal = global_rank % devices_per_node
517
+
518
+ torch.cuda.set_device(device_ordinal)
519
+ logger.info(f"process rank {global_rank} is bound to host:{socket.gethostname()} device: {device_ordinal}")
520
+
521
+ def set_seed(self, seed: int, dpseed_with_tpoffset: bool = False):
522
+ """Sets seeds for all random libraries.
523
+
524
+ Args:
525
+ seed (int): seed for random states
526
+ """
527
+ pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
528
+ global_rank = self.get_global_rank()
529
+
530
+ random.seed(seed)
531
+ np.random.seed(seed)
532
+ torch.manual_seed(seed)
533
+ assert torch.cuda.is_available()
534
+
535
+ # data parallel seed are kept the same in the same pipeline stage
536
+ dp_seed = seed
537
+ if dpseed_with_tpoffset:
538
+ dp_seed = seed + pipeline_offset * 1024
539
+ add_seed(ParallelMode.DATA, dp_seed)
540
+ add_seed(ParallelMode.DUMMY, dp_seed)
541
+
542
+ # model parallel seeds are different across ranks
543
+ if self.is_initialized(ParallelMode.TENSOR):
544
+ tp_rank = self.get_local_rank(ParallelMode.TENSOR)
545
+ tp_seed = seed + tp_rank + pipeline_offset * 1024
546
+ add_seed(ParallelMode.TENSOR, tp_seed)
547
+
548
+ # we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
549
+ # during model construction), this is because the random state will be different in different tensor parallel
550
+ # device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
551
+ # additional random operations during the RowParallelLinear module building process.
552
+ set_mode(ParallelMode.DUMMY)
553
+
554
+ seeds = get_seeds()
555
+ seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
556
+ logger.info(
557
+ f"initialized seed on rank {global_rank}, "
558
+ f"numpy: {seed}, python random: {seed}, {seed_str},"
559
+ f"the default parallel seed is {ParallelMode.DATA}."
560
+ )
561
+
562
+ def set_virtual_pipeline_parallel_size(self, size):
563
+ self.virtual_pipeline_parallel_size = size
564
+
565
+ def set_virtual_pipeline_parallel_rank(self, rank):
566
+ self.virtual_pipeline_parallel_rank = rank
567
+
568
+
569
+ global_context = ParallelContext()
InternLM/internlm/core/context/process_group_initializer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
5
+
6
+ import math
7
+ from abc import ABC, abstractmethod
8
+ from enum import Enum
9
+
10
+ import torch.distributed as dist
11
+
12
+ from internlm.utils.timeout import LLM_NCCL_TIMEOUT
13
+
14
+
15
+ # parallel modes
16
+ class ParallelMode(Enum):
17
+ """This is an enumeration class containing all possible parallel modes."""
18
+
19
+ GLOBAL = "global"
20
+
21
+ # common parallel
22
+ DATA = "data"
23
+
24
+ # model parallel - containing tensor and pipeline parallel groups
25
+ # this is added to facilitate amp and grad clipping in hybrid parallel
26
+ MODEL = "model"
27
+
28
+ # pipeline parallel
29
+ PIPELINE = "pipe"
30
+
31
+ # containing all ranks in tensor parallel
32
+ TENSOR = "tensor"
33
+
34
+ # zero1 parallel
35
+ ZERO1 = "zero1"
36
+
37
+ # runntime network test
38
+ NETTEST = "nettest"
39
+
40
+ # dummy mode, only used during mode construction
41
+ DUMMY = "dummy"
42
+
43
+
44
+ class ProcessGroupInitializer(ABC):
45
+ """An object, knowing the parallelism configuration, that initializes parallel groups.
46
+
47
+ Args:
48
+ rank (int): The rank of current process.
49
+ world_size (int): Size of whole communication world.
50
+ data_parallel_size (int): Size of data parallel.
51
+ pipeline_parallel_size (int): Size of pipeline parallel.
52
+ tensor_parallel_size (int): Size of tensor parallel.
53
+ zero1_parallel_size (int): Size of zero1 parallel.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ rank: int,
59
+ world_size: int,
60
+ data_parallel_size: int,
61
+ pipeline_parallel_size: int,
62
+ tensor_parallel_size: int,
63
+ zero1_parallel_size: int,
64
+ nettest_parallel_size: int,
65
+ ):
66
+ self.rank = rank
67
+ self.world_size = world_size
68
+ self.data_parallel_size = data_parallel_size
69
+ self.pipeline_parallel_size = pipeline_parallel_size
70
+ self.tensor_parallel_size = tensor_parallel_size
71
+ self.zero1_parallel_size = zero1_parallel_size
72
+ self.nettest_parallel_size = nettest_parallel_size
73
+ super().__init__()
74
+
75
+ @abstractmethod
76
+ def init_dist_group(self, use_cpu: bool = False):
77
+ pass
78
+
79
+
80
+ class Initializer_Data(ProcessGroupInitializer):
81
+ """A ProcessGroupInitializer for data parallelism.
82
+
83
+ Args:
84
+ rank (int): The rank of current process.
85
+ world_size (int): Size of whole communication world.
86
+ data_parallel_size (int): Size of data parallel.
87
+ pipeline_parallel_size (int): Size of pipeline parallel.
88
+ tensor_parallel_size (int): Size of tensor parallel.
89
+ zero1_parallel_size (int): Size of zero1 parallel.
90
+ """
91
+
92
+ def __init__(self, *args, **kwargs):
93
+ super().__init__(*args, **kwargs)
94
+ self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
95
+
96
+ assert self.world_size % self.data_parallel_size == 0
97
+
98
+ def init_dist_group(self, use_cpu: bool = False):
99
+ """Initialize data parallel groups, and assign local_ranks and groups to each gpu.
100
+
101
+ Returns:
102
+ Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
103
+ A Data parallelism's information tuple.
104
+ """
105
+ local_rank = None
106
+ ranks_in_group = None
107
+ process_group = None
108
+ cpu_group = None
109
+ group_world_size = None
110
+ mode = ParallelMode.DATA
111
+
112
+ for i in range(self.rank_num_per_dp_group):
113
+ ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
114
+ group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
115
+ if use_cpu:
116
+ group_cpu = (
117
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
118
+ if dist.get_backend() != "gloo"
119
+ else group
120
+ )
121
+ else:
122
+ group_cpu = None
123
+
124
+ if self.rank in ranks:
125
+ local_rank = ranks.index(self.rank)
126
+ group_world_size = len(ranks)
127
+ process_group = group
128
+ cpu_group = group_cpu
129
+ ranks_in_group = ranks
130
+
131
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
132
+
133
+
134
+ class Initializer_Model(ProcessGroupInitializer):
135
+ """A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
136
+ groups).
137
+
138
+ Args:
139
+ rank (int): The rank of current process.
140
+ world_size (int): Size of whole communication world.
141
+ data_parallel_size (int): Size of data parallel.
142
+ pipeline_parallel_size (int): Size of pipeline parallel.
143
+ tensor_parallel_size (int): Size of tensor parallel.
144
+ zero1_parallel_size (int): Size of zero1 parallel.
145
+ """
146
+
147
+ def __init__(self, *args, **kwargs):
148
+ super().__init__(*args, **kwargs)
149
+ self.rank_num_per_group = self.tensor_parallel_size * self.pipeline_parallel_size
150
+ self.num_group = self.world_size // self.rank_num_per_group
151
+
152
+ assert self.world_size % self.rank_num_per_group == 0
153
+
154
+ def init_dist_group(self, use_cpu: bool = False):
155
+ """Initialize model parallel groups, and assign local_ranks and groups to each gpu.
156
+
157
+ Returns:
158
+ Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
159
+ A Model parallelism's information tuple.
160
+ """
161
+ local_rank = None
162
+ ranks_in_group = None
163
+ process_group = None
164
+ cpu_group = None
165
+ group_world_size = None
166
+ mode = ParallelMode.MODEL
167
+
168
+ for i in range(self.num_group):
169
+ ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
170
+ group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
171
+ if use_cpu:
172
+ group_cpu = (
173
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
174
+ if dist.get_backend() != "gloo"
175
+ else group
176
+ )
177
+ else:
178
+ group_cpu = None
179
+
180
+ if self.rank in ranks:
181
+ local_rank = ranks.index(self.rank)
182
+ group_world_size = len(ranks)
183
+ process_group = group
184
+ cpu_group = group_cpu
185
+ ranks_in_group = ranks
186
+
187
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
188
+
189
+
190
+ class Initializer_Pipeline(ProcessGroupInitializer):
191
+ """A ProcessGroupInitializer for pipeline parallelism.
192
+
193
+ Args:
194
+ rank (int): The rank of current process
195
+ world_size (int): Size of whole communication world
196
+ data_parallel_size (int): Size of data parallel
197
+ pipeline_parallel_size (int): Size of pipeline parallel
198
+ tensor_parallel_size (int): Size of tensor parallel
199
+ zero1_parallel_size (int): Size of zero1 parallel.
200
+ """
201
+
202
+ def __init__(self, *args, **kwargs):
203
+ super().__init__(*args, **kwargs)
204
+ self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
205
+ self.pipeline_stage_size = self.rank_num_per_dp_group // self.pipeline_parallel_size
206
+
207
+ assert self.world_size % self.data_parallel_size == 0
208
+ assert self.rank_num_per_dp_group % self.pipeline_parallel_size == 0
209
+
210
+ def init_dist_group(self, use_cpu: bool = False):
211
+ """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
212
+
213
+ Returns:
214
+ List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
215
+ A Pipeline parallelism's information in list of tuples.
216
+ """
217
+ local_rank = None
218
+ ranks_in_group = None
219
+ process_group = None
220
+ cpu_group = None
221
+ group_world_size = None
222
+ mode = ParallelMode.PIPELINE
223
+
224
+ for i in range(self.data_parallel_size):
225
+ for j in range(self.pipeline_stage_size):
226
+ ranks = list(
227
+ range(
228
+ i * self.rank_num_per_dp_group + j,
229
+ (i + 1) * self.rank_num_per_dp_group,
230
+ self.pipeline_stage_size,
231
+ )
232
+ )
233
+ pipe_group_size = len(ranks)
234
+ pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
235
+ if use_cpu:
236
+ group_cpu = (
237
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
238
+ if dist.get_backend() != "gloo"
239
+ else pipe_group
240
+ )
241
+ else:
242
+ group_cpu = None
243
+
244
+ if self.rank in ranks:
245
+ local_rank = ranks.index(self.rank)
246
+ group_world_size = pipe_group_size
247
+ process_group = pipe_group
248
+ cpu_group = group_cpu
249
+ ranks_in_group = ranks
250
+
251
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
252
+
253
+
254
+ class Initializer_Tensor(ProcessGroupInitializer):
255
+ """A ProcessGroupInitializer for tensor parallelism.
256
+
257
+ Args:
258
+ rank (int): The rank of current process.
259
+ world_size (int): Size of whole communication world.
260
+ data_parallel_size (int): Size of data parallel.
261
+ pipeline_parallel_size (int): Size of pipeline parallel.
262
+ tensor_parallel_size (int): Size of tensor parallel.
263
+ zero1_parallel_size (int): Size of zero1 parallel.
264
+ """
265
+
266
+ def __init__(self, *args, **kwargs):
267
+ super().__init__(*args, **kwargs)
268
+ self.num_tensor_parallel_group = self.world_size // self.tensor_parallel_size
269
+
270
+ assert self.world_size % self.tensor_parallel_size == 0
271
+
272
+ def init_dist_group(self, use_cpu: bool = False):
273
+ """Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
274
+
275
+ Returns:
276
+ Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
277
+ A Tensor parallelism's information tuple.
278
+ """
279
+ local_rank = None
280
+ ranks_in_group = None
281
+ process_group = None
282
+ cpu_group = None
283
+ group_world_size = None
284
+ mode = ParallelMode.TENSOR
285
+
286
+ for i in range(self.num_tensor_parallel_group):
287
+ ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
288
+ group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
289
+ if use_cpu:
290
+ group_cpu = (
291
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
292
+ if dist.get_backend() != "gloo"
293
+ else group
294
+ )
295
+ else:
296
+ group_cpu = None
297
+
298
+ if self.rank in ranks:
299
+ local_rank = ranks.index(self.rank)
300
+ group_world_size = len(ranks)
301
+ process_group = group
302
+ cpu_group = group_cpu
303
+ ranks_in_group = ranks
304
+
305
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
306
+
307
+
308
+ class Initializer_Zero1(ProcessGroupInitializer):
309
+ """A ProcessGroupInitializer for zero-1 parallelism.
310
+
311
+ Args:
312
+ rank (int): The rank of current process.
313
+ world_size (int): Size of whole communication world.
314
+ data_parallel_size (int): Size of data parallel.
315
+ pipeline_parallel_size (int): Size of pipeline parallel.
316
+ tensor_parallel_size (int): Size of tensor parallel.
317
+ zero1_parallel_size (int): Size of zero-1 parallel.
318
+ """
319
+
320
+ def __init__(self, *args, **kwargs):
321
+ super().__init__(*args, **kwargs)
322
+ self.rank_num_per_dp_group = self.world_size // self.data_parallel_size
323
+ self.num_zero1_parallel_group = self.data_parallel_size // self.zero1_parallel_size
324
+
325
+ assert self.world_size % self.data_parallel_size == 0
326
+ assert self.world_size % self.zero1_parallel_size == 0
327
+
328
+ def init_dist_group(self, use_cpu: bool = False):
329
+ """Initialize zero1 parallel groups, and assign local_ranks and groups to each gpu.
330
+
331
+ Returns:
332
+ Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
333
+ A zero1 parallelism's information tuple.
334
+ """
335
+ local_rank = None
336
+ ranks_in_group = None
337
+ process_group = None
338
+ cpu_group = None
339
+ group_world_size = None
340
+ mode = ParallelMode.ZERO1
341
+
342
+ for i in range(self.rank_num_per_dp_group):
343
+ for j in range(self.num_zero1_parallel_group):
344
+ ranks = [
345
+ i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
346
+ for k in range(self.zero1_parallel_size)
347
+ ]
348
+ group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
349
+ if use_cpu:
350
+ group_cpu = (
351
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
352
+ if dist.get_backend() != "gloo"
353
+ else group
354
+ )
355
+ else:
356
+ group_cpu = None
357
+
358
+ if self.rank in ranks:
359
+ local_rank = ranks.index(self.rank)
360
+ group_world_size = len(ranks)
361
+ process_group = group
362
+ cpu_group = group_cpu
363
+ ranks_in_group = ranks
364
+
365
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
366
+
367
+
368
+ class Initializer_Nettest(ProcessGroupInitializer):
369
+ """A ProcessGroupInitializer for network test, especailly for NCCL.
370
+
371
+ Args:
372
+ rank (int): The rank of current process.
373
+ world_size (int): Size of whole communication world.
374
+ nettest_parallel_size (int): Size of a network test group.
375
+ """
376
+
377
+ def __init__(self, *args, **kwargs):
378
+ super().__init__(*args, **kwargs)
379
+ self.num_nettest_group = math.ceil(self.world_size / self.nettest_parallel_size)
380
+
381
+ def init_dist_group(self, use_cpu: bool = False):
382
+ """Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
383
+
384
+ Returns:
385
+ Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
386
+ A Tensor parallelism's information tuple.
387
+ """
388
+ local_rank = None
389
+ ranks_in_group = None
390
+ process_group = None
391
+ cpu_group = None
392
+ group_world_size = None
393
+ mode = ParallelMode.NETTEST
394
+
395
+ for i in range(self.num_nettest_group):
396
+ ranks = []
397
+ for j in range(self.nettest_parallel_size):
398
+ rank = i * self.nettest_parallel_size + j
399
+ if rank < self.world_size:
400
+ ranks.append(rank)
401
+ group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
402
+ if use_cpu:
403
+ group_cpu = (
404
+ dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
405
+ if dist.get_backend() != "gloo"
406
+ else group
407
+ )
408
+ else:
409
+ group_cpu = None
410
+
411
+ if self.rank in ranks:
412
+ local_rank = ranks.index(self.rank)
413
+ group_world_size = len(ranks)
414
+ process_group = group
415
+ cpu_group = group_cpu
416
+ ranks_in_group = ranks
417
+
418
+ return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
InternLM/internlm/core/context/random.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
4
+
5
+ from contextlib import contextmanager
6
+
7
+ import torch
8
+ import torch.cuda
9
+ from torch import Tensor
10
+
11
+ from .process_group_initializer import ParallelMode
12
+
13
+
14
+ class SeedManager:
15
+ """This class is a manager of all random seeds involved in the system."""
16
+
17
+ def __init__(self):
18
+ self._current_mode = None
19
+ self._seeds = {}
20
+ self._seed_states = {}
21
+
22
+ @property
23
+ def current_mode(self):
24
+ return self._current_mode
25
+
26
+ @property
27
+ def seeds(self):
28
+ return self._seeds
29
+
30
+ @property
31
+ def seed_states(self):
32
+ return self._seed_states
33
+
34
+ def set_state(self, parallel_mode: ParallelMode, state: Tensor):
35
+ """Sets the state of the seed manager for `parallel_mode`."""
36
+ assert parallel_mode in self._seed_states, f"{parallel_mode} not found in seed manager"
37
+ self._seed_states[parallel_mode] = state
38
+
39
+ def set_mode(self, parallel_mode: ParallelMode):
40
+ """Sets the current mode of the seed manager."""
41
+ if self.current_mode:
42
+ # save state for current mode
43
+ self._seed_states[self._current_mode] = torch.cuda.get_rng_state()
44
+
45
+ # set new state for new mode
46
+ self._current_mode = parallel_mode
47
+ torch.cuda.set_rng_state(self._seed_states[parallel_mode])
48
+
49
+ def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
50
+ """Adds a seed to the seed manager for `parallel_mode`."""
51
+ assert isinstance(parallel_mode, ParallelMode), "Invalid ParallelMode"
52
+ if not overwrite:
53
+ assert parallel_mode not in self._seed_states, f"Seed for {parallel_mode} exists"
54
+ elif parallel_mode in self._seed_states:
55
+ print(f"Warning: {parallel_mode} seed overwritten.", flush=True)
56
+
57
+ current_state = torch.cuda.get_rng_state()
58
+ torch.cuda.manual_seed(seed)
59
+ self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
60
+ self._seeds[parallel_mode] = seed
61
+ torch.cuda.set_rng_state(current_state)
62
+
63
+ def reset(self):
64
+ self._current_mode = None
65
+ self._seeds = {}
66
+ self._seed_states = {}
67
+
68
+
69
+ _SEED_MANAGER = SeedManager()
70
+
71
+
72
+ def get_seeds():
73
+ """Returns the seeds of the seed manager.
74
+ Returns:
75
+ dict: The seeds of the seed manager.
76
+ """
77
+ return _SEED_MANAGER.seeds
78
+
79
+
80
+ def get_states(copy=False):
81
+ """Returns the seed states of the seed manager.
82
+ Returns:
83
+ dict: The seed states of the seed manager.
84
+ """
85
+ states = _SEED_MANAGER.seed_states
86
+ if copy:
87
+ new_states = dict()
88
+ for parallel_mode, state in states.items():
89
+ new_states[parallel_mode] = state.clone()
90
+ return new_states
91
+ else:
92
+ return _SEED_MANAGER.seed_states
93
+
94
+
95
+ def get_current_mode():
96
+ """Returns the current mode of the seed manager.
97
+ Returns:
98
+ :class:`torch.ByteTensor`: The current mode of the seed manager.
99
+ """
100
+ return _SEED_MANAGER.current_mode
101
+
102
+
103
+ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
104
+ """Adds a seed to the seed manager for `parallel_mode`."""
105
+ _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
106
+
107
+
108
+ def set_mode(parallel_mode: ParallelMode):
109
+ """Sets the current mode of the seed manager."""
110
+ _SEED_MANAGER.set_mode(parallel_mode)
111
+
112
+
113
+ def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
114
+ """Sets the state of the seed manager for `parallel_mode`."""
115
+ _SEED_MANAGER.set_state(parallel_mode, state)
116
+
117
+
118
+ def sync_states():
119
+ current_mode = get_current_mode()
120
+ current_states = torch.cuda.get_rng_state()
121
+ set_seed_states(current_mode, current_states)
122
+
123
+
124
+ @contextmanager
125
+ def seed(parallel_mode: ParallelMode):
126
+ """A context for seed switch"""
127
+ current_mode = _SEED_MANAGER.current_mode
128
+ try:
129
+ yield _SEED_MANAGER.set_mode(parallel_mode)
130
+ finally:
131
+ _SEED_MANAGER.set_mode(current_mode)
InternLM/internlm/core/engine.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
5
+
6
+ from typing import List, Optional
7
+
8
+ import torch
9
+ from torch.nn import Module
10
+ from torch.nn.modules.loss import _Loss
11
+ from torch.optim.lr_scheduler import _LRScheduler
12
+
13
+ from internlm.core.gradient_handler import BaseGradientHandler
14
+ from internlm.solver.beta2_scheduler import Beta2Scheduler
15
+ from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
16
+ from internlm.utils.common import get_batch_size, move_to_device
17
+
18
+
19
+ class Engine:
20
+ """
21
+ The Engine class is responsible for managing the training and evaluation process of a neural network model.
22
+ It handles the forward and backward passes, parameter updates, gradient handling, and mode switching between
23
+ training and evaluation.
24
+
25
+ Args:
26
+ model (torch.nn.Module): The neural network model to be trained or evaluated.
27
+ optimizer (BaseOptimizer): The optimizer used for updating the parameters of the model.
28
+ lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): The learning rate scheduler for the optimizer.
29
+ Default is None.
30
+ beta2_scheduler (internlm.solver.beta2_scheduler.Beta2Scheduler, optional): The beta2 scheduler for the
31
+ optimizer. Default is None.
32
+ criterion (torch.nn.modules.loss._Loss, optional): The loss function used for calculating the loss during
33
+ training. Default is None.
34
+ gradient_handlers (List[BaseGradientHandler], optional): A list of gradient handlers used in the backward pass.
35
+ Default is None.
36
+ clip_grad_norm (float, optional): The norm value for gradient clipping. Default is 0.0.
37
+
38
+ Examples:
39
+ >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
40
+ >>> model = ...
41
+ >>> criterion = ...
42
+ >>> optimizer = ...
43
+ >>> train_dataloader = ...
44
+ >>> engine, _, _, _ = internlm.initialize_engine(model, optimizer, criterion)
45
+ >>> engine.train()
46
+ >>> for inputs, labels in train_dataloader
47
+ >>> # set gradients to zero
48
+ >>> engine.zero_grad()
49
+ >>> # run forward pass
50
+ >>> outputs = engine(inputs)
51
+ >>> # compute loss value and run backward pass
52
+ >>> loss = engine.criterion(outputs, labels)
53
+ >>> engine.backward(loss)
54
+ >>> # update parameters
55
+ >>> engine.step()
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ model: Module,
61
+ optimizer: BaseOptimizer,
62
+ lr_scheduler: Optional[_LRScheduler] = None,
63
+ beta2_scheduler: Optional[Beta2Scheduler] = None,
64
+ criterion: Optional[_Loss] = None,
65
+ gradient_handlers: Optional[List[BaseGradientHandler]] = None,
66
+ clip_grad_norm: float = 0.0,
67
+ ):
68
+ self._model = model
69
+ self._optimizer = optimizer
70
+ self._lr_scheduler = lr_scheduler
71
+ self._beta2_scheduler = beta2_scheduler
72
+ self._criterion = criterion
73
+ self._clip_grad_norm = clip_grad_norm
74
+
75
+ # state
76
+ self.training = True # default
77
+
78
+ # build gradient handler
79
+ self._gradient_handlers = gradient_handlers if gradient_handlers else []
80
+
81
+ @property
82
+ def model(self):
83
+ """Returns the model attached to the engine."""
84
+ return self._model
85
+
86
+ @property
87
+ def optimizer(self):
88
+ """Returns the optimizer attached to the engine."""
89
+ return self._optimizer
90
+
91
+ @property
92
+ def criterion(self):
93
+ """Returns the criterion (loss function) attached to the engine."""
94
+ return self._criterion
95
+
96
+ def _all_reduce_gradients(self):
97
+ """Handles all-reduce operations of gradients across different parallel groups."""
98
+ for handler in self._gradient_handlers:
99
+ handler.handle_gradient()
100
+
101
+ def zero_grad(self):
102
+ """Sets the gradient of all parameters in the model to zero."""
103
+ self.optimizer.zero_grad()
104
+
105
+ def step(self):
106
+ """
107
+ Executes the parameter update step. This includes all-reduce operations of gradients, gradient clipping,
108
+ and parameter update. If successful, it also steps the learning rate scheduler and beta2 scheduler
109
+ if they exist.
110
+
111
+ Returns:
112
+ success (bool): Whether the parameter update was successful.
113
+ grad_norm (float): The norm of the gradient after clipping.
114
+ """
115
+ self._all_reduce_gradients()
116
+ self.optimizer.clip_grad_norm(self.model, self._clip_grad_norm)
117
+
118
+ success, grad_norm = self.optimizer.step()
119
+
120
+ if success and self._lr_scheduler is not None:
121
+ self._lr_scheduler.step()
122
+
123
+ if success and self._beta2_scheduler is not None:
124
+ self._beta2_scheduler.step()
125
+
126
+ return success, grad_norm
127
+
128
+ def train(self):
129
+ """Sets the model to training mode."""
130
+ self.training = True
131
+ self._model.train()
132
+
133
+ def eval(self):
134
+ """Sets the model to evaluation mode."""
135
+ self.training = False
136
+ self._model.eval()
137
+
138
+ def backward(self, loss: torch.Tensor):
139
+ """
140
+ Starts the backward propagation given the loss value computed by a loss function.
141
+
142
+ Args:
143
+ loss (torch.Tensor): The loss value computed by a loss function.
144
+ """
145
+ return self.optimizer.backward(loss)
146
+
147
+ def backward_by_grad(self, tensor, grad):
148
+ """
149
+ Starts the backward propagation given the gradient of the output tensor.
150
+
151
+ Args:
152
+ tensor (torch.Tensor): The output tensor.
153
+ grad (torch.Tensor): The gradient passed back to the output tensor.
154
+ """
155
+ return self.optimizer.backward_by_grad(tensor, grad)
156
+
157
+ def __call__(self, *args, **kwargs):
158
+ """
159
+ Runs the forward step for the model.
160
+
161
+ Returns:
162
+ torch.Tensor: The output of the model.
163
+ """
164
+ return self.model(*args, **kwargs)
165
+
166
+ def load_batch(self, data_iter, to_gpu=True):
167
+ """
168
+ Loads a batch from the data iterator. It returns the data and labels which are
169
+ already in the same GPU as where the model is.
170
+
171
+ Args:
172
+ data_iter (Iterable): The data iterator from which to get a batch of data, obtained by calling
173
+ iter(dataloader).
174
+ to_gpu (bool, optional): Whether the data should be moved to the GPU. Default is True.
175
+
176
+ Returns:
177
+ Tuple (torch.Tensor, torch.Tensor): A tuple of (data, label).
178
+ """
179
+ if data_iter is None:
180
+ raise RuntimeError("Dataloader is not defined.")
181
+ try:
182
+ batch_data = next(data_iter)
183
+ except TypeError:
184
+ batch_data = data_iter
185
+
186
+ if to_gpu:
187
+ batch_data = move_to_device(batch_data)
188
+ batch_size = get_batch_size(batch_data)
189
+
190
+ return batch_data, batch_size
191
+
192
+
193
+ class KDEngine(Engine):
194
+ def __init__(
195
+ self,
196
+ model: Module,
197
+ teacher: Module,
198
+ optimizer: BaseOptimizer,
199
+ lr_scheduler: Optional[_LRScheduler] = None,
200
+ beta2_scheduler: Optional[Beta2Scheduler] = None,
201
+ criterion: Optional[_Loss] = None,
202
+ kd_criterion: Optional[_Loss] = None,
203
+ gradient_handlers: Optional[List[BaseGradientHandler]] = None,
204
+ clip_grad_norm: float = 0.0,
205
+ ):
206
+ self._teacher = teacher
207
+ self._kd_criterion = kd_criterion
208
+
209
+ super().__init__(
210
+ model=model,
211
+ optimizer=optimizer,
212
+ lr_scheduler=lr_scheduler,
213
+ beta2_scheduler=beta2_scheduler,
214
+ criterion=criterion,
215
+ gradient_handlers=gradient_handlers,
216
+ clip_grad_norm=clip_grad_norm,
217
+ )
218
+
219
+ @property
220
+ def teacher(self):
221
+ """Returns the model attached to the engine."""
222
+ return self._teacher
223
+
224
+ @property
225
+ def kd_criterion(self):
226
+ """Returns the model attached to the engine."""
227
+ return self._kd_criterion
InternLM/internlm/core/gradient_handler.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
10
+
11
+ from internlm.core.context import global_context as gpc
12
+
13
+
14
+ class BaseGradientHandler(ABC):
15
+ """A basic helper class to handle all-reduce operations of gradients across different parallel groups
16
+ before optimization.
17
+
18
+ Args:
19
+ model (Module): Model where the gradients accumulate.
20
+ optimizer (Optimizer): Optimizer for updating the parameters.
21
+ """
22
+
23
+ def __init__(self, model, optimizer):
24
+ self._model = model
25
+ self._optimizer = optimizer
26
+
27
+ @abstractmethod
28
+ def handle_gradient(self):
29
+ """A method to accumulate gradients across different parallel groups. Users should
30
+ write their own functions or just use the functions in pre-defined subclasses.
31
+ """
32
+ pass
33
+
34
+
35
+ class PipelineSharedModuleGradientHandler(BaseGradientHandler):
36
+ """A helper class to handle all-reduce operations in sub parallel groups.
37
+ A all-reduce collective communication will be operated in
38
+ :func:`handle_gradient` among all sub pipeline parallel groups.
39
+ For better performance, it bucketizes the gradients of all parameters that are
40
+ the same type to improve the efficiency of communication.
41
+
42
+ Args:
43
+ model (Module): Model where the gradients accumulate.
44
+ optimizer (Optimizer): Optimizer for updating the parameters.
45
+ """
46
+
47
+ def handle_gradient(self):
48
+ """A method running a all-reduce operation in sub pipeline parallel groups."""
49
+ if gpc.pipeline_parallel_size > 1:
50
+ # bucketize and all-reduce
51
+ buckets = defaultdict(lambda: defaultdict(list))
52
+ # Pack the buckets.
53
+ for param in self._model.parameters():
54
+ group = getattr(param, "pipeline_shared_module_pg", None)
55
+ if (
56
+ param.requires_grad
57
+ and group is not None
58
+ and (
59
+ (hasattr(param, "colo_attr") and not param.colo_attr.saved_grad.is_null())
60
+ or param.grad is not None
61
+ )
62
+ ):
63
+ tp = param.data.type()
64
+ buckets[group][tp].append(param)
65
+
66
+ # For each bucket, all-reduce and copy all-reduced grads.
67
+ for group, group_buckets in buckets.items():
68
+ for tp, bucket in group_buckets.items():
69
+ grads = [
70
+ param.colo_attr.grad_payload if hasattr(param, "colo_attr") else param.grad.data
71
+ for param in bucket
72
+ ]
73
+ coalesced = _flatten_dense_tensors(grads).to(torch.cuda.current_device())
74
+ dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=group)
75
+ for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
76
+ buf.copy_(synced)
InternLM/internlm/core/naive_amp.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp
5
+
6
+ from typing import Any
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from torch import Tensor, nn
11
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
12
+ from torch.distributed import ReduceOp
13
+
14
+ from internlm.core.context import ParallelMode
15
+ from internlm.core.context.parallel_context import global_context as gpc
16
+
17
+
18
+ class NaiveAMPModel(nn.Module):
19
+ """
20
+ This is a wrapper class for a model that automatically casts the model, its inputs, and outputs into fp16.
21
+ It also provides options to cast the output back to fp32 and to synchronize buffers.
22
+
23
+ Args:
24
+ model (torch.nn.Module): The model to be wrapped and cast into fp16.
25
+ output_to_fp32 (bool, optional): If True, the output of this module is cast into fp32. Defaults to True.
26
+ parallel_mode (:class:`internlm.core.context.ParallelMode`): The parallel group mode used in this module.
27
+ Defaults to ``ParallelMode.DATA``.
28
+ sync_buffer (bool, optional): If True, the buffers are synchronized. Defaults to True.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ model: nn.Module,
34
+ output_to_fp32: bool = True,
35
+ parallel_mode: ParallelMode = ParallelMode.DATA,
36
+ sync_buffer: bool = True,
37
+ dtype=torch.float16,
38
+ ):
39
+ super().__init__()
40
+ self.model = model.to(dtype)
41
+ self._output_to_fp32 = output_to_fp32
42
+ self._sync_buf = sync_buffer
43
+ self.dtype = dtype
44
+
45
+ if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
46
+ self._process_group = gpc.get_group(parallel_mode)
47
+ self._world_size = gpc.get_world_size(parallel_mode)
48
+ else:
49
+ self._process_group = None
50
+ self._world_size = 1
51
+ self._sync_buf = False
52
+ self._first_eval_run = False
53
+
54
+ @property
55
+ def sync_buffer(self):
56
+ """Returns the current state of the buffer synchronization."""
57
+ return self._sync_buf
58
+
59
+ @sync_buffer.setter
60
+ def sync_buffer(self, state: bool):
61
+ """Sets the state of the buffer synchronization."""
62
+ self._sync_buf = state
63
+
64
+ def _convert_to_fp16(self, input_: Any):
65
+ """Converts the input to fp16 if it is a Tensor of dtype float32."""
66
+ if isinstance(input_, Tensor) and input_.dtype == torch.float32:
67
+ input_ = input_.to(self.dtype)
68
+ return input_
69
+
70
+ def _convert_to_fp32(self, input_: Any):
71
+ """Converts the input to fp32 if it is a Tensor of dtype float16."""
72
+ if isinstance(input_, Tensor) and input_.dtype == torch.float16:
73
+ input_ = input_.float()
74
+ return input_
75
+
76
+ def convert_to_fp32(self, out):
77
+ """Converts the output to fp32"""
78
+ if isinstance(out, Tensor):
79
+ out = self._convert_to_fp32(out)
80
+ elif isinstance(out, (tuple, list)):
81
+ out = [self._convert_to_fp32(val) for val in out]
82
+ elif isinstance(out, dict):
83
+ out = {key: self._convert_to_fp32(val) for key, val in out.items()}
84
+
85
+ return out
86
+
87
+ def _reduce_module_buffer(self):
88
+ """
89
+ All-reduces the buffers (e.g., running stats of batch normalization) across
90
+ data parallel ranks so that all the ranks will produce consistent results
91
+ when given the same input.
92
+ """
93
+ buf_list = []
94
+
95
+ # find valid buffers
96
+ for buf in self.model.buffers():
97
+ if buf is not None:
98
+ buf_list.append(buf)
99
+
100
+ # reduce buffers across data parallel ranks
101
+ if buf_list:
102
+ coalesced_buf = _flatten_dense_tensors(buf_list)
103
+ coalesced_buf.div_(self._world_size)
104
+ dist.all_reduce(coalesced_buf, op=ReduceOp.SUM, group=self._process_group)
105
+ unflattened_buf_list = _unflatten_dense_tensors(coalesced_buf, buf_list)
106
+ for old, new in zip(buf_list, unflattened_buf_list):
107
+ old.copy_(new)
108
+
109
+ def eval(self):
110
+ """Sets the model to evaluation mode. Buffers are only synchronized in the first eval iteration."""
111
+ self.model.eval()
112
+ self._first_eval_run = True
113
+
114
+ def forward(self, *args, **kwargs):
115
+ """
116
+ Performs a forward pass on the model. Buffers are synchronized before the forward pass.
117
+ The inputs are converted to fp16 and the outputs are optionally converted back to fp32.
118
+ """
119
+ if (self.training or self._first_eval_run) and self._sync_buf:
120
+ with torch.no_grad():
121
+ self._reduce_module_buffer()
122
+
123
+ if self._first_eval_run:
124
+ self._first_eval_run = False
125
+
126
+ if args:
127
+ args = [self._convert_to_fp16(arg) for arg in args]
128
+ if kwargs:
129
+ for k, v in kwargs.items():
130
+ kwargs[k] = self._convert_to_fp16(v)
131
+
132
+ out = self.model(*args, **kwargs)
133
+
134
+ if self._output_to_fp32:
135
+ out = self.convert_to_fp32(out)
136
+ return out
InternLM/internlm/core/scheduler/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_scheduler import BaseScheduler, SchedulerHook, SchedulerMetricHook
2
+ from .no_pipeline_scheduler import NonPipelineScheduler, KDNonPipelineScheduler
3
+ from .pipeline_scheduler import InterleavedPipelineScheduler, PipelineScheduler, KDPipelineScheduler
4
+
5
+ __all__ = [
6
+ "BaseScheduler",
7
+ "NonPipelineScheduler",
8
+ "KDNonPipelineScheduler",
9
+ "InterleavedPipelineScheduler",
10
+ "PipelineScheduler",
11
+ "KDPipelineScheduler",
12
+ "SchedulerHook",
13
+ "SchedulerMetricHook",
14
+ ]
InternLM/internlm/core/scheduler/base_scheduler.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
5
+
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any, Callable, Iterable, Optional
8
+
9
+ import torch
10
+
11
+ from internlm.core.engine import Engine
12
+ from internlm.utils.megatron_timers import megatron_timer as timer
13
+
14
+
15
+ class BaseScheduler(ABC):
16
+ """A basic helper class to control the process of training or evaluation.
17
+ It mainly composes of forward_backward_step for gradient backward and
18
+ optimizer_step for parameters update.
19
+ For the convenience to enable FP16, we aggregate all codes that contain the
20
+ control of FP16 in class schedule.
21
+
22
+ Args:
23
+ data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges
24
+ them into data and label.
25
+ """
26
+
27
+ def __init__(self, data_process_func: Callable = None):
28
+ self.data_process_func = data_process_func
29
+
30
+ @abstractmethod
31
+ def pre_processing(self, engine: Engine):
32
+ """To perform actions before running the schedule.
33
+
34
+ Args:
35
+ engine (internlm.core.Engine): InternLM engine for training and inference.
36
+ """
37
+ pass
38
+
39
+ def _load_micro_batch(self, data, label, offset, micro_bsz):
40
+ assert isinstance(data, dict) and isinstance(label, torch.Tensor)
41
+ micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
42
+ micro_batch_label = label[offset : offset + micro_bsz]
43
+
44
+ return micro_batch_data, micro_batch_label
45
+
46
+ @abstractmethod
47
+ def forward_backward_step(
48
+ self,
49
+ engine: Engine,
50
+ data_iter: Iterable,
51
+ forward_only: bool,
52
+ return_loss: bool = True,
53
+ return_output_label: bool = True,
54
+ ):
55
+ """The process function over a batch of dataset for training or evaluation.
56
+
57
+ Args:
58
+ engine (internlm.core.Engine): InternLM engine for training and inference.
59
+ data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
60
+ forward_only (bool): If True, the process won't include backward.
61
+ return_loss (bool, optional): If False, the loss won't be returned.
62
+ return_output_label (bool, optional): If False, the output and label won't be returned.
63
+ """
64
+ pass
65
+
66
+ @staticmethod
67
+ def _call_engine(engine: Engine, inputs: Any):
68
+ """Calls the engine with the given inputs.
69
+
70
+ Args:
71
+ engine (internlm.core.Engine): InternLM engine for training and inference.
72
+ inputs (Any): The inputs to the engine, can be of type torch.Tensor, list, tuple, or dict.
73
+ """
74
+ if isinstance(inputs, torch.Tensor):
75
+ return engine(inputs)
76
+ elif isinstance(inputs, (list, tuple)):
77
+ return engine(*inputs)
78
+ elif isinstance(inputs, dict):
79
+ return engine(**inputs)
80
+ else:
81
+ raise TypeError(
82
+ f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}"
83
+ )
84
+
85
+ @staticmethod
86
+ def _call_engine_criterion(criterion, outputs: Any, labels: Any):
87
+ """Calls the engine's criterion with the given outputs and labels.
88
+
89
+ Args:
90
+ engine (internlm.core.Engine): InternLM engine for training and inference.
91
+ outputs (Any): The outputs from the model, can be of type torch.Tensor, list, tuple, or dict.
92
+ labels (Any): The labels for the outputs, can be of type torch.Tensor, list, tuple, or dict.
93
+ """
94
+ assert isinstance(
95
+ outputs, (torch.Tensor, list, tuple, dict)
96
+ ), f"Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}"
97
+ if isinstance(outputs, torch.Tensor):
98
+ outputs = (outputs,)
99
+ if isinstance(labels, torch.Tensor):
100
+ labels = (labels,)
101
+
102
+ if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
103
+ return criterion(*outputs, *labels)
104
+ elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
105
+ return criterion(*outputs, **labels)
106
+ elif isinstance(outputs, dict) and isinstance(labels, dict):
107
+ return criterion(**outputs, **labels)
108
+ elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
109
+ raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
110
+ else:
111
+ raise TypeError(
112
+ f"Expected model outputs and labels to be of type torch.Tensor ' \
113
+ '(which is auto-converted to tuple), list, tuple, or dict, ' \
114
+ 'but got {type(outputs)} (model outputs) and {type(labels)} (labels)"
115
+ )
116
+
117
+
118
+ class SchedulerHook(ABC):
119
+ """
120
+ Scheduler Hook.
121
+ """
122
+
123
+ @abstractmethod
124
+ def before_forward(self, scheduler, inputs) -> None:
125
+ """Actions before forward"""
126
+
127
+ @abstractmethod
128
+ def after_forward(self, scheduler, outputs) -> None:
129
+ """Actions after forward"""
130
+
131
+ @abstractmethod
132
+ def before_criterion(self, scheduler, outputs, label) -> None:
133
+ """Actions before criterion"""
134
+
135
+ @abstractmethod
136
+ def after_criterion(self, scheduler, loss) -> None:
137
+ """Actions after criterion"""
138
+
139
+ @abstractmethod
140
+ def before_backward(self, scheduler, outputs, outputs_grad) -> None:
141
+ """Actions before backward"""
142
+
143
+ @abstractmethod
144
+ def after_backward(self, scheduler, inputs_grad) -> None:
145
+ """Actions after backward"""
146
+
147
+ @abstractmethod
148
+ def post_helper_func(self, scheduler, outputs, label) -> None:
149
+ """A post helper function"""
150
+
151
+
152
+ class SchedulerMetricHook(SchedulerHook):
153
+ """
154
+ Scheduler Metric Hook.
155
+ """
156
+
157
+ def __init__(self, metric: Optional[Callable] = None, skip: bool = False) -> None:
158
+ self._post_func = metric
159
+ self._skip = skip
160
+
161
+ def before_forward(self, scheduler, inputs) -> None:
162
+ if not self._skip:
163
+ timer("fwd").start()
164
+
165
+ def after_forward(self, scheduler, outputs) -> None:
166
+ if not self._skip:
167
+ timer("fwd").stop()
168
+
169
+ def before_criterion(self, scheduler, outputs, label) -> None:
170
+ if not self._skip:
171
+ timer("cal_loss").start()
172
+
173
+ def after_criterion(self, scheduler, loss) -> None:
174
+ if not self._skip:
175
+ timer("cal_loss").stop()
176
+
177
+ def before_backward(self, scheduler, outputs, outputs_grad) -> None:
178
+ if not self._skip:
179
+ timer("bwd").start()
180
+
181
+ def after_backward(self, scheduler, inputs_grad) -> None:
182
+ if not self._skip:
183
+ timer("bwd").stop()
184
+
185
+ def post_helper_func(self, scheduler, outputs, label) -> None:
186
+ if self._post_func is not None:
187
+ self._post_func(outputs, label)
InternLM/internlm/core/scheduler/no_pipeline_scheduler.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
5
+
6
+ from typing import Any, Callable, Iterable, List, Optional
7
+
8
+ import torch
9
+
10
+ from internlm.core.context import global_context as gpc
11
+ from internlm.core.engine import Engine, KDEngine
12
+ from internlm.utils.common import conditional_context
13
+ from internlm.utils.timeout import llm_timeout
14
+ from collections import defaultdict
15
+ from .base_scheduler import BaseScheduler, SchedulerHook
16
+
17
+
18
+ class NonPipelineScheduler(BaseScheduler):
19
+ """A helper schedule class for no pipeline parallelism running environment.
20
+ During one process, it loads a batch of dataset and feeds it to the model.
21
+ After getting the output and calculating the loss, it will use :meth:`step`
22
+ to update the parameters if it is in training mode.
23
+
24
+ Args:
25
+ data_process_func (Callable, optional): The preprocessing function which receives a batch of data
26
+ and returns a tuple in the form of (data, label), and it will be executed in load_batch.
27
+ gradient_accumulation_steps(int, optional): the steps of gradient accumulation, 1 for disable
28
+ gradient accumulation.
29
+
30
+ Examples:
31
+ >>> # this shows an tools of customized data_process_func
32
+ >>> def data_process_func(dataloader_output):
33
+ >>> item1, item2, item3 = dataloader_output
34
+ >>> data = (item1, item2)
35
+ >>> label = item3
36
+ >>> return data, label
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ data_process_func: Callable = None,
42
+ gradient_accumulation_size: int = 1,
43
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
44
+ ):
45
+ self._grad_accum_size = gradient_accumulation_size
46
+ self._grad_accum_offset = 0
47
+
48
+ self._hooks = scheduler_hooks
49
+
50
+ super().__init__(data_process_func)
51
+
52
+ def pre_processing(self, engine: Engine):
53
+ """Performs actions before running the schedule.
54
+
55
+ Args:
56
+ engine (internlm.core.Engine): InternLM engine for training and inference.
57
+ """
58
+ pass
59
+
60
+ def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
61
+ for hook in self._hooks:
62
+ getattr(hook, func_name)(self, *args, **kwargs)
63
+
64
+ def _load_accum_batch(self, data: Any, label: Any):
65
+ """Loads a batch of data and label for gradient accumulation.
66
+
67
+ Args:
68
+ data (Any): The data to be loaded.
69
+ label (Any): The label to be loaded.
70
+ """
71
+
72
+ _data, _label = self._load_micro_batch(
73
+ data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
74
+ )
75
+ self._grad_accum_offset += self._grad_accum_batch_size
76
+
77
+ if self.data_process_func:
78
+ _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
79
+ _label = self.data_process_func(_label, _data["cu_seqlens"])
80
+ _data.pop("cu_seqlens")
81
+ _data.pop("indexes")
82
+
83
+ return _data, _label
84
+
85
+ def _train_one_batch(
86
+ self,
87
+ data: Any,
88
+ label: Any,
89
+ engine: Engine,
90
+ forward_only: bool = False,
91
+ return_loss: bool = True,
92
+ scale_loss: int = 1,
93
+ ):
94
+ """Trains one batch of data.
95
+
96
+ Args:
97
+ data (Any): The data to be trained.
98
+ label (Any): The label for the data.
99
+ engine (internlm.core.Engine): InternLM engine for training and inference.
100
+ forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
101
+ be executed.
102
+ return_loss (bool, optional): Loss will be returned if True.
103
+ scale_loss (int, optional): The scale factor for the loss.
104
+ """
105
+
106
+ # forward
107
+ with conditional_context(torch.no_grad(), enable=forward_only):
108
+ self._call_hooks("before_forward", data)
109
+ output = self._call_engine(engine, data)
110
+ self._call_hooks("after_forward", output)
111
+
112
+ self._call_hooks("post_helper_func", output, label)
113
+
114
+ if return_loss:
115
+ self._call_hooks("before_criterion", output, label)
116
+ loss = self._call_engine_criterion(engine.criterion, output, label)
117
+ self._call_hooks("after_criterion", loss)
118
+ loss /= scale_loss
119
+
120
+ # backward
121
+ if not forward_only:
122
+ self._call_hooks("before_backward", None, None)
123
+ engine.backward(loss)
124
+ self._call_hooks("after_backward", None)
125
+
126
+ if not return_loss:
127
+ loss = None
128
+
129
+ return output, dict(loss=loss)
130
+
131
+ @llm_timeout(func_name="nopp_forward_backward_step")
132
+ def forward_backward_step(
133
+ self,
134
+ engine: Engine,
135
+ data_iter: Iterable,
136
+ forward_only: bool = False,
137
+ return_loss: bool = True,
138
+ return_output_label: bool = True,
139
+ ):
140
+ """The process function that loads a batch of dataset and feeds it to the model.
141
+ The returned labels and loss will None if :attr:`return_loss` is False.
142
+
143
+ Args:
144
+ engine (internlm.core.Engine): InternLM engine for training and inference.
145
+ data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
146
+ forward_only (bool, optional):
147
+ If True, the model is run for the forward pass, else back propagation will be executed.
148
+ return_loss (bool, optional): Loss will be returned if True.
149
+ return_output_label (bool, optional): Output and label will be returned if True.
150
+
151
+ Returns:
152
+ Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
153
+ """
154
+ assert (
155
+ forward_only or return_loss
156
+ ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
157
+
158
+ batch_data, batch_size = engine.load_batch(data_iter)
159
+
160
+ assert (
161
+ batch_size % self._grad_accum_size == 0
162
+ ), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
163
+ self._grad_accum_batch_size = batch_size // self._grad_accum_size
164
+
165
+ data, label = batch_data
166
+
167
+ loss = defaultdict(int) if return_loss else None
168
+ outputs = []
169
+ labels = []
170
+
171
+ # reset accumulation microbatch offset
172
+ self._grad_accum_offset = 0
173
+
174
+ for _current_accum_step in range(self._grad_accum_size):
175
+ if _current_accum_step == self._grad_accum_size - 1:
176
+ engine.optimizer.skip_grad_reduce = False
177
+ else:
178
+ engine.optimizer.skip_grad_reduce = True
179
+
180
+ _data, _label = self._load_accum_batch(data, label)
181
+
182
+ _output, _loss = self._train_one_batch(
183
+ _data, _label, engine, forward_only, return_loss, self._grad_accum_size
184
+ )
185
+
186
+ if return_loss:
187
+ for k in _loss:
188
+ loss[k] += _loss[k]
189
+ if return_output_label:
190
+ outputs.append(_output)
191
+ labels.append(_label)
192
+
193
+ if not return_output_label:
194
+ outputs, labels = None, None
195
+
196
+ return outputs, labels, loss
197
+
198
+
199
+ class KDNonPipelineScheduler(NonPipelineScheduler):
200
+
201
+ def __init__(
202
+ self,
203
+ data_process_func: Callable = None,
204
+ gradient_accumulation_size: int = 1,
205
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
206
+ ):
207
+ super().__init__(
208
+ data_process_func=data_process_func,
209
+ gradient_accumulation_size=gradient_accumulation_size,
210
+ scheduler_hooks=scheduler_hooks,
211
+ )
212
+
213
+ def _train_one_batch(
214
+ self,
215
+ data: Any,
216
+ label: Any,
217
+ engine: KDEngine,
218
+ forward_only: bool = False,
219
+ return_loss: bool = True,
220
+ scale_loss: int = 1,
221
+ ):
222
+ """Trains one batch of data.
223
+
224
+ Args:
225
+ data (Any): The data to be trained.
226
+ label (Any): The label for the data.
227
+ engine (internlm.core.Engine): InternLM engine for training and inference.
228
+ forward_only (bool, optional): If True, the model is run for the forward pass, else back propagation will
229
+ be executed.
230
+ return_loss (bool, optional): Loss will be returned if True.
231
+ scale_loss (int, optional): The scale factor for the loss.
232
+ """
233
+
234
+ # forward
235
+ with conditional_context(torch.no_grad(), enable=forward_only):
236
+ self._call_hooks("before_forward", data)
237
+ output = self._call_engine(engine, data)
238
+ self._call_hooks("after_forward", output)
239
+
240
+ self._call_hooks("post_helper_func", output, label)
241
+
242
+ if return_loss:
243
+ self._call_hooks("before_criterion", output, label)
244
+ loss_gt = gpc.config.kd_config['gt_weight'] * self._call_engine_criterion(engine.criterion, output, label)
245
+
246
+ with torch.no_grad():
247
+ engine.teacher.eval()
248
+ output_t = self._call_engine(engine.teacher, data)
249
+
250
+ loss_kd = gpc.config.kd_config['kd_weight'] * self._call_engine_criterion(engine.kd_criterion, output, (output_t, label))
251
+
252
+ self._call_hooks("after_criterion", loss_gt + loss_kd)
253
+ loss_gt /= scale_loss
254
+ loss_kd /= scale_loss
255
+
256
+ # backward
257
+ if not forward_only:
258
+ self._call_hooks("before_backward", None, None)
259
+ engine.backward(loss_gt+loss_kd)
260
+ self._call_hooks("after_backward", None)
261
+
262
+ if not return_loss:
263
+ loss_gt = None
264
+ loss_kd = None
265
+
266
+ return output, dict(loss_gt=loss_gt, loss_kd=loss_kd)
InternLM/internlm/core/scheduler/pipeline_scheduler.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
5
+
6
+ from contextlib import contextmanager
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch.cuda
10
+
11
+ import internlm.core.communication as comm
12
+ from internlm.core.context import ParallelMode
13
+ from internlm.core.context import global_context as gpc
14
+ from internlm.core.engine import Engine
15
+ from internlm.core.naive_amp import NaiveAMPModel
16
+ from internlm.utils.common import get_current_device, move_to_device
17
+ from internlm.utils.logger import get_logger
18
+ from internlm.utils.timeout import llm_timeout
19
+ from collections import defaultdict
20
+ from .base_scheduler import BaseScheduler, SchedulerHook
21
+
22
+ logger = get_logger(__file__)
23
+
24
+
25
+ def get_tensor_shape():
26
+ if hasattr(gpc.config, "TENSOR_SHAPE"):
27
+ return gpc.config.TENSOR_SHAPE
28
+
29
+ if not gpc.is_initialized(ParallelMode.PIPELINE):
30
+ return None
31
+
32
+ if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
33
+ if gpc.config.model.use_flash_attn:
34
+ if gpc.config.parallel.sequence_parallel:
35
+ sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
36
+ tensor_shape = (
37
+ gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
38
+ gpc.config.HIDDEN_SIZE,
39
+ )
40
+ else:
41
+ tensor_shape = (
42
+ gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
43
+ gpc.config.HIDDEN_SIZE,
44
+ )
45
+ else:
46
+ tensor_shape = (
47
+ gpc.config.data["micro_bsz"],
48
+ gpc.config.SEQ_LEN,
49
+ gpc.config.HIDDEN_SIZE,
50
+ )
51
+ return tensor_shape
52
+ else:
53
+ return None
54
+
55
+
56
+ def pack_return_tensors(return_tensors):
57
+ output, label = tuple(zip(*return_tensors))
58
+ if isinstance(output[0], torch.Tensor):
59
+ output = torch.cat(output, dim=0)
60
+ elif isinstance(output[0], (list, tuple)):
61
+ output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
62
+ else:
63
+ raise TypeError("Output of model must be tensor or list/tuple of tensors")
64
+ if isinstance(label[0], torch.Tensor):
65
+ label = torch.cat(label, dim=0)
66
+ else:
67
+ merged_label = {k: [] for k in label[0].keys()}
68
+ for d in label:
69
+ for k, v in d.items():
70
+ merged_label[k].append(v)
71
+ label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
72
+ return output, label
73
+
74
+
75
+ @contextmanager
76
+ def switch_virtual_pipeline_parallel_rank(rank):
77
+ prev_rank = gpc.virtual_pipeline_parallel_rank
78
+ try:
79
+ gpc.set_virtual_pipeline_parallel_rank(rank)
80
+ yield
81
+ finally:
82
+ gpc.set_virtual_pipeline_parallel_rank(prev_rank)
83
+
84
+
85
+ @contextmanager
86
+ def switch_optimizer_grad_sync_skip_mode(optimizer, skip: bool = True):
87
+ prev_mode = optimizer.skip_grad_reduce
88
+ try:
89
+ optimizer.skip_grad_reduce = skip
90
+ yield
91
+ finally:
92
+ optimizer.skip_grad_reduce = prev_mode
93
+
94
+
95
+ class PipelineScheduler(BaseScheduler):
96
+ """
97
+ A helper schedule class for pipeline parallelism running environment.
98
+ It uses non-interleaved 1F1B strategy. Other properties are similar as
99
+ :class:`NonPipelineSchedule`.
100
+
101
+ Args:
102
+ num_microbatches (int): The number of microbatches.
103
+ dtype (torch.dtype): Type of data. torch.float by default.
104
+ data_process_func (Callable, optional):
105
+ The post processing function which receives a micro batch of data, and it will be executed
106
+ in `load_micro_batch`.
107
+ tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
108
+ scatter_gather_tensors (bool, optional):
109
+ If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
110
+ scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ num_microbatches: int,
116
+ dtype: torch.dtype = torch.float,
117
+ data_process_func: Callable = None,
118
+ tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
119
+ scatter_gather_tensors: bool = False,
120
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
121
+ ):
122
+ assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
123
+
124
+ assert not isinstance(
125
+ tensor_shape, int
126
+ ), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
127
+
128
+ super().__init__(data_process_func=data_process_func)
129
+
130
+ self.num_microbatches = num_microbatches
131
+ self.dtype = dtype
132
+ self._hooks = scheduler_hooks
133
+
134
+ self._tensor_shape = (
135
+ tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
136
+ )
137
+
138
+ self.scatter_gather_tensors = (
139
+ scatter_gather_tensors
140
+ and gpc.is_initialized(ParallelMode.TENSOR)
141
+ and gpc.get_world_size(ParallelMode.TENSOR) > 1
142
+ )
143
+
144
+ if gpc.config.parallel.sequence_parallel:
145
+ self.scatter_gather_tensors = False
146
+
147
+ # cache for the batch data
148
+ self.batch_data = None
149
+
150
+ @property
151
+ def tensor_shape(self) -> torch.Size:
152
+ return self._tensor_shape
153
+
154
+ @tensor_shape.setter
155
+ def tensor_shape(self, tensor_shape: torch.Size):
156
+ self._tensor_shape = tensor_shape
157
+
158
+ def pre_processing(self, engine):
159
+ types = set()
160
+
161
+ for param in engine.model.parameters():
162
+ types.add(param.dtype)
163
+ assert len(types) == 1, f"Mixed types of parameter detected, {types}"
164
+
165
+ self.dtype = types.pop()
166
+
167
+ @staticmethod
168
+ def _call_engine(engine, data): # pylint: disable=W0237
169
+ if data is None:
170
+ return None
171
+
172
+ if isinstance(data, torch.Tensor):
173
+ return engine(data)
174
+ elif isinstance(data, (list, tuple)):
175
+ return engine(*data)
176
+ elif isinstance(data, dict):
177
+ stage_output = data.pop("stage_output", None)
178
+
179
+ if stage_output is None:
180
+ return engine(**data)
181
+ elif isinstance(stage_output, torch.Tensor):
182
+ return engine(stage_output, **data)
183
+ elif isinstance(stage_output, (tuple, list)):
184
+ return engine(*stage_output, **data)
185
+ else:
186
+ raise TypeError(
187
+ f"Expected stage_output to be of type torch.Tensor, list, or tuple, "
188
+ f"but got {type(stage_output)}"
189
+ )
190
+ else:
191
+ raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
192
+
193
+ def load_batch(self, engine, data_iter):
194
+ # Pipeline schedule just puts data in memory
195
+ batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
196
+ assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
197
+
198
+ self.microbatch_offset = 0
199
+ self.batch_size = batch_size
200
+ self.batch_data, self.batch_label = batch_data
201
+ self.microbatch_size = self.batch_size // self.num_microbatches
202
+
203
+ def load_micro_batch(self):
204
+ micro_batch_data, micro_batch_label = self._load_micro_batch(
205
+ data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
206
+ )
207
+ if self.data_process_func:
208
+ micro_batch_data["input_ids"] = self.data_process_func(
209
+ micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
210
+ )
211
+ micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
212
+
213
+ micro_batch_data.pop("cu_seqlens")
214
+ micro_batch_data.pop("indexes")
215
+
216
+ micro_batch_data["label"] = micro_batch_label
217
+ self.microbatch_offset += self.microbatch_size
218
+
219
+ return move_to_device(micro_batch_data)
220
+
221
+ def _get_data_label_for_current_step(self, stage_output, micro_batch_data):
222
+ if isinstance(micro_batch_data, (tuple, list)):
223
+ if gpc.is_first_rank(ParallelMode.PIPELINE):
224
+ # for the first stage, we use the data from the
225
+ # dataloader output by default
226
+ data, label = micro_batch_data
227
+ else:
228
+ # for non-first stage, we use the output passed
229
+ # by the previous as the model input
230
+ data = stage_output
231
+ _, label = micro_batch_data
232
+ elif isinstance(micro_batch_data, dict):
233
+ label = micro_batch_data.pop("label", None)
234
+ data = {"stage_output": stage_output, **micro_batch_data}
235
+
236
+ return data, label
237
+
238
+ def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
239
+ for hook in self._hooks:
240
+ getattr(hook, func_name)(self, *args, **kwargs)
241
+
242
+ def _get_current_microbatch_id(self, step_id: int) -> int:
243
+ """
244
+ Get the current microbatch ID based on the step ID.
245
+ In 1f1b scheduler, the microbatch ID is the same as the step ID,
246
+ but it is important to note that the step ID is calculated separately
247
+ for forward and backward passes.
248
+ """
249
+ return step_id
250
+
251
+ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
252
+ """
253
+ Forward step for passed-in model. If it is the first stage, the input tensor
254
+ is obtained from data_iterator, otherwise the passed-in input_obj is used.
255
+ Returns output tensor. This is a helper function and can be ignored by users.
256
+
257
+ Args:
258
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
259
+ input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
260
+ return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
261
+ return_output_label (bool, optional): Whether returns output labels.
262
+ accum_loss (optional): Where accumulated loss stores.
263
+ Returns:
264
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
265
+ pipeline stage.
266
+ """
267
+ micro_batch_data = self.load_micro_batch()
268
+ data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
269
+
270
+ self._call_hooks("before_forward", data)
271
+ output_obj = self._call_engine(engine.model, data)
272
+ self._call_hooks("after_forward", output_obj)
273
+
274
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
275
+ self._call_hooks("post_helper_func", output_obj, label)
276
+ if return_output_label:
277
+ return_tensors.append((output_obj, label))
278
+ if accum_loss is not None:
279
+ self._call_hooks("before_criterion", output_obj, label)
280
+ loss = self._call_engine_criterion(engine.criterion, output_obj, label)
281
+ self._call_hooks("after_criterion", loss)
282
+
283
+ loss_reduced = loss / self.num_microbatches
284
+ accum_loss['loss'].add_(loss_reduced.detach())
285
+ output_obj = loss_reduced
286
+
287
+ return output_obj
288
+
289
+ def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
290
+ """
291
+ Backward step through the passed-in output tensor. If it is the last stage, the
292
+ output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
293
+ Returns the gradients with respect to the input tensor (None if first stage).
294
+ This is a helper function and can be ignored by users.
295
+
296
+ Args:
297
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
298
+ step_id (int): The ID of the current step.
299
+ input_obj (Union[torch.Tensor, List[torch.Tensor]]): Input tensor for this stage.
300
+ output_obj (Union[torch.Tensor, List[torch.Tensor]]): Output tensor for this stage.
301
+ output_obj_grad (Union[torch.Tensor, List[torch.Tensor]]): Gradient of output tensor for this stage.
302
+
303
+ Returns:
304
+ Union[torch.Tensor, List[torch.Tensor]]: Gradient of input tensor.
305
+ """
306
+
307
+ # Retain the grad on the input_obj.
308
+ if input_obj is not None:
309
+ if isinstance(input_obj, torch.Tensor):
310
+ input_obj.retain_grad()
311
+ else:
312
+ for in_tensor in input_obj:
313
+ if in_tensor is not None:
314
+ in_tensor.retain_grad()
315
+
316
+ # Backward pass.
317
+
318
+ # Only the last microbatch does syncing grad.
319
+ skip_grad_sync = self._get_current_microbatch_id(step_id) != self.num_microbatches - 1
320
+
321
+ self._call_hooks("before_backward", output_obj, output_obj_grad)
322
+ with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
323
+ if output_obj_grad is None:
324
+ engine.backward(output_obj)
325
+ else:
326
+ engine.backward_by_grad(output_obj, output_obj_grad)
327
+
328
+ # Collect the grad of the input_obj.
329
+ input_obj_grad = None
330
+ if input_obj is not None:
331
+ if isinstance(input_obj, torch.Tensor):
332
+ input_obj_grad = input_obj.grad
333
+ else:
334
+ input_obj_grad = []
335
+ for in_tensor in input_obj:
336
+ input_obj_grad.append(in_tensor.grad)
337
+ self._call_hooks("after_backward", input_obj_grad)
338
+
339
+ return input_obj_grad
340
+
341
+ def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
342
+ """
343
+ This function performs forward only computation process. The scheduling of microbatches is similar to the
344
+ warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
345
+ the forward computation, and finally passes the forward computation output to the next stage. There are two
346
+ special cases to note:
347
+ 1. The first stage of the pipeline does not need to receive forward input; its input comes from the dataloader.
348
+ 2. The last stage of the pipeline does not need to send forward output; its output is returned to the user code
349
+ for processing.
350
+
351
+ Args:
352
+ engine (colossalai.engine.Engine): internlm engine for training and inference.
353
+ return_loss (bool, optional): Whether to return the accumulated loss.
354
+ return_output_label (bool, optional): Whether to return outputs and labels.
355
+
356
+ Returns:
357
+ Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
358
+ output, label, and accumulated loss.
359
+ """
360
+
361
+ # Input, output tensors only need to be saved when doing backward passes
362
+ return_tensors = []
363
+ accum_loss_init_func = lambda: torch.zeros(1, device=get_current_device())
364
+ accum_loss = defaultdict(accum_loss_init_func) if return_loss and gpc.is_pipeline_last_stage(
365
+ ignore_virtual=True) else None
366
+
367
+ # Used for tensor meta information communication
368
+ forward_recv_shapes = self.tensor_shape
369
+ need_forward_meta = self.tensor_shape is None
370
+
371
+ # Run all forward passes.
372
+ for _ in range(self.num_microbatches):
373
+ # Receive input from the previous stage
374
+ if not gpc.is_first_rank(ParallelMode.PIPELINE):
375
+ if forward_recv_shapes is None:
376
+ forward_recv_shapes = comm.recv_obj_meta()
377
+ input_obj = comm.recv_forward(
378
+ forward_recv_shapes,
379
+ dtype=self.dtype,
380
+ scatter_gather_tensors=self.scatter_gather_tensors,
381
+ )
382
+ else:
383
+ input_obj = None
384
+
385
+ # Perform forward computation
386
+ output_obj = self._forward_step(
387
+ engine,
388
+ input_obj,
389
+ return_tensors,
390
+ return_output_label=return_output_label,
391
+ accum_loss=accum_loss,
392
+ )
393
+
394
+ if not gpc.is_last_rank(ParallelMode.PIPELINE):
395
+ if need_forward_meta:
396
+ comm.send_obj_meta(output_obj)
397
+ need_forward_meta = False # send only once.
398
+ # Send the forward computation output to the next stage
399
+ comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
400
+
401
+ output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
402
+
403
+ return output, label, accum_loss
404
+
405
+ def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
406
+ """
407
+ This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
408
+ It consists of three stages: warmup, 1F1B, and cooldown.
409
+
410
+ 1. Warmup Stage:
411
+ The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length
412
+ minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous
413
+ stage, performs the forward computation, and then sends the result to the next stage.
414
+
415
+ 2. 1F1B Stage:
416
+ The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations,
417
+ where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in
418
+ the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next
419
+ stage, receives input for the backward computation, performs the backward computation, and finally sends the
420
+ result to the previous stage to receive input for the next forward computation.
421
+
422
+ 3. Cooldown Stage:
423
+ The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives
424
+ input for the backward computation, performs the backward computation, and finally sends the result to the
425
+ previous stage.
426
+
427
+ There are two special cases to consider:
428
+ 1. The first stage of the pipeline does not need to receive forward input or send backward output. The last
429
+ stage does not need to send forward output or receive backward input.
430
+ 2. Pay attention to the communication between stages and use additional communication to bridge the gap.
431
+
432
+ Args:
433
+ engine (Engine): The engine used for computation.
434
+ return_loss (bool, optional): Whether to return the accumulated loss.
435
+ return_output_label (bool, optional): Whether to return outputs and labels.
436
+
437
+ Returns:
438
+ Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
439
+ The output, label, and accumulated loss.
440
+ """
441
+
442
+ num_warmup_microsteps = (
443
+ gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
444
+ )
445
+ num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches)
446
+ num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps
447
+
448
+ # Input, output tensors only need to be saved when doing backward passes
449
+ input_objs = []
450
+ output_objs = []
451
+ return_tensors = []
452
+ accum_loss_init_func = lambda: torch.zeros(1, device=get_current_device())
453
+ accum_loss = defaultdict(accum_loss_init_func) if return_loss and gpc.is_pipeline_last_stage(
454
+ ignore_virtual=True) else None
455
+
456
+ # Used for tensor meta information communication
457
+ forward_recv_shapes = self.tensor_shape
458
+ backward_recv_shapes = None
459
+ need_forward_meta = self.tensor_shape is None
460
+
461
+ # Run warmup forward passes.
462
+ for i in range(num_warmup_microsteps):
463
+ # Receive the input from the previous stage
464
+ if not gpc.is_first_rank(ParallelMode.PIPELINE):
465
+ if forward_recv_shapes is None:
466
+ forward_recv_shapes = comm.recv_obj_meta()
467
+ input_obj = comm.recv_forward(
468
+ forward_recv_shapes,
469
+ dtype=self.dtype,
470
+ scatter_gather_tensors=self.scatter_gather_tensors,
471
+ )
472
+ else:
473
+ input_obj = None
474
+
475
+ # Perform forward computation
476
+ output_obj = self._forward_step(
477
+ engine,
478
+ input_obj,
479
+ return_tensors,
480
+ return_output_label=return_output_label,
481
+ accum_loss=accum_loss,
482
+ )
483
+
484
+ if not gpc.is_last_rank(ParallelMode.PIPELINE):
485
+ if isinstance(output_obj, torch.Tensor):
486
+ backward_recv_shapes = output_obj.shape
487
+ else:
488
+ backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj]
489
+
490
+ if need_forward_meta:
491
+ comm.send_obj_meta(output_obj)
492
+ need_forward_meta = False # send only once.
493
+
494
+ # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
495
+ # forward computation
496
+ if not gpc.is_last_rank(ParallelMode.PIPELINE):
497
+ comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
498
+
499
+ input_objs.append(input_obj)
500
+ output_objs.append(output_obj)
501
+
502
+ # Before running 1F1B, need to receive first forward tensor.
503
+ # If all microbatches are run in warmup / cooldown phase, then no need to
504
+ # receive this tensor here.
505
+ if num_1f1b_micropairs > 0:
506
+ if not gpc.is_first_rank(ParallelMode.PIPELINE):
507
+ if forward_recv_shapes is None:
508
+ forward_recv_shapes = comm.recv_obj_meta(forward_recv_shapes)
509
+ input_obj = comm.recv_forward(
510
+ forward_recv_shapes,
511
+ dtype=self.dtype,
512
+ scatter_gather_tensors=self.scatter_gather_tensors,
513
+ )
514
+ else:
515
+ input_obj = None
516
+
517
+ # Run 1F1B in steady state.
518
+ for i in range(num_1f1b_micropairs):
519
+ # Perform forward computation
520
+ output_obj = self._forward_step(
521
+ engine,
522
+ input_obj,
523
+ return_tensors,
524
+ return_output_label=return_output_label,
525
+ accum_loss=accum_loss,
526
+ )
527
+
528
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
529
+ output_obj_grad = None
530
+ else:
531
+ output_obj_grad = comm.send_forward_recv_backward(
532
+ output_obj,
533
+ backward_recv_shapes,
534
+ dtype=self.dtype,
535
+ scatter_gather_tensors=self.scatter_gather_tensors,
536
+ )
537
+
538
+ # Add input_obj and output_obj to end of list.
539
+ input_objs.append(input_obj)
540
+ output_objs.append(output_obj)
541
+
542
+ # Pop output_obj and output_obj from the start of the list for
543
+ # the backward pass.
544
+ input_obj = input_objs.pop(0)
545
+ output_obj = output_objs.pop(0)
546
+
547
+ input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
548
+
549
+ if i == (num_1f1b_micropairs - 1):
550
+ input_obj = None
551
+ if not gpc.is_first_rank(ParallelMode.PIPELINE):
552
+ comm.send_backward(
553
+ input_obj_grad,
554
+ scatter_gather_tensors=self.scatter_gather_tensors,
555
+ )
556
+ else:
557
+ if gpc.is_first_rank(ParallelMode.PIPELINE):
558
+ input_obj = None
559
+ else:
560
+ input_obj = comm.send_backward_recv_forward(
561
+ input_obj_grad,
562
+ forward_recv_shapes,
563
+ dtype=self.dtype,
564
+ scatter_gather_tensors=self.scatter_gather_tensors,
565
+ )
566
+
567
+ # Run cooldown backward passes.
568
+ for i in range(num_warmup_microsteps):
569
+ input_obj = input_objs.pop(0)
570
+ output_obj = output_objs.pop(0)
571
+
572
+ if not gpc.is_last_rank(ParallelMode.PIPELINE):
573
+ output_obj_grad = comm.recv_backward(
574
+ backward_recv_shapes,
575
+ dtype=self.dtype,
576
+ scatter_gather_tensors=self.scatter_gather_tensors,
577
+ )
578
+ else:
579
+ output_obj_grad = None
580
+
581
+ input_obj_grad = self._backward_step(
582
+ engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
583
+ )
584
+
585
+ if not gpc.is_first_rank(ParallelMode.PIPELINE):
586
+ comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
587
+
588
+ output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
589
+
590
+ return output, label, accum_loss
591
+
592
+ @llm_timeout(func_name="nointerleaved_forward_backward_step")
593
+ def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
594
+ """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
595
+ Returns a tuple with losses if the last stage, an empty tuple otherwise.
596
+
597
+ Args:
598
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
599
+ data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
600
+ forward_only (bool, optional):
601
+ Whether run forward step only. Default is false. If true, no backward will be run.
602
+ return_loss (bool, optional): Whether returns the loss value. Default is true.
603
+ return_output_label (bool, optional): If False, the output and label won't be returned.
604
+ Returns:
605
+ Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
606
+ """
607
+
608
+ assert (
609
+ forward_only or return_loss
610
+ ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
611
+
612
+ # Load data first
613
+ self.load_batch(engine, data_iter)
614
+
615
+ if forward_only:
616
+ return self._forward_only_step(engine, return_loss, return_output_label)
617
+ else:
618
+ return self._forward_backward_step(engine, return_loss, return_output_label)
619
+
620
+
621
+ class InterleavedPipelineScheduler(PipelineScheduler):
622
+ """
623
+ Interleaved Pipeline Scheduler.
624
+ """
625
+
626
+ def __init__(
627
+ self,
628
+ num_microbatches: int,
629
+ num_chunks: int,
630
+ dtype: torch.dtype = torch.float,
631
+ data_process_func: Callable = None,
632
+ tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
633
+ scatter_gather_tensors: bool = False,
634
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
635
+ communication_overlap: bool = False,
636
+ ):
637
+ """A helper schedule class for pipeline parallelism running environment.
638
+ It uses interleaved 1F1B strategy. Other properties are similar as
639
+ :class:`NonPipelineSchedule`.
640
+
641
+ Args:
642
+ num_microbatches (int): The number of microbatches.
643
+ num_chunks (int): The number of model chunks.
644
+ dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float.
645
+ data_process_func (Callable, optional):
646
+ The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
647
+ tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
648
+ scatter_gather_tensors (bool, optional):
649
+ If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
650
+ scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None.
651
+ communication_overlap (bool, optional): Whether to enable communication overlap. Default is False.
652
+ """
653
+ assert (
654
+ num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
655
+ ), "num_microbatches must be an integer multiple of pipeline parallel world size"
656
+
657
+ assert (
658
+ isinstance(num_chunks, int) and num_chunks > 0
659
+ ), f"expected num_chunks to be an integer and larger than 0, but got {num_chunks}"
660
+
661
+ super().__init__(
662
+ num_microbatches,
663
+ dtype=dtype,
664
+ data_process_func=data_process_func,
665
+ tensor_shape=tensor_shape,
666
+ scatter_gather_tensors=scatter_gather_tensors,
667
+ scheduler_hooks=scheduler_hooks,
668
+ )
669
+
670
+ gpc.set_virtual_pipeline_parallel_size(num_chunks)
671
+ gpc.set_virtual_pipeline_parallel_rank(0)
672
+
673
+ self._num_chunks = num_chunks
674
+ self._communication_overlap = communication_overlap
675
+ # switch 1f1b loop runner function according to communication overlap
676
+ self._run_1f1b_loop = (
677
+ self._run_1f1b_loop_with_overlap if communication_overlap else self._run_1f1b_loop_without_overlap
678
+ )
679
+
680
+ # states
681
+ self._pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
682
+ self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
683
+
684
+ self._accum_loss = None
685
+ self._return_tensors = None
686
+ self._input_objs = [[] for _ in range(num_chunks)]
687
+ self._output_objs = [[] for _ in range(num_chunks)]
688
+ self._output_obj_grads = [[] for _ in range(num_chunks)]
689
+
690
+ self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
691
+ self._output_obj_shapes = [None for _ in range(num_chunks)]
692
+ self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
693
+
694
+ @property
695
+ def tensor_shape(self) -> torch.Size:
696
+ return self._tensor_shape
697
+
698
+ @tensor_shape.setter
699
+ def tensor_shape(self, tensor_shape: torch.Size):
700
+ self._tensor_shape = tensor_shape
701
+ self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
702
+ self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
703
+
704
+ def _clear_state(self) -> None:
705
+ self._accum_loss = None
706
+ self._return_tensors = None
707
+ self._input_objs = [[] for _ in range(self._num_chunks)]
708
+ self._output_objs = [[] for _ in range(self._num_chunks)]
709
+ self._output_obj_grads = [[] for _ in range(self._num_chunks)]
710
+
711
+ self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
712
+ self._output_obj_shapes = [None for _ in range(self._num_chunks)]
713
+ self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)]
714
+
715
+ def load_batch(self, engine, data_iter):
716
+ super().load_batch(engine, data_iter)
717
+ # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
718
+ self.microbatch_offset = [0 for _ in range(self._num_chunks)]
719
+
720
+ def load_micro_batch(self, model_chunk_id):
721
+ micro_batch_data, micro_batch_label = self._load_micro_batch(
722
+ data=self.batch_data,
723
+ label=self.batch_label,
724
+ offset=self.microbatch_offset[model_chunk_id],
725
+ micro_bsz=self.microbatch_size,
726
+ )
727
+ micro_batch_data["label"] = micro_batch_label
728
+ self.microbatch_offset[model_chunk_id] += self.microbatch_size
729
+ return move_to_device(micro_batch_data)
730
+
731
+ def _forward_step(self, engine, chunk_id):
732
+ """Forward step for passed-in model. If it is the first stage, the input tensor
733
+ is obtained from data_iterator, otherwise the passed-in input_obj is used.
734
+ Returns output tensor. This is a helper function and can be ignored by users.
735
+
736
+ Args:
737
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
738
+ chunk_id (int): The id of model chunks.
739
+ Returns:
740
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
741
+ pipeline stage.
742
+ """
743
+ gpc.set_virtual_pipeline_parallel_rank(chunk_id)
744
+
745
+ if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]):
746
+ self._input_objs[chunk_id].append(None)
747
+ input_obj = self._input_objs[chunk_id][-1]
748
+
749
+ micro_batch_data = self.load_micro_batch(chunk_id)
750
+ data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
751
+
752
+ self._call_hooks("before_forward", data)
753
+ output_obj = self._call_engine(engine.model[chunk_id], data)
754
+ # Convert output_obj to fp32 when last model chunk of last stage
755
+ if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
756
+ output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
757
+ self._call_hooks("after_forward", output_obj)
758
+
759
+ if gpc.is_pipeline_last_stage():
760
+ self._call_hooks("post_helper_func", output_obj, label)
761
+
762
+ if self._return_tensors is not None:
763
+ self._return_tensors.append((output_obj, label))
764
+ if self._accum_loss is not None:
765
+ self._call_hooks("before_criterion", output_obj, label)
766
+ loss = self._call_engine_criterion(engine.criterion, output_obj, label)
767
+ self._call_hooks("after_criterion", loss)
768
+
769
+ loss_reduced = loss / self.num_microbatches
770
+ self._accum_loss.add_(loss_reduced.detach())
771
+ output_obj = loss_reduced
772
+
773
+ self._output_objs[chunk_id].append(output_obj)
774
+
775
+ return output_obj
776
+
777
+ def _backward_step(self, engine, chunk_id, step_id):
778
+ """
779
+ Backward step for passed-in model. If it is the last stage, the input tensor
780
+ is obtained from the previous forward step, otherwise the passed-in input_obj is used.
781
+ Returns input tensor gradient. This is a helper function and can be ignored by users.
782
+
783
+ Args:
784
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
785
+ chunk_id (int): The id of model chunks.
786
+ step_id (int): The current step id.
787
+
788
+ Returns:
789
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient.
790
+ """
791
+ gpc.set_virtual_pipeline_parallel_rank(chunk_id)
792
+
793
+ if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0:
794
+ self._output_obj_grads[chunk_id].append(None)
795
+
796
+ input_obj = self._input_objs[chunk_id].pop(0)
797
+ output_obj = self._output_objs[chunk_id].pop(0)
798
+ output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
799
+
800
+ input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
801
+
802
+ return input_obj_grad
803
+
804
+ def _get_chunk_by_microbatch(self, step_id: int, backward: bool = False) -> int:
805
+ """Helper method to get the model chunk ID given the iteration number."""
806
+ microbatch_id_in_group = step_id % (self._pp_size * self._num_chunks)
807
+ chunk_id = microbatch_id_in_group // self._pp_size
808
+
809
+ if backward:
810
+ chunk_id = self._num_chunks - chunk_id - 1
811
+
812
+ return chunk_id
813
+
814
+ def _get_current_microbatch_id(self, step_id: int) -> int:
815
+ # format:
816
+ # microstep_id : 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
817
+ # microbatch_id: 1 2 3 4 1 2 3 4 5 6 7 8 5 6 7 8
818
+ num_microbatch_group = step_id // (self._pp_size * self._num_chunks)
819
+ step_id_in_group = step_id % (self._pp_size * self._num_chunks)
820
+
821
+ microbatch_id = num_microbatch_group * self._pp_size + step_id_in_group % self._pp_size
822
+
823
+ return microbatch_id
824
+
825
+ def _run_warmup_loop(
826
+ self,
827
+ engine: Engine,
828
+ num_microsteps: int,
829
+ num_warmup_microsteps: int,
830
+ receive_extra_backward: bool = False,
831
+ forward_only: bool = False,
832
+ ) -> None:
833
+ """
834
+ Run the warm-up loop and prepare data for the 1F1B stage.
835
+
836
+ During the warm-up process, for each execution, it first performs a forward computation,
837
+ and then sends the computation result to the next stage.
838
+ It also receives data for the next forward computation.
839
+ Since the input for the first forward computation is not considered initially,
840
+ it needs to receive data once at the beginning.
841
+
842
+ After the warm-up is completed, we need to prepare data for the 1F1B stage.
843
+ The data preparation process should be consistent with the communication method of the 1F1B stage.
844
+
845
+ Args:
846
+ engine (Engine): The engine to run the warm-up loop.
847
+ num_microsteps (int): The total number of microsteps.
848
+ num_warmup_microsteps (int): The number of warm-up microsteps.
849
+ receive_extra_backward (bool, optional): Whether to receive extra backward input for the 1F1B stage.
850
+ Default is False.
851
+ forward_only (bool, optional): Whether to only perform forward pass. Default is False.
852
+ """
853
+ if not gpc.is_pipeline_first_stage():
854
+ if self._input_obj_shapes[0] is None:
855
+ self._input_obj_shapes[0] = comm.recv_obj_meta(self._input_obj_shapes[0])
856
+ self._input_objs[0].append(
857
+ comm.recv_forward(
858
+ self._input_obj_shapes[0],
859
+ dtype=self.dtype,
860
+ scatter_gather_tensors=self.scatter_gather_tensors,
861
+ )
862
+ )
863
+ else:
864
+ self._input_objs[0].append(None)
865
+
866
+ for k in range(num_warmup_microsteps):
867
+ chunk_id = self._get_chunk_by_microbatch(k)
868
+
869
+ output_obj = self._forward_step(engine, chunk_id)
870
+
871
+ if forward_only:
872
+ # when forward-only, no need to save tensors for a backward pass
873
+ self._input_objs[chunk_id].pop()
874
+ self._output_objs[chunk_id].pop()
875
+
876
+ if not gpc.is_pipeline_last_stage():
877
+ if isinstance(output_obj, torch.Tensor):
878
+ self._output_obj_shapes[chunk_id] = output_obj.shape
879
+ else:
880
+ self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj]
881
+
882
+ if self._send_tensor_shape_flags[chunk_id]:
883
+ comm.send_obj_meta(output_obj)
884
+ self._send_tensor_shape_flags[chunk_id] = False # send only once for each chunk.
885
+
886
+ # Determine if tensor should be received from previous stage.
887
+ next_forward_chunk_id = self._get_chunk_by_microbatch(k + 1)
888
+
889
+ with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
890
+ if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[next_forward_chunk_id] is None:
891
+ self._input_obj_shapes[next_forward_chunk_id] = comm.recv_obj_meta()
892
+ if k == (num_microsteps - 1) or gpc.is_pipeline_first_stage():
893
+ input_shape = None
894
+ else:
895
+ input_shape = self._input_obj_shapes[next_forward_chunk_id]
896
+
897
+ # Don't send tensor downstream if on last stage.
898
+ if gpc.is_pipeline_last_stage():
899
+ output_obj = None
900
+
901
+ # Send and receive tensors as appropriate (send tensors computed
902
+ # in this iteration; receive tensors for next iteration).
903
+ if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
904
+ # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
905
+ input_obj = comm.send_forward_recv_forward(
906
+ output_obj,
907
+ input_shape,
908
+ dtype=self.dtype,
909
+ scatter_gather_tensors=self.scatter_gather_tensors,
910
+ )
911
+ else:
912
+ # Receive output_obj_grad for next backward, if receive_extra_backward is True.
913
+ if self._communication_overlap:
914
+ # In this case, we should handle forward and backward communication separately, consistent with the
915
+ # overlap version of the 1F1B stage
916
+ input_obj = comm.send_forward_recv_forward(
917
+ output_obj,
918
+ input_shape,
919
+ dtype=self.dtype,
920
+ scatter_gather_tensors=self.scatter_gather_tensors,
921
+ )
922
+ output_obj_grad = comm.send_backward_recv_backward(
923
+ None, # nothing to send
924
+ self._output_obj_shapes[self._num_chunks - 1],
925
+ dtype=self.dtype,
926
+ scatter_gather_tensors=self.scatter_gather_tensors,
927
+ )
928
+ self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
929
+ else:
930
+ # In this case, we should handle forward and backward communication together, consistent with the
931
+ # non-overlap version of the 1F1B stage
932
+ input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
933
+ output_obj,
934
+ None, # no backward grad to send
935
+ input_shape,
936
+ self._output_obj_shapes[self._num_chunks - 1],
937
+ dtype=self.dtype,
938
+ scatter_gather_tensors=self.scatter_gather_tensors,
939
+ )
940
+ self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
941
+
942
+ self._input_objs[next_forward_chunk_id].append(input_obj)
943
+
944
+ def _run_1f1b_loop_with_overlap(
945
+ self,
946
+ engine: Engine,
947
+ num_warmup_microsteps: int,
948
+ num_1f1b_micropairs: int,
949
+ all_warmup_microsteps: bool = False,
950
+ ) -> None:
951
+ """
952
+ Run the 1F1B loop with overlap.
953
+
954
+ The 1F1B loop with overlap consists of the following steps:
955
+ 1. Perform the forward pass.
956
+ 2. Check if the backward input is ready.
957
+ 3. Send the forward output and receive the forward input for the next iteration.
958
+ 4. Perform the backward pass.
959
+ 5. Check if the forward input is ready.
960
+ 6. Send the backward output and receive the backward input for the next iteration.
961
+
962
+ Args:
963
+ engine (Engine): The engine to run the 1F1B loop.
964
+ num_warmup_microsteps (int): The number of warm-up microsteps.
965
+ num_1f1b_micropairs (int): The number of 1F1B micropairs.
966
+ all_warmup_microsteps (bool, optional): Whether to run all warm-up microsteps. Default is False.
967
+ """
968
+
969
+ backward_async_communicator = None
970
+
971
+ # Run 1F1B in steady state.
972
+ for k in range(num_1f1b_micropairs):
973
+ forward_microstep_id = k + num_warmup_microsteps
974
+ backward_microstep_id = k
975
+ forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
976
+ backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
977
+
978
+ # 1. Forward pass.
979
+ output_obj = self._forward_step(engine, forward_chunk_id)
980
+
981
+ # 2. Check if the backward input is ready.
982
+ if backward_async_communicator is not None:
983
+ output_obj_grad = backward_async_communicator.wait_and_receive()
984
+
985
+ if backward_async_communicator.need_receive:
986
+ self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
987
+
988
+ # 3. Send the forward outputs and receive the forward inputs from the previous rank.
989
+
990
+ # Check if it is the last model chunk of the last pipeline stage, no need to send forward output.
991
+ gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
992
+ if gpc.is_pipeline_last_stage():
993
+ output_obj = None
994
+
995
+ # Check if it needs to receive the results from the previous rank.
996
+ next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
997
+ with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
998
+ if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
999
+ input_obj_shape = None
1000
+ else:
1001
+ input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
1002
+
1003
+ forward_async_communicator = comm.AsynCommunicator(
1004
+ output_obj,
1005
+ input_obj_shape,
1006
+ self.dtype,
1007
+ self.scatter_gather_tensors,
1008
+ forward=True,
1009
+ )
1010
+ forward_async_communicator.start()
1011
+
1012
+ # 5. Backward pass.
1013
+
1014
+ input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
1015
+
1016
+ input_obj = forward_async_communicator.wait_and_receive()
1017
+ if forward_async_communicator.need_receive:
1018
+ self._input_objs[next_forward_chunk_id].append(input_obj)
1019
+
1020
+ # 6. Send the backward output and receive the backward input for the next iteration.
1021
+ gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
1022
+ if gpc.is_pipeline_first_stage():
1023
+ input_obj_grad = None
1024
+
1025
+ next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
1026
+ with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
1027
+ if gpc.is_pipeline_last_stage():
1028
+ output_obj_shape = None
1029
+ else:
1030
+ output_obj_shape = self._output_obj_shapes[next_backward_chunk_id]
1031
+
1032
+ backward_async_communicator = comm.AsynCommunicator(
1033
+ input_obj_grad,
1034
+ output_obj_shape,
1035
+ self.dtype,
1036
+ self.scatter_gather_tensors,
1037
+ forward=False,
1038
+ )
1039
+ backward_async_communicator.start()
1040
+
1041
+ if all_warmup_microsteps:
1042
+ if not gpc.is_pipeline_last_stage():
1043
+ self._output_obj_grads[self._num_chunks - 1].append(
1044
+ comm.recv_backward(
1045
+ self._output_obj_shapes[self._num_chunks - 1],
1046
+ dtype=self.dtype,
1047
+ scatter_gather_tensors=self.scatter_gather_tensors,
1048
+ )
1049
+ )
1050
+ else:
1051
+ self._output_obj_grads[self._num_chunks - 1].append(None)
1052
+ else:
1053
+ output_obj_grad = backward_async_communicator.wait_and_receive()
1054
+ if backward_async_communicator.need_receive:
1055
+ backward_chunk_id = self._get_chunk_by_microbatch(num_1f1b_micropairs, backward=True)
1056
+ self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
1057
+
1058
+ def _run_1f1b_loop_without_overlap(
1059
+ self,
1060
+ engine: Engine,
1061
+ num_warmup_microsteps: int,
1062
+ num_1f1b_micropairs: int,
1063
+ all_warmup_microsteps: bool = False,
1064
+ ) -> None:
1065
+ """
1066
+ Run the 1F1B loop without overlap.
1067
+
1068
+ The 1F1B loop without overlap consists of the following steps:
1069
+ 1. Perform the forward pass.
1070
+ 2. Perform the backward pass.
1071
+ 3. Send the forward output of this iteration to the next stage, and send the backward output of this iteration
1072
+ to the previous stage, and receive the forward and backward inputs for the next iteration.
1073
+
1074
+ Args:
1075
+ engine (Engine): The engine to use for computation.
1076
+ num_warmup_microsteps (int): The number of warmup microsteps.
1077
+ num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
1078
+ all_warmup_microsteps (bool, optional): Whether to run all warmup microsteps. Defaults to False.
1079
+ """
1080
+ for k in range(num_1f1b_micropairs):
1081
+ # Forward pass.
1082
+ forward_microstep_id = k + num_warmup_microsteps
1083
+ forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
1084
+ output_obj = self._forward_step(engine, forward_chunk_id)
1085
+
1086
+ # Backward pass.
1087
+ backward_microstep_id = k
1088
+ backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
1089
+ input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
1090
+
1091
+ # Send output_obj and input_obj_grad, receive input_obj
1092
+ # and output_obj_grad.
1093
+
1094
+ # Determine if current stage has anything to send in either direction,
1095
+ # otherwise set obj to None.
1096
+ gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
1097
+ if gpc.is_pipeline_last_stage():
1098
+ output_obj = None
1099
+
1100
+ gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
1101
+ if gpc.is_pipeline_first_stage():
1102
+ input_obj_grad = None
1103
+
1104
+ # Determine if peers are sending, and where in data structure to put
1105
+ # received tensors.
1106
+ next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
1107
+ with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
1108
+ if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
1109
+ recv_prev = False
1110
+ else:
1111
+ recv_prev = True
1112
+
1113
+ next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
1114
+ with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
1115
+ if gpc.is_pipeline_last_stage():
1116
+ recv_next = False
1117
+ else:
1118
+ recv_next = True
1119
+
1120
+ input_shape = self._input_obj_shapes[next_forward_chunk_id] if recv_prev else None
1121
+ output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
1122
+
1123
+ # Communicate objs.
1124
+ input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
1125
+ output_obj,
1126
+ input_obj_grad,
1127
+ input_shape,
1128
+ output_shape,
1129
+ dtype=self.dtype,
1130
+ scatter_gather_tensors=self.scatter_gather_tensors,
1131
+ )
1132
+
1133
+ # Put input_obj and output_obj_grad in data structures in the
1134
+ # right location.
1135
+ if recv_prev:
1136
+ self._input_objs[next_forward_chunk_id].append(input_obj)
1137
+ if recv_next:
1138
+ self._output_obj_grads[next_backward_chunk_id].append(output_obj_grad)
1139
+
1140
+ # receive necessary data for next cooldown loop
1141
+ if all_warmup_microsteps:
1142
+ if not gpc.is_pipeline_last_stage():
1143
+ self._output_obj_grads[self._num_chunks - 1].append(
1144
+ comm.recv_backward(
1145
+ self._output_obj_shapes[self._num_chunks - 1],
1146
+ dtype=self.dtype,
1147
+ scatter_gather_tensors=self.scatter_gather_tensors,
1148
+ )
1149
+ )
1150
+ else:
1151
+ self._output_obj_grads[self._num_chunks - 1].append(None)
1152
+
1153
+ def _run_cooldown_loop(self, engine: Engine, num_microsteps: int, num_1f1b_micropairs: int) -> None:
1154
+ """
1155
+ Run the cooldown loop.
1156
+
1157
+ The cooldown loop consists of the following steps:
1158
+ 1. Perform the backward step.
1159
+ 2. Send the backward output to the next stage and receive inputs for next backward.
1160
+
1161
+ Args:
1162
+ engine (Engine): The engine to use for computation.
1163
+ num_microsteps (int): The total number of microsteps.
1164
+ num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
1165
+ """
1166
+ for k in range(num_1f1b_micropairs, num_microsteps):
1167
+ chunk_id = self._get_chunk_by_microbatch(k, backward=True)
1168
+
1169
+ input_obj_grad = self._backward_step(engine, chunk_id, k)
1170
+
1171
+ next_backward_chunk_id = self._get_chunk_by_microbatch(k + 1, backward=True)
1172
+
1173
+ if k != (num_microsteps - 1) and not (
1174
+ gpc.is_pipeline_last_stage(ignore_virtual=True) and next_backward_chunk_id == (self._num_chunks - 1)
1175
+ ):
1176
+ output_shape = self._output_obj_shapes[next_backward_chunk_id]
1177
+ else:
1178
+ output_shape = None
1179
+
1180
+ self._output_obj_grads[next_backward_chunk_id].append(
1181
+ comm.send_backward_recv_backward(
1182
+ input_obj_grad,
1183
+ output_shape,
1184
+ dtype=self.dtype,
1185
+ scatter_gather_tensors=self.scatter_gather_tensors,
1186
+ )
1187
+ )
1188
+
1189
+ def _forward_only_step(self, engine: Engine):
1190
+ num_microsteps = self.num_microbatches * self._num_chunks
1191
+ num_warmup_microsteps = num_microsteps
1192
+
1193
+ self._run_warmup_loop(
1194
+ engine,
1195
+ num_microsteps,
1196
+ num_warmup_microsteps,
1197
+ receive_extra_backward=False,
1198
+ forward_only=True,
1199
+ )
1200
+
1201
+ def _forward_backward_step(self, engine: Engine):
1202
+ # Compute number of warmup and remaining microbatches.
1203
+ all_warmup_microsteps = False
1204
+ num_microsteps = self.num_microbatches * self._num_chunks
1205
+
1206
+ # Run all forward passes and then all backward passes if number of
1207
+ # microbatches is just the number of pipeline stages.
1208
+ # Otherwise, perform (num_chunks-1)*pipeline_parallel_size on
1209
+ # all workers, followed by more microbatches after depending on
1210
+ # stage ID (more forward passes for earlier stages, later stages can
1211
+ # immediately start with 1F1B).
1212
+ if self.num_microbatches == self._pp_size:
1213
+ num_warmup_steps = num_microsteps
1214
+ all_warmup_microsteps = True
1215
+ else:
1216
+ num_warmup_steps = (self._pp_size - self._pp_rank - 1) * 2
1217
+ num_warmup_steps += (self._num_chunks - 1) * self._pp_size
1218
+ num_warmup_steps = min(num_warmup_steps, num_microsteps)
1219
+ num_1f1b_micropairs = num_microsteps - num_warmup_steps
1220
+
1221
+ # We usually need to prepare an extra backward data for the 1F1B stage when the WarmUp stage ends,
1222
+ # because the 1F1B stage typically performs one forward and backward pass together,
1223
+ # except in the following cases:
1224
+ receive_extra_backward = not (
1225
+ all_warmup_microsteps # Only warmup microsteps
1226
+ or gpc.is_pipeline_last_stage(ignore_virtual=True) # The rank is the last pipeline stage
1227
+ )
1228
+
1229
+ # 1. Warmup
1230
+ self._run_warmup_loop(
1231
+ engine,
1232
+ num_microsteps,
1233
+ num_warmup_steps,
1234
+ receive_extra_backward=receive_extra_backward,
1235
+ )
1236
+
1237
+ # 2. 1F1B
1238
+ self._run_1f1b_loop(
1239
+ engine,
1240
+ num_warmup_steps,
1241
+ num_1f1b_micropairs=num_1f1b_micropairs,
1242
+ all_warmup_microsteps=all_warmup_microsteps,
1243
+ )
1244
+
1245
+ # 3. Cooldown
1246
+ self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
1247
+
1248
+ @llm_timeout(func_name="interleaved_forward_backward_step")
1249
+ def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
1250
+ """Run interleaved 1F1B schedule (model split into model chunks), with
1251
+ communication between pipeline stages as needed.
1252
+
1253
+ Args:
1254
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
1255
+ data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
1256
+ forward_only (bool, optional):
1257
+ Whether run forward step only. Default is false. If true, no backward will be run.
1258
+ return_loss (bool, optional): Whether returns the loss value. Default is true.
1259
+ return_output_label (bool, optional): If False, the output and label won't be returned.
1260
+
1261
+ Returns:
1262
+ Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
1263
+ The loss would be returned only in the last stage.
1264
+ """
1265
+ assert (
1266
+ forward_only or return_loss
1267
+ ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
1268
+
1269
+ gpc.set_virtual_pipeline_parallel_rank(0)
1270
+
1271
+ self.load_batch(engine, data_iter)
1272
+
1273
+ if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
1274
+ self._accum_loss = torch.zeros(1, device=get_current_device())
1275
+ if return_output_label:
1276
+ self._return_tensors = []
1277
+
1278
+ if forward_only:
1279
+ self._forward_only_step(engine)
1280
+ else:
1281
+ self._forward_backward_step(engine)
1282
+
1283
+ if return_output_label and len(self._return_tensors) > 0:
1284
+ output, label = pack_return_tensors(self._return_tensors)
1285
+ else:
1286
+ output, label = (None, None)
1287
+ accum_loss = self._accum_loss
1288
+
1289
+ self._clear_state()
1290
+
1291
+ return output, label, accum_loss
1292
+
1293
+
1294
+ class KDPipelineScheduler(PipelineScheduler):
1295
+
1296
+ def __init__(
1297
+ self,
1298
+ num_microbatches: int,
1299
+ dtype: torch.dtype = torch.float,
1300
+ data_process_func: Callable = None,
1301
+ tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
1302
+ scatter_gather_tensors: bool = False,
1303
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
1304
+ ):
1305
+ super().__init__(
1306
+ num_microbatches=num_microbatches,
1307
+ dtype=dtype,
1308
+ data_process_func=data_process_func,
1309
+ tensor_shape=tensor_shape,
1310
+ scatter_gather_tensors=scatter_gather_tensors,
1311
+ scheduler_hooks=scheduler_hooks,
1312
+ )
1313
+
1314
+ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
1315
+ """
1316
+ Forward step for passed-in model. If it is the first stage, the input tensor
1317
+ is obtained from data_iterator, otherwise the passed-in input_obj is used.
1318
+ Returns output tensor. This is a helper function and can be ignored by users.
1319
+
1320
+ Args:
1321
+ engine (colossalai.engine.Engine): Colossalai engine for training and inference.
1322
+ input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
1323
+ return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
1324
+ return_output_label (bool, optional): Whether returns output labels.
1325
+ accum_loss (optional): Where accumulated loss stores.
1326
+ Returns:
1327
+ Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
1328
+ pipeline stage.
1329
+ """
1330
+ micro_batch_data = self.load_micro_batch()
1331
+ data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
1332
+
1333
+ self._call_hooks("before_forward", data)
1334
+ output_obj = self._call_engine(engine.model, data)
1335
+ self._call_hooks("after_forward", output_obj)
1336
+
1337
+ if gpc.is_last_rank(ParallelMode.PIPELINE):
1338
+ self._call_hooks("post_helper_func", output_obj, label)
1339
+ if return_output_label:
1340
+ return_tensors.append((output_obj, label))
1341
+ if accum_loss is not None:
1342
+ self._call_hooks("before_criterion", output_obj, label)
1343
+ loss_gt = gpc.config.kd_config['gt_weight'] * self._call_engine_criterion(engine.criterion, output_obj,
1344
+ label)
1345
+
1346
+ with torch.no_grad():
1347
+ engine.teacher.eval()
1348
+ output_obj_t = self._call_engine(engine.teacher, data)
1349
+
1350
+ loss_kd = gpc.config.kd_config['kd_weight'] * self._call_engine_criterion(engine.kd_criterion,
1351
+ output_obj,
1352
+ (output_obj_t, label))
1353
+ # loss = (gpc.config.kd_config['gt_weight'] * loss_gt +
1354
+ # gpc.config.kd_config['kd_weight'] * loss_kd)
1355
+ self._call_hooks("after_criterion", loss_gt + loss_kd)
1356
+
1357
+ loss_gt_reduced = loss_gt / self.num_microbatches
1358
+ loss_kd_reduced = loss_kd / self.num_microbatches
1359
+ accum_loss['loss_gt'].add_(loss_gt_reduced.detach())
1360
+ accum_loss['loss_kd'].add_(loss_kd_reduced.detach())
1361
+ output_obj = loss_gt_reduced + loss_kd_reduced
1362
+
1363
+ return output_obj
InternLM/internlm/core/trainer.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
5
+
6
+ import json
7
+ from typing import Iterable, Optional
8
+
9
+ from internlm.core.engine import Engine
10
+ from internlm.core.scheduler import (
11
+ BaseScheduler,
12
+ InterleavedPipelineScheduler,
13
+ NonPipelineScheduler,
14
+ PipelineScheduler,
15
+ )
16
+
17
+
18
+ class TrainState:
19
+ """
20
+ The TrainState class is used to record the current state of training.
21
+
22
+ Args:
23
+ train_dl (DataLoader): The DataLoader object used for training.
24
+ """
25
+
26
+ def __init__(self, config, batch_sampler) -> None:
27
+ """
28
+ Args:
29
+ config (Config): internlm config
30
+ batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
31
+ asynchronous and prefetched, the batch_sampler state maintained inside the
32
+ dataloader are faster then the actual training progress, so we copy the
33
+ batch_sampler as the anchor point of ckpt reload.
34
+ """
35
+ # The number of batches produced by the data iterator
36
+ self.batch_count: int = 0
37
+ # Used to store the number of samples consumed in the current epoch
38
+ self.num_consumed_samples_in_epoch: int = 0
39
+ # Total number of tokens consumed
40
+ self.num_consumed_tokens: int = 0
41
+ # Number of batches skipped due to inf or nan values
42
+ self.inf_nan_skip_batches: int = 0
43
+ # Records the number of updates, skipped batches and inf batches are not counted
44
+ self.step_count: int = 0
45
+
46
+ # Total step count
47
+ self.total_steps: int = config.data.total_steps
48
+
49
+ # resume tensorboard folder, need load from checkpoint or set manually.
50
+ self.resume_tb_folder = config.resume_tb_folder
51
+
52
+ self.tensorboard_folder = config.tensorboard_folder
53
+
54
+ # learning rate
55
+ self.lr = config.adam.lr
56
+
57
+ # smapler state
58
+ if batch_sampler:
59
+ self.init_batch_sampler(batch_sampler)
60
+
61
+ def init_batch_sampler(self, batch_sampler):
62
+ """
63
+ Args:
64
+ batch_sampler (torch.utils.data.Sampler): sampler.
65
+ """
66
+ # make a copy of batch_sampler.
67
+ self.batch_sampler = batch_sampler.copy()
68
+ # Iterator for the batch sampler
69
+ self.batch_sampler_iter = iter(self.batch_sampler)
70
+
71
+ def __str__(self) -> str:
72
+ """Returns a string representation of the training state in JSON format."""
73
+ info = {
74
+ "batch_count": self.batch_count,
75
+ "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
76
+ "num_consumed_tokens": self.num_consumed_tokens,
77
+ "inf_nan_skip_batches": self.inf_nan_skip_batches,
78
+ "step_count": self.step_count,
79
+ }
80
+
81
+ return json.dumps(info, indent=4, sort_keys=True)
82
+
83
+ def load_state_dict(self, other_stuffs):
84
+ """
85
+ Resumes training from a checkpoint.
86
+
87
+ Args:
88
+ other_stuffs (dict): Other information needed to resume training.
89
+ """
90
+ self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
91
+ self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
92
+ self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
93
+
94
+ # Because the ckpt save occurs after updating 'step_count',
95
+ # there is no need to increment 'step_count' here (Does our step count start from 0 ?),
96
+ # However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
97
+ self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
98
+ self.step_count = other_stuffs.get("step_count", self.batch_count)
99
+
100
+ # resume tensorboard from older tensorboard_folder
101
+ self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
102
+
103
+ def state_dict(self):
104
+ return {
105
+ "batch_count": self.batch_count,
106
+ "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
107
+ "num_consumed_tokens": self.num_consumed_tokens,
108
+ "inf_nan_skip_batches": self.inf_nan_skip_batches,
109
+ "step_count": self.step_count,
110
+ "tensorboard_folder": self.tensorboard_folder,
111
+ }
112
+
113
+
114
+ class Trainer:
115
+ """This is a class tending for easy deployments of users' training and evaluation instead of
116
+ writing their own scripts.
117
+
118
+ Args:
119
+ engine (:class:`Engine`): Engine responsible for the process function.
120
+ schedule (:class:`BaseScheduler`, optional): Runtime schedule. Defaults to None.
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ engine: Engine,
126
+ schedule: Optional[BaseScheduler] = None,
127
+ ):
128
+ """Initializes the Trainer class.
129
+
130
+ Args:
131
+ engine (Engine): The engine responsible for the process function.
132
+ schedule (Optional[BaseScheduler], optional): The runtime schedule. Defaults to None.
133
+ """
134
+ self._engine = engine
135
+
136
+ # build schedule
137
+ if schedule is None:
138
+ self._schedule = NonPipelineScheduler()
139
+ else:
140
+ assert isinstance(
141
+ schedule, BaseScheduler
142
+ ), f"expected schedule to be of type BaseSchedule, but got {type(schedule)}"
143
+ self._schedule = schedule
144
+
145
+ self._schedule.pre_processing(self._engine)
146
+
147
+ @property
148
+ def engine(self):
149
+ """Returns the engine that responsible for managing the training and evaluation process."""
150
+ return self._engine
151
+
152
+ @property
153
+ def schedule(self):
154
+ """Returns the runtime scheduler."""
155
+ return self._schedule
156
+
157
+ @property
158
+ def uses_pipeline(self):
159
+ """Returns whether the pipeline parallel is used or not."""
160
+ return isinstance(self._schedule, (PipelineScheduler, InterleavedPipelineScheduler))
161
+
162
+ def train(self):
163
+ """Sets the model to training mode."""
164
+ self._engine.train()
165
+
166
+ def eval(self):
167
+ """Sets the model to evaluation mode."""
168
+ self._engine.eval()
169
+
170
+ def zero_grad(self):
171
+ """Sets the gradient of all parameters in the model to zero."""
172
+ self._engine.zero_grad()
173
+
174
+ def step(self):
175
+ """Executes the parameter update step."""
176
+ return self._engine.step()
177
+
178
+ def execute_schedule(self, data_iter: Iterable, **kwargs):
179
+ """Runs the forward, loss computation, and backward for the model.
180
+ Returns a tuple of (output, label, loss).
181
+
182
+ Args:
183
+ data_iter (Iterable): The data iterator.
184
+ **kwargs: Additional keyword arguments.
185
+
186
+ Returns:
187
+ Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss).
188
+ """
189
+ output, label, loss = self._schedule.forward_backward_step(self._engine, data_iter, **kwargs)
190
+ return output, label, loss
InternLM/internlm/data/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .batch_sampler import get_dpsampler_dataloader
2
+ from .collaters import jsonl_ds_collate_fn, packed_collate_fn
3
+ from .dummy_dataset import RandomDataset
4
+ from .packed_dataset import PackedDataset, PackedDatasetWithoutCuSeqlen
5
+
6
+ __all__ = [
7
+ "jsonl_ds_collate_fn",
8
+ "packed_collate_fn",
9
+ "RandomDataset",
10
+ "PackedDataset",
11
+ "PackedDatasetWithoutCuSeqlen",
12
+ "get_dpsampler_dataloader",
13
+ ]
InternLM/internlm/data/batch_sampler.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import math
5
+ import random
6
+ from typing import Iterator, TypeVar
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data import DataLoader, Dataset, Sampler
11
+
12
+ from internlm.core.context import ParallelMode
13
+ from internlm.core.context import global_context as gpc
14
+ from internlm.utils.logger import get_logger
15
+
16
+ logger = get_logger(__file__)
17
+
18
+ T_co = TypeVar("T_co", covariant=True)
19
+
20
+
21
+ class DataParallelSampler(Sampler):
22
+ """A data sampler for distributed data parallelism.
23
+
24
+ Args:
25
+ dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling.
26
+ shuffle (bool, optional): Whether to shuffle data, defaults to False.
27
+ seed (int, optional): The random seed used for sampling, defaults to 0.
28
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
29
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
30
+ the batch size, then the last batch will be smaller, defaults to False.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ dataset: Dataset,
36
+ shuffle: bool = False,
37
+ seed: int = 0,
38
+ drop_last: bool = False,
39
+ ) -> None:
40
+ self.dataset = dataset
41
+ self.num_replicas = gpc.get_world_size(ParallelMode.DATA)
42
+ self.rank = gpc.get_local_rank(ParallelMode.DATA)
43
+ self.epoch = 0
44
+ self.drop_last = drop_last
45
+ # If the dataset length is evenly divisible by # of replicas, then there
46
+ # is no need to drop any data, since the dataset will be split equally.
47
+ # type: ignore[arg-type]
48
+ if self.drop_last and len(self.dataset) % self.num_replicas != 0:
49
+ # Split to nearest available length that is evenly divisible.
50
+ # This is to ensure each rank receives the same amount of data when
51
+ # using this Sampler.
52
+ self.num_samples = math.ceil(
53
+ # `type:ignore` is required because Dataset cannot provide a default __len__
54
+ # see NOTE in pytorch/torch/utils/data/sampler.py
55
+ (len(self.dataset) - self.num_replicas)
56
+ / self.num_replicas # type: ignore[arg-type]
57
+ )
58
+ else:
59
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
60
+ self.total_size = self.num_samples * self.num_replicas
61
+ self.shuffle = shuffle
62
+ self.seed = seed
63
+
64
+ def __iter__(self) -> Iterator[T_co]:
65
+ if self.shuffle:
66
+ # deterministically shuffle based on epoch and seed
67
+ g = torch.Generator()
68
+ g.manual_seed(self.seed + self.epoch)
69
+ # type: ignore[arg-type]
70
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
71
+
72
+ # update for next epoch so that there is no need to call
73
+ # set_epoch manually
74
+ self.epoch += 1
75
+ else:
76
+ indices = list(range(len(self.dataset))) # type: ignore[arg-type]
77
+
78
+ if not self.drop_last:
79
+ # add extra samples to make it evenly divisible
80
+ padding_size = self.total_size - len(indices)
81
+ if padding_size <= len(indices):
82
+ indices += indices[:padding_size]
83
+ else:
84
+ indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
85
+ else:
86
+ # remove tail of data to make it evenly divisible.
87
+ indices = indices[: self.total_size]
88
+ assert len(indices) == self.total_size
89
+
90
+ # subsample
91
+ indices = indices[self.rank : self.total_size : self.num_replicas]
92
+ assert len(indices) == self.num_samples
93
+
94
+ return iter(indices)
95
+
96
+ def __len__(self) -> int:
97
+ return self.num_samples
98
+
99
+ def set_epoch(self, epoch: int) -> None:
100
+ r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
101
+ use a different random ordering for each epoch. Otherwise, the next iteration of this
102
+ sampler will yield the same ordering.
103
+
104
+ Args:
105
+ epoch (int): Epoch number.
106
+ """
107
+ self.epoch = epoch
108
+
109
+
110
+ def get_dpsampler_dataloader(
111
+ dataset,
112
+ shuffle=False,
113
+ seed=1024,
114
+ add_sampler=True,
115
+ drop_last=False,
116
+ pin_memory=False,
117
+ num_workers=0,
118
+ **kwargs,
119
+ ):
120
+ r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
121
+
122
+ Note:
123
+ When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
124
+ on the 1st stage and label on the last stage.
125
+
126
+ Args:
127
+ dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded.
128
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
129
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
130
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
131
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
132
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
133
+ the batch size, then the last batch will be smaller, defaults to False.
134
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
135
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
136
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
137
+ `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
138
+
139
+ Returns:
140
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
141
+ """
142
+ _kwargs = kwargs.copy()
143
+
144
+ if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
145
+ sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last)
146
+ else:
147
+ sampler = None
148
+
149
+ # Deterministic dataloader
150
+ def seed_worker():
151
+ worker_seed = seed
152
+ np.random.seed(worker_seed)
153
+ torch.manual_seed(worker_seed)
154
+ random.seed(worker_seed)
155
+
156
+ if sampler is None:
157
+ return DataLoader(
158
+ dataset,
159
+ worker_init_fn=seed_worker,
160
+ shuffle=shuffle,
161
+ drop_last=drop_last,
162
+ pin_memory=pin_memory,
163
+ num_workers=num_workers,
164
+ **_kwargs,
165
+ )
166
+ else:
167
+ return DataLoader(
168
+ dataset,
169
+ sampler=sampler,
170
+ worker_init_fn=seed_worker,
171
+ drop_last=drop_last,
172
+ pin_memory=pin_memory,
173
+ num_workers=num_workers,
174
+ **_kwargs,
175
+ )
176
+
177
+
178
+ class StaticBatchSampler:
179
+ """
180
+ A static batch sampler that generates batches with a fixed micro-batch size.
181
+
182
+ Args:
183
+ num_samples (int): The total number of samples in the dataset.
184
+ batch_size (int): The batch size for the current rank. Defaults to 192.
185
+ rampup_batch_size (str): A string with three space-separated integers representing the
186
+ starting batch size, the increment, and the number of steps between
187
+ each increment. For tools, "192 24 8" means that the batch size
188
+ starts at 192 and increases by 24 every 8 steps. Defaults to
189
+ "6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
190
+ micro_bsz (int): The micro-batch size. Defaults to 2.
191
+ seed (int): The random seed for shuffling the indices. Defaults to 0.
192
+ drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
193
+ data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
194
+ data_world_size (int): The number of processes in the data parallel group. Defaults to 1.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ datasets,
200
+ batch_size=192,
201
+ rampup_batch_size="6 2 8",
202
+ micro_bsz=2,
203
+ seed=0,
204
+ drop_last=True,
205
+ data_rank=0,
206
+ data_world_size=1,
207
+ ):
208
+ assert drop_last is True, "Currently only support drop last"
209
+ if rampup_batch_size:
210
+ # In the process increase to batch_size
211
+ start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
212
+ else:
213
+ start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
214
+ self.raw_rampup_batch_size = rampup_batch_size
215
+ self.start_bsz = start_bsz
216
+ self.bsz_incre = bsz_incre
217
+ self.incre_every = incre_every
218
+ if gpc.is_initialized(ParallelMode.PIPELINE):
219
+ assert (
220
+ batch_size - self.start_bsz
221
+ ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
222
+ assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
223
+ assert (
224
+ self.start_bsz % micro_bsz == 0
225
+ ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
226
+ assert (
227
+ self.bsz_incre % micro_bsz == 0
228
+ ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
229
+
230
+ self.batch_size = batch_size
231
+ self.epoch = 0
232
+ self.seed = seed
233
+ self.rng = np.random.RandomState(seed)
234
+ self.batch_count = 0
235
+ self.micro_bsz = micro_bsz
236
+ self.data_rank = data_rank
237
+ self.data_world_size = data_world_size
238
+ self.num_consumed_samples_in_epoch = 0
239
+ self.datasets = datasets
240
+ self.num_samples = sum([len(ds) for ds in datasets])
241
+
242
+ self.get_indices() # get data
243
+
244
+ def get_indices(self, old_indices=None):
245
+ if old_indices is not None:
246
+ assert (
247
+ len(old_indices) <= self.num_samples
248
+ ), f"The checkpoint has {len(old_indices)} samples, \
249
+ while the new restart use less samples ({self.num_samples})"
250
+
251
+ else:
252
+ old_indices = np.array([])
253
+
254
+ # indices includes len(old_indices) but not self.num_samples
255
+ indices = np.arange(len(old_indices), self.num_samples)
256
+ self.rng_state = self.rng.get_state()
257
+ self.rng.shuffle(indices)
258
+ # Need to consider drop_last
259
+ ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
260
+ if self.batch_count < ramp_steps * self.incre_every:
261
+ rampup_samples = 0
262
+ for i in range(ramp_steps):
263
+ rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
264
+ assert (
265
+ rampup_samples * self.data_world_size <= self.num_samples
266
+ ), f"Too much rampup samples: \
267
+ {rampup_samples*self.data_world_size} Vs. self.num_samples: {self.num_samples}"
268
+
269
+ num_samples = (self.num_samples - rampup_samples * self.data_world_size) // (
270
+ self.batch_size * self.data_world_size
271
+ )
272
+ num_samples = num_samples * self.batch_size * self.data_world_size + rampup_samples * self.data_world_size
273
+ else:
274
+ num_samples = self.num_samples // (self.batch_size * self.data_world_size)
275
+ num_samples = num_samples * self.batch_size * self.data_world_size
276
+ indices = np.concatenate([old_indices, indices]).astype(int) # It needs to be spliced with the previous
277
+ indices = indices[:num_samples]
278
+ self.indices = indices
279
+ assert len(self.indices) >= self.batch_size, "The number of samples should be larger than batch_size"
280
+ self.num_consumed_samples_in_epoch = 0
281
+
282
+ def set_epoch(self, epoch):
283
+ self.epoch = epoch
284
+ self.rng = np.random.RandomState(self.seed + self.epoch)
285
+
286
+ def __len__(self):
287
+ ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre
288
+ if self.batch_count < ramp_steps * self.incre_every:
289
+ rampup_samples = 0
290
+ for i in range(ramp_steps):
291
+ rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every
292
+ assert (
293
+ rampup_samples * self.data_world_size <= self.num_samples
294
+ ), f"Too much rampup samples: {rampup_samples*self.data_world_size} \
295
+ Vs. self.num_samples: {self.num_samples}"
296
+
297
+ num_batches = (self.num_samples - rampup_samples * self.data_world_size) // self.batch_size
298
+ num_batches = num_batches // self.data_world_size + self.incre_every * ramp_steps
299
+ else:
300
+ num_batches = self.num_samples // self.batch_size // self.data_world_size
301
+
302
+ return num_batches
303
+
304
+ def __iter__(self):
305
+ indices = self.indices[self.data_rank :: self.data_world_size]
306
+ while self.num_consumed_samples_in_epoch < len(indices):
307
+ batch_rampup_idx = self.batch_count // self.incre_every
308
+ cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz
309
+ cur_batch_size = min(cur_batch_size, self.batch_size)
310
+ batch = indices[self.num_consumed_samples_in_epoch : self.num_consumed_samples_in_epoch + cur_batch_size]
311
+ yield batch
312
+ self.num_consumed_samples_in_epoch += len(batch) # Consider multiple processes.
313
+ self.batch_count += 1
314
+ self.get_indices() # get a new round
315
+
316
+ def state_dict(self):
317
+ states = {
318
+ "batch_size": self.batch_size,
319
+ "raw_rampup_batch_size": self.raw_rampup_batch_size,
320
+ "rng_state": self.rng_state,
321
+ "epoch": self.epoch,
322
+ "seed": self.seed,
323
+ "data_world_size": self.data_world_size,
324
+ "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch,
325
+ "batch_count": self.batch_count, # The batch_count here is due to the existence of multiple processes,
326
+ # the batch may be oversent, and it needs to be overwritten by the external batch_count
327
+ "indices": self.indices, # The sequence used to breakpoint retraining is the same as before
328
+ }
329
+
330
+ return states
331
+
332
+ def load_state_dict(self, states):
333
+ for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size'
334
+ assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change
335
+ self.rng.set_state(states["rng_state"])
336
+ self.get_indices(old_indices=None) # Regenerate indices based on random state
337
+ self.epoch = states["epoch"]
338
+ self.batch_count = states["batch_count"]
339
+ self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"]
340
+
341
+ def copy(self):
342
+ copy_sampler = StaticBatchSampler(
343
+ self.datasets,
344
+ self.batch_size,
345
+ self.raw_rampup_batch_size,
346
+ self.micro_bsz,
347
+ self.seed,
348
+ drop_last=True,
349
+ data_rank=self.data_rank,
350
+ data_world_size=self.data_world_size,
351
+ )
352
+
353
+ copy_sampler.load_state_dict(self.state_dict())
354
+ return copy_sampler
InternLM/internlm/data/collaters.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+
6
+
7
+ def packed_collate_fn(batch, packed_length):
8
+
9
+ """
10
+ Collate function for packed input sequences.
11
+
12
+ Args:
13
+ batch (List[Dict]): List of dictionaries representing each sample in batch.
14
+ Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys.
15
+ packed_length (int): The length of packed sequence.
16
+
17
+ Returns:
18
+ Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
19
+ "cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels".
20
+
21
+ Raises:
22
+ AssertionError: If the length of a sample is not equal to packed_length.
23
+ AssertionError: If the shape of the padded "input_ids" tensor does not have the correct shape.
24
+ """
25
+
26
+ xs, ys, cu_seqlens, indexes, ts = [], [], [], [], []
27
+ for b in batch:
28
+ assert (
29
+ len(b["tokens"]) == packed_length
30
+ ), f"length of a sample should be equal to packed_length, but got {len(b['tokens'])} and {packed_length})"
31
+ assert (
32
+ len(b["labels"]) == packed_length
33
+ ), f"length of a sample should be equal to packed_length, but got {len(b['labels'])} and {packed_length})"
34
+ assert (
35
+ len(b["type_ids"]) == packed_length
36
+ ), f"length of a sample should be equal to packed_length, but got {len(b['type_ids'])} and {packed_length})"
37
+
38
+ tokens = [abs(w) for w in b["tokens"]]
39
+ labels = [w if w > 0 else -100 for w in b["labels"]]
40
+
41
+ xs.append(torch.LongTensor(tokens))
42
+ # The labels have been shifted here, so they are aligned with the output corresponding to the token
43
+ ys.append(torch.LongTensor(labels))
44
+ ts.append(torch.LongTensor(b["type_ids"]))
45
+ cu_seqlens.append(torch.IntTensor(b["cu_seqlens"]))
46
+ indexes.append(torch.LongTensor(b["indexes"]))
47
+
48
+ xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
49
+ ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
50
+ ts = torch.nn.utils.rnn.pad_sequence(ts, batch_first=True, padding_value=0)
51
+ indexes = torch.stack(indexes, dim=0)
52
+ if len(set(map(len, cu_seqlens))) == 1: # if has uniform length, then stack to save device transfer time
53
+ cu_seqlens = torch.stack(cu_seqlens, dim=0)
54
+
55
+ assert xs.shape[1] == packed_length, (xs.shape[1], packed_length)
56
+
57
+ return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys
58
+
59
+
60
+ def jsonl_ds_collate_fn(batch, max_length_per_sample):
61
+ """
62
+ Collate function for json dataset.
63
+
64
+ Args:
65
+ batch (List[Dict]): List of dictionaries representing each sample in batch.
66
+ Each dictionary contains "tokens".
67
+ max_length_per_sample (int): The length of output sequence.
68
+
69
+ Returns:
70
+ Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids",
71
+ and the tensor of padded "labels".
72
+
73
+ """
74
+ xs, ys = [], []
75
+ for x in batch:
76
+ x["tokens"] = x["tokens"][:max_length_per_sample]
77
+ tokens = [abs(w) for w in x["tokens"]]
78
+ labels = [w if w > 0 else -100 for w in x["tokens"]]
79
+ labels = labels[1:] + [-100]
80
+ xs.append(torch.as_tensor(tokens))
81
+ ys.append(torch.as_tensor(labels)) # y has been shifted
82
+ xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)
83
+ ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100)
84
+
85
+ xs = torch.cat([xs, xs.new_zeros(len(xs), max_length_per_sample - len(xs[0]))], dim=-1)
86
+ ys = torch.cat([ys, ys.new_full((len(ys), max_length_per_sample - len(ys[0])), fill_value=-100)], dim=-1)
87
+
88
+ return {"input_ids": xs}, ys
InternLM/internlm/data/dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+
4
+ from torch.utils.data import ConcatDataset
5
+
6
+ from internlm.data.single_dataset import JsonlDataset
7
+
8
+
9
+ def get_dataset_dict(folder, split="valid") -> Dict:
10
+ """
11
+ Return a dictionary of Datasets from a folder containing data files for validation.
12
+
13
+ Args:
14
+ folder (str): The path to the folder containing data files.
15
+ split (str): The split of the data files to be used, default is "valid".
16
+
17
+ Returns:
18
+ A dictionary containing Datasets for each folder in the given path
19
+ that contains data files with the specified split.
20
+
21
+ Raises:
22
+ AssertionError: If the given folder does not exist.
23
+
24
+ Example:
25
+ If the given folder is as follows,
26
+ - data
27
+ - zhihu
28
+ - xxx.bin
29
+ - valid.bin
30
+ - baike
31
+ - xxx.bin
32
+ - valid.bin
33
+
34
+ The returned dictionary will be,
35
+ {
36
+ 'zhihu': Dataset,
37
+ 'baike': Dataset
38
+ }
39
+ """
40
+
41
+ assert os.path.exists(folder), f"folder `{folder}` not exists"
42
+ data_dict = {}
43
+
44
+ for root, dirs, files in os.walk(folder, followlinks=True):
45
+ dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind
46
+ datasets = []
47
+ for fn in sorted(files): # Need sorted to ensure that the order is consistent
48
+ if fn.endswith(".bin") and split in fn:
49
+ fp = os.path.join(root, fn)
50
+ ds = JsonlDataset(fp)
51
+ datasets.append(ds)
52
+ if datasets:
53
+ ds = ConcatDataset(datasets=datasets)
54
+ data_dict[os.path.basename(root)] = ds
55
+
56
+ return data_dict
InternLM/internlm/data/dummy_dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class RandomDataset(Dataset):
9
+ """
10
+ RandomDataset for generating random dataset.
11
+
12
+ Args:
13
+ num_samples (int): The number of samples to generate.
14
+ max_len (int): The maximum length of each sample.
15
+
16
+ """
17
+
18
+ def __init__(self, num_samples=10000, max_len=1024) -> None:
19
+ super().__init__()
20
+ rng = np.random.RandomState(1999)
21
+ max_num = rng.randint(1, 30, size=(num_samples,))
22
+ rep_num = rng.randint(10, 200, size=(num_samples,))
23
+ data = []
24
+ lengths = []
25
+ for n, r in zip(max_num, rep_num):
26
+ d = list(range(n)) * r
27
+ d = [n, r] + d
28
+ d = d[:max_len]
29
+ data.append(d)
30
+ lengths.append(len(d))
31
+ self.data = data
32
+ self.max_len = max_len
33
+ self.lengths = np.array(lengths, dtype=int)
34
+
35
+ def __getitem__(self, index):
36
+ d = self.data[index]
37
+ input_ids = np.array(d, dtype=int)
38
+ return {"tokens": list(input_ids), "type_id": 0}
39
+
40
+ def get_dataset_name(self):
41
+ return "dummy_path/dummy_lang/dummy_ds/train.bin"
42
+
43
+ def __len__(self):
44
+ return len(self.data)
InternLM/internlm/data/packed_dataset.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import itertools as it
5
+ import operator
6
+ import os
7
+ from copy import deepcopy
8
+ from typing import Dict
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data import ConcatDataset
13
+ from tqdm import tqdm
14
+
15
+ from internlm.core.context import global_context as gpc
16
+ from internlm.data.single_dataset import JsonlDataset
17
+ from internlm.data.utils import get_dataset_type_id
18
+ from internlm.utils.logger import get_logger
19
+
20
+ DEFAULT_SEED = 1024
21
+ logger = get_logger(__file__)
22
+
23
+
24
+ class PackedDataset(torch.utils.data.Dataset):
25
+ """
26
+ The class PackedDataset takes in a dataset and aggregates samples of different
27
+ lengths together based on the packed_length.
28
+
29
+ Args:
30
+ dataset: The original dataset to pack.
31
+ max_length_per_sample: The maximum length of each original sample. Default is 2048.
32
+ packed_length: The length of each packed sample. Default is 4096.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ dataset,
38
+ max_length_per_sample: int = 2048,
39
+ packed_length: int = 4096,
40
+ ):
41
+ assert hasattr(dataset, "lengths")
42
+ assert len(getattr(dataset, "lengths")) == len(
43
+ dataset
44
+ ), "The dataset must have lengths attribute and have the same length as the dataset"
45
+ self.dataset = dataset
46
+ self.max_length_per_sample = max_length_per_sample
47
+ self.lengths = getattr(self.dataset, "lengths")
48
+ self.packed_length = packed_length
49
+ # Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
50
+
51
+ self.seed = DEFAULT_SEED
52
+ self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed)
53
+ self.num_tokens = sum(self.lengths)
54
+
55
+ def get_dataset_name(self):
56
+ return self.dataset.get_dataset_name()
57
+
58
+ def accu_sample_len(self, seed=None):
59
+ """accumulative length of samples"""
60
+ if seed is not None:
61
+ rng = np.random.RandomState(seed)
62
+ else:
63
+ rng = np.random.RandomState(self.seed - 1)
64
+
65
+ sample_indices = np.arange(len(self.lengths))
66
+ rng.shuffle(sample_indices)
67
+ len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices))
68
+ acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add))
69
+ return sample_indices, len_samples_shuffled, acm_len_samples
70
+
71
+ def __len__(self):
72
+ # Line 405 of document_to_sequence.py in metaseq is directly spliced,
73
+ # without additional consideration of sos or eos
74
+ n_packs = self.num_tokens // self.packed_length
75
+ return n_packs
76
+
77
+ def cal_map(self, carriage_idx: int = 0):
78
+ assert carriage_idx >= 0
79
+ length_train = (carriage_idx + 1) * self.packed_length
80
+ post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left")
81
+ return post_pos
82
+
83
+ def mapping(self, pack_idx: int = 0):
84
+ # pack_idx is zero-based
85
+ pre_pos, pre_token_id = 0, 0
86
+ if pack_idx > 0:
87
+ pre_pos = self.cal_map(pack_idx - 1)
88
+ pre_token_id = self.len_samples_shuffled[pre_pos] - (
89
+ self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length
90
+ )
91
+ if pre_token_id == self.len_samples_shuffled[pre_pos]:
92
+ pre_pos += 1
93
+ pre_token_id = 0
94
+
95
+ pos = self.cal_map(pack_idx)
96
+ token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length)
97
+ return pre_pos, pre_token_id, pos, token_id
98
+
99
+ def build_pack(self, pre_pos: int, pre_token_id: int, pos: int, token_id: int):
100
+ pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
101
+
102
+ while pre_pos < pos:
103
+ sample_idx = self.sample_indices[pre_pos]
104
+ sample = self.dataset[sample_idx]
105
+ chunk = sample["tokens"][pre_token_id:]
106
+ pack.extend(chunk)
107
+ _labels = deepcopy(chunk)
108
+ _labels = list(_labels[1:]) + [-100]
109
+ assert len(_labels) == len(chunk), (_labels, chunk)
110
+ labels.extend(_labels)
111
+ type_ids.extend([sample.get("type_id", 0)] * len(chunk))
112
+ num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
113
+ for _ in range(num_new_samples):
114
+ cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
115
+ indexes.extend(list(range(self.max_length_per_sample)))
116
+ if tokens_left > 0:
117
+ cu_seqlens.append(cu_seqlens[-1] + tokens_left)
118
+ indexes.extend(list(range(tokens_left)))
119
+ pre_pos = pre_pos + 1
120
+ pre_token_id = 0
121
+
122
+ sample_idx = self.sample_indices[pos]
123
+ sample = self.dataset[sample_idx]
124
+ chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample
125
+ pack.extend(chunk)
126
+ _labels = deepcopy(chunk)
127
+ if token_id == len(sample["tokens"]):
128
+ _labels = list(_labels[1:]) + [-100]
129
+ else:
130
+ if token_id > len(sample["tokens"]):
131
+ print(f"token_id {token_id}, len of sample {len(sample['tokens'])}")
132
+ _labels = list(_labels[1:]) + [sample["tokens"][token_id]]
133
+ assert len(_labels) == len(chunk), (_labels, chunk)
134
+ labels.extend(_labels)
135
+ type_ids.extend([sample.get("type_id", 0)] * len(chunk))
136
+ num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample)
137
+ for _ in range(num_new_samples):
138
+ cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample)
139
+ indexes.extend(list(range(self.max_length_per_sample)))
140
+ if tokens_left > 0:
141
+ cu_seqlens.append(cu_seqlens[-1] + tokens_left)
142
+ indexes.extend(list(range(tokens_left)))
143
+
144
+ out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
145
+ return out
146
+
147
+ def cal_pos_unpack(self, index):
148
+ if index == 0:
149
+ pre_pos = 0
150
+ else:
151
+ pre_pos = index * gpc.config.data["micro_bsz"]
152
+
153
+ pos = (index + 1) * gpc.config.data["micro_bsz"]
154
+ return pre_pos, pos
155
+
156
+ def build_unpack(self, index):
157
+
158
+ pre_pos, pos = self.cal_pos_unpack(index)
159
+
160
+ pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
161
+
162
+ while pre_pos < pos and pre_pos < len(self.dataset):
163
+ sample_idx = self.sample_indices[pre_pos]
164
+ sample = self.dataset[sample_idx]
165
+ length = min(len(sample["tokens"]), self.max_length_per_sample)
166
+ chunk = sample["tokens"][0:length]
167
+ pack.extend(chunk)
168
+ _labels = deepcopy(chunk)
169
+ _labels = list(_labels[1:]) + [-100]
170
+ assert len(_labels) == len(chunk), (_labels, chunk)
171
+ labels.extend(_labels)
172
+ type_ids.extend([sample.get("type_id", 0)] * len(chunk))
173
+ cu_seqlens.append(cu_seqlens[-1] + len(chunk))
174
+ indexes.extend(list(range(length)))
175
+ pre_pos = pre_pos + 1
176
+
177
+ if cu_seqlens[-1] != self.packed_length:
178
+ pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
179
+ labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
180
+ type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
181
+ indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
182
+ cu_seqlens.append(self.packed_length)
183
+
184
+ assert len(pack) == self.packed_length
185
+
186
+ out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
187
+ return out
188
+
189
+ def __getitem__(self, item: int) -> Dict:
190
+ """Given the index, it returns a dict as
191
+ {
192
+ 'tokens': List[int],
193
+ 'cu_seqlens': List[int],
194
+ 'indexes': List[int], # denotes positional vector as 'tokens'
195
+ 'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
196
+ }
197
+ """
198
+
199
+ if gpc.config.model.use_flash_attn:
200
+ pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
201
+ return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
202
+
203
+ return self.build_unpack(item)
204
+
205
+
206
+ class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
207
+ """
208
+ A dataset wrapper that aggregates samples with different lengths based on packed_length.
209
+ If a sample is shorter than max_length_per_sample, it will be merged with other samples.
210
+ For tools, given a dataset with 10 samples:
211
+ [1, 2, 3, 4, 5]
212
+ [6, 7]
213
+ [8, 9, 10, 11]
214
+ [12, ..., 100]
215
+ ...
216
+
217
+ Args:
218
+ dataset: The original dataset to be wrapped.
219
+ max_length_per_sample (int): The maximum length allowed for each sample.
220
+ packed_length (int): The desired length for each packed sample.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ dataset,
226
+ max_length_per_sample: int = 2048,
227
+ packed_length: int = 4096,
228
+ debug=False,
229
+ ):
230
+ assert packed_length % max_length_per_sample == 0
231
+ assert hasattr(dataset, "lengths")
232
+ assert len(getattr(dataset, "lengths")) == len(
233
+ dataset
234
+ ), "The dataset must have lengths attribute and have the same length as the dataset"
235
+ self.dataset = dataset
236
+ self.max_length_per_sample = max_length_per_sample
237
+ self.lengths = getattr(self.dataset, "lengths")
238
+ self.bsz = packed_length // max_length_per_sample
239
+ self.packed_length = packed_length
240
+ self.debug = debug
241
+ # Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting
242
+
243
+ self.seed = DEFAULT_SEED
244
+ indices = np.arange(len(self.lengths))
245
+ rng = np.random.RandomState(self.seed)
246
+ rng.shuffle(indices)
247
+ self.indices = indices
248
+ self.cum_lens = np.cumsum(self.lengths[self.indices])
249
+ self.num_tokens = sum(self.lengths)
250
+
251
+ def get_dataset_name(self):
252
+ return self.dataset.get_dataset_name()
253
+
254
+ def __len__(self):
255
+ n_packs = self.num_tokens // self.packed_length
256
+ return n_packs
257
+
258
+ def find_offset(self, offset):
259
+ idx = np.searchsorted(self.cum_lens, offset, side="right")
260
+ if idx == 0:
261
+ return idx, offset
262
+ length = offset - self.cum_lens[idx - 1]
263
+ return idx, length
264
+
265
+ def pdebug(self, line):
266
+ if self.debug:
267
+ print(line, flush=True)
268
+
269
+ def __getitem__(self, item: int) -> Dict:
270
+ """Given the index, it returns a dict as
271
+ {
272
+ 'tokens': List[int],
273
+ 'cu_seqlens': List[int],
274
+ 'indexes': List[int], # denotes positional vector as 'tokens'
275
+ 'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction
276
+ }
277
+ """
278
+
279
+ start_idx, start_length = self.find_offset(item * self.packed_length)
280
+ end_idx, end_length = self.find_offset((item + 1) * self.packed_length)
281
+ pack_tokens = []
282
+ pack_labels = []
283
+ type_ids = []
284
+
285
+ self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ")
286
+ self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ")
287
+
288
+ if start_idx == end_idx:
289
+ idx = self.indices[start_idx]
290
+ sample = self.dataset[idx]
291
+ self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
292
+ tokens = sample["tokens"][start_length:end_length]
293
+ pack_tokens.extend(tokens)
294
+ pack_labels.extend(tokens[1:] + [-100])
295
+ type_ids.extend([sample["type_id"]] * len(tokens))
296
+ return {
297
+ "tokens": pack_tokens,
298
+ "cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
299
+ "indexes": list(range(self.max_length_per_sample)) * self.bsz,
300
+ "labels": pack_labels,
301
+ "type_ids": type_ids,
302
+ }
303
+
304
+ idx = self.indices[start_idx]
305
+ sample = self.dataset[idx]
306
+ self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
307
+ tokens = sample["tokens"][start_length:]
308
+ pack_tokens.extend(tokens)
309
+ pack_labels.extend(tokens[1:] + [-100])
310
+ type_ids.extend([sample["type_id"]] * len(tokens))
311
+
312
+ for i in range(start_idx + 1, end_idx):
313
+ idx = self.indices[i]
314
+ sample = self.dataset[idx]
315
+ self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
316
+ tokens = sample["tokens"]
317
+ pack_tokens.extend(tokens)
318
+ pack_labels.extend(tokens[1:] + [-100])
319
+ type_ids.extend([sample.get("type_id")] * len(tokens))
320
+
321
+ # corner case, the last sample is useless
322
+ if end_length == 0:
323
+ pass
324
+ else:
325
+ idx = self.indices[end_idx]
326
+ sample = self.dataset[idx]
327
+ self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}")
328
+ tokens = sample["tokens"][:end_length]
329
+ pack_tokens.extend(tokens)
330
+ pack_labels.extend(tokens[1:] + [-100])
331
+ type_ids.extend([sample.get("type_id")] * len(tokens))
332
+
333
+ return {
334
+ "tokens": pack_tokens,
335
+ "cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)],
336
+ "indexes": list(range(self.max_length_per_sample)) * self.bsz,
337
+ "labels": pack_labels,
338
+ "type_ids": type_ids,
339
+ }
340
+
341
+
342
+ def get_packed_dataset_without_short_length(
343
+ folder,
344
+ max_length_per_sample=2048,
345
+ packed_length=4096,
346
+ show_progress=False,
347
+ min_length=50,
348
+ min_length_dict=None,
349
+ pack_into_one_sample=False,
350
+ ):
351
+ """
352
+ Given a folder, combine all the .bin files into a single large dataset.
353
+ And filter out short samples with length less than 'min_length'.
354
+
355
+ Each .bin file is treated as a separate dataset.
356
+
357
+ Args:
358
+ folder (str): Path to the folder containing the .bin files.
359
+ max_length_per_sample (int): Maximum length of each sample.
360
+ packed_length (int): Length to pack samples to.
361
+ show_progress (bool): Whether to show the progress bar.
362
+ min_length (int): The minimum length of the sample.
363
+ min_length_dict (dict): The minimum length of the sample for each dataset.
364
+ The format is something like {'pile-arxiv': 50}
365
+ dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv
366
+
367
+ Returns:
368
+ A packed dataset containing all the data from the .bin files.
369
+ """
370
+
371
+ assert os.path.exists(folder), f"{folder} does not exist."
372
+ datasets = []
373
+ delete_samples = 0
374
+
375
+ for root, dirs, files in os.walk(folder, followlinks=True):
376
+ dirs.sort() # Let the folder need to be returned in a fixed order
377
+ if gpc.is_rank_for_log():
378
+ logger.info(f"Reading {root}...")
379
+ num_token_in_folder = 0
380
+
381
+ for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress):
382
+ if fn.endswith(".bin"):
383
+ fp = os.path.join(root, fn)
384
+ catch_ml_keys = []
385
+ min_length_num = min_length
386
+ if min_length_dict is not None:
387
+ for k, v in min_length_dict.items():
388
+ if k in fp:
389
+ min_length_num = v
390
+ catch_ml_keys.append(k)
391
+ assert (
392
+ len(catch_ml_keys) < 2
393
+ ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}"
394
+
395
+ ds_type_id = get_dataset_type_id(path=fp)
396
+ ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num)
397
+
398
+ if hasattr(ds, "old_length"):
399
+ delete_samples += ds.old_length - len(ds)
400
+ if len(ds) == 0:
401
+ if gpc.is_rank_for_log():
402
+ logger.info(f"None of the data in `{fp}` is longer than {min_length}")
403
+ continue
404
+
405
+ if pack_into_one_sample:
406
+ ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length)
407
+ else:
408
+ ds = PackedDataset(ds, max_length_per_sample, packed_length)
409
+
410
+ num_token_in_folder += len(ds) * packed_length
411
+ datasets.append(ds)
412
+
413
+ dataset = ConcatDataset(datasets=datasets)
414
+ if gpc.is_rank_for_log():
415
+ logger.info(
416
+ f"Find `{len(datasets)}` datasets, \
417
+ {len(dataset)} samples, \
418
+ delete `{delete_samples}` because of short length",
419
+ )
420
+
421
+ return dataset
InternLM/internlm/data/single_dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ """
5
+ A .bin file corresponds to a Dataset instance here.
6
+ """
7
+
8
+ import json
9
+ import mmap
10
+ import os
11
+ import threading
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import torch
16
+
17
+
18
+ class JsonlDataset(torch.utils.data.Dataset):
19
+ """
20
+
21
+ JSONL format is expected to roughly follow that of The Pile.
22
+ One-line-per-document of the form:
23
+ ```
24
+ {
25
+ "tokens": List[int],
26
+ }
27
+ ```
28
+
29
+ Note that only the "tokens" key is used.
30
+ """
31
+
32
+ def __init__(self, path: str, dataset_type_id: int = 0, min_length=50):
33
+ self.path = path
34
+ self.threadlocal = threading.local()
35
+ resolved_path = Path(path).resolve()
36
+ self.resolved_path = resolved_path
37
+ self.meta = Path(f"{resolved_path}.meta")
38
+ self.type_id = dataset_type_id
39
+
40
+ # only build the cache in on the primary worker to prevent overloading nfs
41
+ assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}"
42
+ try:
43
+ with open(self.meta, "rb") as f:
44
+ meta = np.load(f)
45
+ except Exception as e:
46
+ print(f"Cannot load file {self.meta}...")
47
+ raise e
48
+ self.offsets = meta[:, 0]
49
+ self.lengths = meta[:, -1]
50
+
51
+ if min_length > 0:
52
+ mask = self.lengths >= min_length
53
+ self.old_lengths = self.lengths.copy()
54
+ self.old_length = len(self.offsets)
55
+ self.offsets = self.offsets[mask]
56
+ self.lengths = self.lengths[mask]
57
+
58
+ def __getitem__(self, idx):
59
+ f = self._get_mmap()
60
+ position = self.offsets[idx]
61
+ f.seek(position)
62
+ item = f.readline().decode("utf-8")
63
+ try:
64
+ item = json.loads(item)
65
+ item["length"] = len(item["tokens"]) # add a length info
66
+ item["type_id"] = self.type_id
67
+ except Exception as err:
68
+ raise json.decoder.JSONDecodeError(
69
+ doc=self.path,
70
+ pos=position,
71
+ msg=(
72
+ f"Error while loading JSONL line in file {self.path} at byte "
73
+ f"{position}. Contents of line:\n{item}\n{err}"
74
+ ),
75
+ )
76
+ return item
77
+
78
+ def get_dataset_name(self):
79
+ return str(self.resolved_path)
80
+
81
+ def _get_mmap(self):
82
+ if not hasattr(self.threadlocal, "handles"):
83
+ with open(self.path, "rb") as f:
84
+ mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
85
+ self.threadlocal.handles = [f, mm]
86
+ if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"):
87
+ raise NotImplementedError(
88
+ "Compressed files are not supported because .seek() would require "
89
+ "rereading the entire file, making performance too slow."
90
+ )
91
+ return self.threadlocal.handles[-1]
92
+
93
+ def __setstate__(self, state):
94
+ self.__dict__ = state
95
+ self.threadlocal = threading.local()
96
+
97
+ def __getstate__(self):
98
+ d = {}
99
+ for i, v in self.__dict__.items():
100
+ if i != "threadlocal":
101
+ d[i] = v
102
+ return d
103
+
104
+ def __del__(self):
105
+ if hasattr(self.threadlocal, "handles"):
106
+ # cleanup files we opened on initialization
107
+ while self.threadlocal.handles:
108
+ self.threadlocal.handles.pop().close()
109
+
110
+ @staticmethod
111
+ def exists(path):
112
+ return os.path.exists(path)
113
+
114
+ def __len__(self):
115
+ # Virtual length of the dataset depends on the epoch number if the number of documents
116
+ # is not perfectly divisible by the data_subshard_count
117
+ return len(self.offsets)
InternLM/internlm/data/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+
6
+ from internlm.core.context import global_context as gpc
7
+
8
+ DATASET_TYPE_IDS_MAP = {"vision": 0}
9
+
10
+
11
+ def get_dataset_type_id(path):
12
+ import re
13
+
14
+ match_idxes = []
15
+ for key, idx in DATASET_TYPE_IDS_MAP.items():
16
+ if re.search(rf"/[z_]*{key}/", path):
17
+ match_idxes.append(idx)
18
+ assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
19
+ return match_idxes[0]
20
+
21
+
22
+ def unpack_data(input_ids, cu_seqlens):
23
+ """
24
+ input_ids: (n, packed_length)
25
+ Return:
26
+ output: (batch_size, max_length)
27
+ """
28
+
29
+ bsz = input_ids.shape[0]
30
+
31
+ num_sequence = gpc.config.data["micro_bsz"]
32
+
33
+ outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
34
+
35
+ for i in range(bsz):
36
+ output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
37
+ cu_seqlens_slice = cu_seqlens[i]
38
+ for j in range(num_sequence):
39
+ seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
40
+ output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
41
+ outputs[i] = output
42
+
43
+ if bsz == 1:
44
+ outputs = outputs.squeeze(0)
45
+
46
+ return outputs
InternLM/internlm/initialize/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .initialize_trainer import initialize_trainer, initialize_kd_trainer
2
+ from .launch import (
3
+ get_default_parser,
4
+ initialize_distributed_env,
5
+ launch_from_slurm,
6
+ launch_from_torch,
7
+ )
8
+
9
+ __all__ = [
10
+ "get_default_parser",
11
+ "initialize_trainer",
12
+ "initialize_kd_trainer",
13
+ "launch_from_slurm",
14
+ "launch_from_torch",
15
+ "initialize_distributed_env",
16
+ ]
InternLM/internlm/initialize/initialize_tensor.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import math
5
+
6
+ from torch import Tensor, nn
7
+
8
+
9
+ def scaled_init_method_normal(sigma: float = 1.0, num_layers: int = 1):
10
+ """Init method based on N(0, sigma/sqrt(2*num_layers)."""
11
+ std = sigma / math.sqrt(2.0 * num_layers)
12
+
13
+ def init_(tensor):
14
+ return nn.init.normal_(tensor, mean=0.0, std=std)
15
+
16
+ return init_
17
+
18
+
19
+ def normal_(mean: float = 0.0, std: float = 1.0):
20
+ r"""Return the initializer filling the input Tensor with values drawn from the normal distribution
21
+
22
+ .. math::
23
+ \mathcal{N}(\text{mean}, \text{std}^2)
24
+
25
+ Args:
26
+ mean (float): the mean of the normal distribution. Defaults 0.0.
27
+ std (float): the standard deviation of the normal distribution. Defaults 1.0.
28
+ """
29
+
30
+ def initializer(tensor: Tensor):
31
+ return nn.init.normal_(tensor, mean, std)
32
+
33
+ return initializer
34
+
35
+
36
+ def scaled_init_method_uniform(sigma: float = 1.0, num_layers: int = 1):
37
+ """Init method based on p(x)=Uniform(-a, a) where std(x)=sigma/sqrt(2*num_layers)."""
38
+ std = sigma / math.sqrt(2.0 * num_layers)
39
+ a = math.sqrt(3.0 * std)
40
+
41
+ def init_(tensor):
42
+ return nn.init.uniform_(tensor, -a, a)
43
+
44
+ return init_
45
+
46
+
47
+ def uniform_(mean: float = 0.0, std: float = 1.0):
48
+ r"""Return the initializer filling the input Tensor with values drawn from the uniform distribution
49
+
50
+ .. math::
51
+ \mathcal{U}(mean-a, mean+a), where a satisfies \mathcal{U}_{std}=std.
52
+
53
+ Args:
54
+ mean (float): the mean of the uniform distribution. Defaults 0.0.
55
+ std (float): the standard deviation of the uniform distribution. Defaults 1.0.
56
+ """
57
+
58
+ a = math.sqrt(3.0 * std)
59
+
60
+ def initializer(tensor: Tensor):
61
+ return nn.init.uniform_(tensor, mean - a, mean + a)
62
+
63
+ return initializer
InternLM/internlm/initialize/initialize_trainer.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize
5
+
6
+ from typing import Callable, Iterable, List, Optional, Tuple
7
+
8
+ from torch import nn
9
+ from torch.nn.modules.loss import _Loss
10
+ from torch.optim.lr_scheduler import _LRScheduler
11
+ from torch.optim.optimizer import Optimizer
12
+ from torch.utils.data import DataLoader
13
+
14
+ from internlm.core.context import global_context as gpc
15
+ from internlm.core.context import ParallelMode
16
+ from internlm.core.engine import Engine, KDEngine
17
+ from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
18
+ from internlm.core.scheduler import (InterleavedPipelineScheduler, KDNonPipelineScheduler, KDPipelineScheduler,
19
+ NonPipelineScheduler, PipelineScheduler, SchedulerHook)
20
+ from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape
21
+ from internlm.core.trainer import Trainer
22
+ from internlm.data.utils import unpack_data
23
+ from internlm.solver.beta2_scheduler import Beta2Scheduler
24
+ from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
25
+ from internlm.utils.common import get_current_device
26
+
27
+
28
+ def initialize_kd_trainer(
29
+ model: nn.Module,
30
+ teacher: nn.Module,
31
+ optimizer: Optimizer,
32
+ criterion: Optional[_Loss] = None,
33
+ kd_criterion: Optional[_Loss] = None,
34
+ train_dataloader: Optional[Iterable] = None,
35
+ test_dataloader: Optional[Iterable] = None,
36
+ lr_scheduler: Optional[_LRScheduler] = None,
37
+ beta2_scheduler: Optional[Beta2Scheduler] = None,
38
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
39
+ ) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
40
+ """Core function to wrap the essential training components with our functionality based on the config which is
41
+ loaded into gpc.config.
42
+
43
+ Args:
44
+ model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
45
+ optimizer (:class:`BaseOptimizer`): Your optimizer for training.
46
+ criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
47
+ train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
48
+ test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
49
+ lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
50
+
51
+ Returns:
52
+ Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
53
+ A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
54
+ where only ``trainer`` could not be None.
55
+ """
56
+
57
+ if isinstance(model, nn.Module):
58
+ # first sync model across dp ranks
59
+ model.to(get_current_device())
60
+ elif isinstance(model, Callable):
61
+ model = model().to(get_current_device())
62
+
63
+ # clip grad norm
64
+ clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
65
+
66
+ assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
67
+
68
+ # gradient handler, only support PipelineSharedModuleGradientHandler now
69
+ if gpc.is_using_pp():
70
+ gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
71
+ gradient_handler_cfg = gpc.config.get("gradient_handler", [])
72
+ gradient_handlers = []
73
+ assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
74
+ for config in gradient_handler_cfg:
75
+ if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
76
+ handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
77
+ gradient_handlers.append(handler)
78
+
79
+ # initialize scheduler for trainer
80
+ scheduler = None
81
+ if gpc.config.model.use_flash_attn:
82
+ data_fn = None
83
+ else:
84
+ data_fn = unpack_data
85
+ if gpc.is_using_pp():
86
+ gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
87
+ tensor_shape = get_tensor_shape()
88
+ use_interleaved = (
89
+ hasattr(gpc.config, "model") and hasattr(gpc.config.model,
90
+ "num_chunks") and gpc.config.model.num_chunks > 1
91
+ )
92
+ scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
93
+ if use_interleaved:
94
+ raise NotImplementedError('InterleavedPipelineScheduler for KD is not implemented')
95
+
96
+ else:
97
+ scheduler = KDPipelineScheduler(
98
+ data_process_func=data_fn,
99
+ num_microbatches=gpc.config.NUM_MICRO_BATCHES,
100
+ dtype=gpc.config.model["dtype"],
101
+ tensor_shape=tensor_shape,
102
+ scatter_gather_tensors=scatter_gather,
103
+ scheduler_hooks=scheduler_hooks,
104
+ )
105
+ else:
106
+ scheduler = KDNonPipelineScheduler(
107
+ data_process_func=data_fn,
108
+ gradient_accumulation_size=gpc.config.data.gradient_accumulation,
109
+ scheduler_hooks=scheduler_hooks,
110
+ )
111
+
112
+ # initialize engine for trainer
113
+ engine = KDEngine(
114
+ model=model,
115
+ teacher=teacher,
116
+ optimizer=optimizer,
117
+ lr_scheduler=lr_scheduler,
118
+ beta2_scheduler=beta2_scheduler,
119
+ criterion=criterion,
120
+ kd_criterion=kd_criterion,
121
+ gradient_handlers=gradient_handlers,
122
+ clip_grad_norm=clip_grad_norm,
123
+ )
124
+
125
+ trainer = Trainer(engine, scheduler)
126
+
127
+ return trainer, train_dataloader, test_dataloader, lr_scheduler
128
+
129
+
130
+ def initialize_trainer(
131
+ model: nn.Module,
132
+ optimizer: Optimizer,
133
+ criterion: Optional[_Loss] = None,
134
+ train_dataloader: Optional[Iterable] = None,
135
+ test_dataloader: Optional[Iterable] = None,
136
+ lr_scheduler: Optional[_LRScheduler] = None,
137
+ beta2_scheduler: Optional[Beta2Scheduler] = None,
138
+ scheduler_hooks: Optional[List[SchedulerHook]] = None,
139
+ ) -> Tuple[Trainer, DataLoader, DataLoader, _LRScheduler]:
140
+ """Core function to wrap the essential training components with our functionality based on the config which is
141
+ loaded into gpc.config.
142
+
143
+ Args:
144
+ model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model.
145
+ optimizer (:class:`BaseOptimizer`): Your optimizer for training.
146
+ criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
147
+ train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
148
+ test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
149
+ lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
150
+
151
+ Returns:
152
+ Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler):
153
+ A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``
154
+ where only ``trainer`` could not be None.
155
+ """
156
+
157
+ if isinstance(model, nn.Module):
158
+ # first sync model across dp ranks
159
+ model.to(get_current_device())
160
+ elif isinstance(model, Callable):
161
+ model = model().to(get_current_device())
162
+
163
+ # clip grad norm
164
+ clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
165
+
166
+ assert isinstance(optimizer, BaseOptimizer), "optimizer must be instance of BaseOptimizer"
167
+
168
+ # gradient handler, only support PipelineSharedModuleGradientHandler now
169
+ if gpc.is_using_pp():
170
+ gpc.config.gradient_handler = [dict(type="PipelineSharedModuleGradientHandler")]
171
+ gradient_handler_cfg = gpc.config.get("gradient_handler", [])
172
+ gradient_handlers = []
173
+ assert isinstance(gradient_handler_cfg, list), f"gradient_handler must be list but got {type(gradient_handler_cfg)}"
174
+ for config in gradient_handler_cfg:
175
+ if isinstance(config, dict) and config.get("type") == "PipelineSharedModuleGradientHandler":
176
+ handler = PipelineSharedModuleGradientHandler(model=model, optimizer=optimizer)
177
+ gradient_handlers.append(handler)
178
+
179
+ # initialize scheduler for trainer
180
+ scheduler = None
181
+ if gpc.config.model.use_flash_attn:
182
+ data_fn = None
183
+ else:
184
+ data_fn = unpack_data
185
+ if gpc.is_using_pp():
186
+ gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
187
+ tensor_shape = get_tensor_shape()
188
+ use_interleaved = (
189
+ hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
190
+ )
191
+ scatter_gather = gpc.is_initialized(ParallelMode.TENSOR)
192
+ if use_interleaved:
193
+ if isinstance(model, nn.Sequential):
194
+ model = nn.ModuleList([model])
195
+
196
+ communication_overlap = gpc.config.parallel["pipeline"].get("interleaved_overlap", False)
197
+ scheduler = InterleavedPipelineScheduler(
198
+ num_microbatches=gpc.config.NUM_MICRO_BATCHES,
199
+ num_chunks=gpc.config.model.num_chunks,
200
+ dtype=gpc.config.model["dtype"],
201
+ tensor_shape=tensor_shape,
202
+ scatter_gather_tensors=scatter_gather,
203
+ scheduler_hooks=scheduler_hooks,
204
+ communication_overlap=communication_overlap,
205
+ )
206
+ else:
207
+ scheduler = PipelineScheduler(
208
+ data_process_func=data_fn,
209
+ num_microbatches=gpc.config.NUM_MICRO_BATCHES,
210
+ dtype=gpc.config.model["dtype"],
211
+ tensor_shape=tensor_shape,
212
+ scatter_gather_tensors=scatter_gather,
213
+ scheduler_hooks=scheduler_hooks,
214
+ )
215
+ else:
216
+ scheduler = NonPipelineScheduler(
217
+ data_process_func=data_fn,
218
+ gradient_accumulation_size=gpc.config.data.gradient_accumulation,
219
+ scheduler_hooks=scheduler_hooks,
220
+ )
221
+
222
+ # initialize engine for trainer
223
+ engine = Engine(
224
+ model=model,
225
+ optimizer=optimizer,
226
+ lr_scheduler=lr_scheduler,
227
+ beta2_scheduler=beta2_scheduler,
228
+ criterion=criterion,
229
+ gradient_handlers=gradient_handlers,
230
+ clip_grad_norm=clip_grad_norm,
231
+ )
232
+
233
+ trainer = Trainer(engine, scheduler)
234
+
235
+ return trainer, train_dataloader, test_dataloader, lr_scheduler
InternLM/internlm/initialize/launch.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import argparse
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Dict, Union
8
+
9
+ import torch
10
+
11
+ from internlm.core.context import Config
12
+ from internlm.core.context import global_context as gpc
13
+ from internlm.monitor import initialize_light_monitor
14
+ from internlm.utils.common import get_master_node
15
+ from internlm.utils.logger import get_logger
16
+ from internlm.utils.timeout import llm_timeout
17
+
18
+ logger = get_logger(__file__)
19
+
20
+
21
+ def get_default_parser():
22
+ """Reads user command line and uses an argument parser to parse the input arguments.
23
+ Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
24
+
25
+ Returns:
26
+ Parser: Returns the parser with the default arguments, the user may add customized arguments into this parser.
27
+ """
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--config", type=str, help="path to the config file")
30
+ parser.add_argument(
31
+ "--launcher",
32
+ type=str,
33
+ default="slurm",
34
+ choices=["slurm", "torch"],
35
+ help="launcher for launching distributed environment",
36
+ )
37
+ parser.add_argument("--host", type=str, help="the master address for distributed training")
38
+ parser.add_argument("--port", type=int, default=8888, help="the master port for distributed training")
39
+ parser.add_argument("--world_size", type=int, help="world size for distributed training")
40
+ parser.add_argument("--rank", type=int, help="rank for the default process group")
41
+ parser.add_argument("--local_rank", type=int, help="local rank on the node")
42
+ parser.add_argument("--backend", type=str, default="nccl", help="backend for distributed communication")
43
+ parser.add_argument("--seed", type=int, default=1024)
44
+ parser.add_argument("--profiling", default=False, action="store_true", help="enable/disable profiling.")
45
+ return parser
46
+
47
+
48
+ def args_sanity_check():
49
+ assert gpc.config is not None, "config is not load!"
50
+
51
+ # the default model type is INTERNLM
52
+ if "model_type" not in gpc.config:
53
+ gpc.config._add_item("model_type", "INTERNLM")
54
+
55
+ # procssing the parallel config in gpc
56
+ if "zero1" not in gpc.config.parallel:
57
+ gpc.config.parallel._add_item("zero1", -1)
58
+
59
+ if "pipeline" not in gpc.config.parallel:
60
+ gpc.config.parallel._add_item("pipeline", 1)
61
+
62
+ if "tensor" not in gpc.config.parallel:
63
+ gpc.config.parallel._add_item("tensor", 1)
64
+
65
+ # processing the data config in gpc
66
+ data = gpc.config.data
67
+
68
+ assert data.seq_len is not None, "'seq_len' must be given a value"
69
+ assert data.micro_bsz is not None, "'micro_bsz' must be given a value"
70
+
71
+ if "packed_length" in data and gpc.is_rank_for_log():
72
+ logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.")
73
+
74
+ data._add_item("packed_length", data.seq_len * data.micro_bsz)
75
+
76
+ if "micro_num" not in data:
77
+ data._add_item("micro_num", 1)
78
+
79
+ data._add_item("gradient_accumulation", data.micro_num)
80
+ if gpc.is_rank_for_log():
81
+ logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
82
+
83
+ # batch_size should be equal with micro_num, should not use it directly
84
+ data._add_item("batch_size", data.micro_num)
85
+
86
+ if "min_length" not in data:
87
+ data._add_item("min_length", 0)
88
+
89
+ if "train_folder" not in data:
90
+ data._add_item("train_folder", None)
91
+
92
+ if "valid_folder" not in data:
93
+ data._add_item("valid_folder", None)
94
+
95
+ if "valid_micro_num" not in data:
96
+ data._add_item("valid_micro_num", data.micro_num)
97
+
98
+ if "valid_every" not in data:
99
+ data._add_item("valid_every", 0)
100
+
101
+ if "empty_cache_and_diag_interval" not in data:
102
+ data._add_item("empty_cache_and_diag_interval", 50)
103
+
104
+ if "diag_outlier_ratio" not in data:
105
+ data._add_item("diag_outlier_ratio", 1.1)
106
+ data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
107
+
108
+ if gpc.is_rank_for_log():
109
+ logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
110
+ logger.info(f"seq_len: {data.seq_len}")
111
+ logger.info(f"micro_num: {data.micro_num}")
112
+ logger.info(f"micro_bsz: {data.micro_bsz}")
113
+ logger.info(f"packed_length: {data.packed_length}")
114
+ logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
115
+ logger.info(f"min_length: {data.min_length}")
116
+ logger.info(f"valid_micro_num: {data.valid_micro_num}")
117
+ logger.info(f"valid_every: {data.valid_every}")
118
+
119
+ # processing the checkpoint config
120
+ ckpt = gpc.config.ckpt
121
+ if "enable_save_ckpt" not in ckpt:
122
+ ckpt._add_item("enable_save_ckpt", True)
123
+
124
+ # Saving checkpoint args.
125
+ if ckpt.enable_save_ckpt:
126
+ assert "checkpoint_every" in ckpt, "If enable save checkpoint, must give checkpoint_every in config.data!"
127
+ assert ckpt.checkpoint_every > 0
128
+ assert "save_ckpt_folder" in ckpt, "If enable save checkpoint, must give save_ckpt_folder in config.data!"
129
+
130
+ if "async_upload" not in ckpt:
131
+ ckpt._add_item("async_upload", False) # async defalut is False.
132
+ else:
133
+ if ckpt.async_upload:
134
+ assert "save_ckpt_folder" in ckpt
135
+ if "boto3:" not in ckpt.save_ckpt_folder:
136
+ if gpc.is_rank_for_log():
137
+ logger.warning(
138
+ "Storing ckpt on file system does not support asynchronous storage, will use sync save!"
139
+ )
140
+ ckpt.async_upload = False
141
+ else:
142
+ if "async_upload_tmp_folder" not in ckpt:
143
+ ckpt._add_item("async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/")
144
+
145
+ if not ckpt.async_upload:
146
+ ckpt._add_item("async_upload_tmp_folder", None)
147
+
148
+ if "oss_snapshot_freq" not in ckpt:
149
+ ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
150
+ else:
151
+ ckpt._add_item("checkpoint_every", float("inf"))
152
+ ckpt._add_item("oss_snapshot_freq", float("inf"))
153
+ ckpt._add_item("save_ckpt_folder", None)
154
+ ckpt._add_item("async_upload", False)
155
+ ckpt._add_item("async_upload_tmp_folder", None)
156
+ ckpt._add_item("snapshot_ckpt_folder", None)
157
+
158
+ if "load_ckpt_folder" not in ckpt:
159
+ ckpt._add_item("load_ckpt_folder", None)
160
+
161
+ if "stop_file_path" not in ckpt:
162
+ ckpt._add_item("stop_file_path", None)
163
+
164
+ if "auto_resume" not in ckpt:
165
+ # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
166
+ # to auto-load latest checkpoint.
167
+ ckpt._add_item("auto_resume", True)
168
+
169
+ if gpc.is_rank_for_log():
170
+ logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
171
+ logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
172
+ logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
173
+ logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
174
+
175
+ # tensorboard writer config
176
+ if "enable_tb" not in gpc.config:
177
+ gpc.config._add_item("enable_tb", True)
178
+ if "tensorboard_folder" not in gpc.config:
179
+ gpc.config._add_item(
180
+ "tensorboard_folder", os.environ["tensorboard_folder"] if "tensorboard_folder" in os.environ else None
181
+ )
182
+ if "resume_tb_folder" not in gpc.config:
183
+ gpc.config._add_item(
184
+ "resume_tb_folder", os.environ["resume_tb_folder"] if "resume_tb_folder" in os.environ else None
185
+ )
186
+
187
+ if gpc.is_rank_for_log():
188
+ logger.info(f"tensorboard_folder: {gpc.config.tensorboard_folder}")
189
+ logger.info(f"resume_tb_folder: {gpc.config.resume_tb_folder}")
190
+
191
+ # cudnn
192
+ torch.backends.cudnn.benchmark = gpc.config.get("cudnn_benchmark", False)
193
+ torch.backends.cudnn.deterministic = gpc.config.get("cudnn_deterministic", False)
194
+ clip_grad_norm = gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0)
195
+
196
+ if gpc.is_rank_for_log():
197
+ logger.info("+" * 15 + " Other Info " + "+" * 15) # pylint: disable=W1201
198
+ logger.info(f"cudnn.benchmark: {torch.backends.cudnn.benchmark }")
199
+ logger.info(f"cudnn.deterministic: {torch.backends.cudnn.deterministic }")
200
+ logger.info(f"clip_grad_norm: {clip_grad_norm}")
201
+
202
+ model = gpc.config.model
203
+ if "dtype" not in model:
204
+ logger.warning("dtype is not set, use torch.float16 by defalut!")
205
+ model._add_item("dtype", torch.float16)
206
+ else:
207
+ if gpc.config.model.dtype == "torch.bfloat16":
208
+ gpc.config.model.dtype = torch.bfloat16
209
+ elif gpc.config.model.dtype in ("torch.float16", "torch.half"):
210
+ gpc.config.model.dtype = torch.float16
211
+ elif gpc.config.model.dtype == "torch.float32":
212
+ gpc.config.model.dtype = torch.float32
213
+ elif gpc.config.model.dtype == "torch.tf32":
214
+ torch.backends.cudnn.allow_tf32 = True
215
+ torch.backends.cuda.matmul.allow_tf32 = True
216
+ gpc.config.model.dtype = torch.float32
217
+ else:
218
+ assert gpc.config.model.dtype in [
219
+ "torch.float16",
220
+ "torch.half",
221
+ "torch.bfloat16",
222
+ "torch.float32",
223
+ "torch.tf32",
224
+ ]
225
+
226
+ if "checkpoint" in model:
227
+ if model.checkpoint is True:
228
+ model.checkpoint = 1
229
+ elif model.checkpoint is False:
230
+ model.checkpoint = 0
231
+ else:
232
+ assert (
233
+ model.checkpoint >= 0 and model.checkpoint <= 1
234
+ ), f'model.checkpoint: "{model.checkpoint}" should >=0 and <=1'
235
+
236
+ if "teacher" in gpc.config:
237
+ teacher = gpc.config.teacher
238
+ if "dtype" not in teacher:
239
+ logger.warning("dtype is not set, use torch.float16 by defalut!")
240
+ teacher._add_item("dtype", torch.float16)
241
+ else:
242
+ if gpc.config.teacher.dtype == "torch.bfloat16":
243
+ gpc.config.teacher.dtype = torch.bfloat16
244
+ elif gpc.config.teacher.dtype in ("torch.float16", "torch.half"):
245
+ gpc.config.teacher.dtype = torch.float16
246
+ elif gpc.config.teacher.dtype == "torch.float32":
247
+ gpc.config.teacher.dtype = torch.float32
248
+ elif gpc.config.teacher.dtype == "torch.tf32":
249
+ torch.backends.cudnn.allow_tf32 = True
250
+ torch.backends.cuda.matmul.allow_tf32 = True
251
+ gpc.config.teacher.dtype = torch.float32
252
+ else:
253
+ assert gpc.config.teacher.dtype in [
254
+ "torch.float16",
255
+ "torch.half",
256
+ "torch.bfloat16",
257
+ "torch.float32",
258
+ "torch.tf32",
259
+ ]
260
+
261
+ if "checkpoint" in teacher:
262
+ if teacher.checkpoint is True:
263
+ teacher.checkpoint = 1
264
+ elif teacher.checkpoint is False:
265
+ teacher.checkpoint = 0
266
+ else:
267
+ assert (
268
+ teacher.checkpoint >= 0 and teacher.checkpoint <= 1
269
+ ), f'teacher.checkpoint: "{teacher.checkpoint}" should >=0 and <=1'
270
+
271
+ if gpc.is_rank_for_log():
272
+ logger.info("+" * 15 + " Model Info " + "+" * 15) # pylint: disable=W1201
273
+ logger.info(f"Model: {gpc.config.model}")
274
+
275
+ logger.info("+" * 15 + " grad_scaler Info " + "+" * 15) # pylint: disable=W1201
276
+ logger.info(f"grad_scaler: {gpc.config.grad_scaler}")
277
+
278
+ logger.info("+" * 15 + " hybrid_zero_optimizer Info " + "+" * 15) # pylint: disable=W1201
279
+ logger.info(f"hybrid_zero_optimizer: {gpc.config.hybrid_zero_optimizer}")
280
+
281
+ logger.info("+" * 15 + " adam Info " + "+" * 15) # pylint: disable=W1201
282
+ logger.info(f"adam: {gpc.config.adam}")
283
+
284
+ logger.info("+" * 15 + " beta2_scheduler Info " + "+" * 15) # pylint: disable=W1201
285
+ logger.info(f"beta2_scheduler: {gpc.config.beta2_scheduler}")
286
+
287
+ # process the model config
288
+ if "use_flash_attn" not in gpc.config.model:
289
+ gpc.config.model._add_item("use_flash_attn", True)
290
+
291
+ # process the parallel config
292
+ if "sequence_parallel" not in gpc.config.parallel:
293
+ gpc.config.parallel._add_item("sequence_parallel", False)
294
+ else:
295
+ assert not (
296
+ gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
297
+ ), "sequence parallel does not support use_flash_attn=False"
298
+
299
+ # monitoring default config
300
+ monitor_default_config = {
301
+ "alert_address": None, # compatible with old alert config
302
+ "monitor": { # new monitoring config
303
+ "alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
304
+ },
305
+ }
306
+
307
+ for key, value in monitor_default_config.items():
308
+ if key not in gpc.config:
309
+ gpc.config._add_item(key, value)
310
+
311
+ alert = gpc.config.monitor.alert
312
+
313
+ if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
314
+ logger.warning("alert is enable but alert_address is not set")
315
+
316
+ optim_ckpt = gpc.config.hybrid_zero_optimizer
317
+ if "zero_overlap_communication" in optim_ckpt:
318
+ # Compatible with the old interfaces.
319
+ optim_ckpt._add_item("overlap_sync_grad", optim_ckpt.zero_overlap_communication)
320
+ if "overlap_sync_grad" not in optim_ckpt:
321
+ optim_ckpt._add_item("overlap_sync_grad", False)
322
+ if "overlap_sync_param" not in optim_ckpt:
323
+ optim_ckpt._add_item("overlap_sync_param", False)
324
+ if gpc.is_rank_for_log():
325
+ logger.info(
326
+ f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}"
327
+ )
328
+
329
+
330
+ def launch(
331
+ config: Union[str, Path, Config, Dict],
332
+ rank: int,
333
+ world_size: int,
334
+ host: str,
335
+ port: int,
336
+ backend: str = "nccl",
337
+ local_rank: int = None,
338
+ seed: int = 1024,
339
+ ):
340
+ """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
341
+ arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
342
+
343
+ Args:
344
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
345
+ rank (int): Rank for the default process group
346
+ world_size (int): World size of the default process group
347
+ host (str): The master address for distributed training
348
+ port (str): The master port for distributed training
349
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
350
+ local_rank (int, optional):
351
+ Rank for the process on the node and is used to set the default CUDA device,
352
+ defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
353
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
354
+
355
+ Raises:
356
+ Exception: Raise exception when config type is wrong
357
+ """
358
+
359
+ # set config
360
+ assert isinstance(
361
+ config, (Config, str, Path, dict)
362
+ ), f"expected argument config to be Config, str or Path, but got {type(config)}"
363
+ if not isinstance(config, Config) and isinstance(config, dict):
364
+ config = Config(config)
365
+ if isinstance(config, (str, Path)):
366
+ config = Config.from_file(config)
367
+ gpc.load_config(config)
368
+
369
+ # init default process group
370
+ gpc.init_global_dist(rank, world_size, backend, host, port)
371
+
372
+ # init process groups for different parallel modes from config
373
+ gpc.init_parallel_groups()
374
+
375
+ # set cuda device
376
+ if torch.cuda.is_available():
377
+ # if local rank is not given, calculate automatically
378
+ gpc.set_device(local_rank)
379
+
380
+ # set the number of processes running on the same node
381
+ gpc.detect_num_processes_on_current_node()
382
+
383
+ gpc.set_seed(seed)
384
+
385
+ if gpc.is_rank_for_log():
386
+ logger.info(
387
+ f"Distributed environment is initialized, "
388
+ f"data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, "
389
+ f"tensor parallel size: {gpc.tensor_parallel_size}",
390
+ )
391
+
392
+
393
+ def launch_from_slurm(
394
+ config: Union[str, Path, Config, Dict],
395
+ host: str,
396
+ port: int,
397
+ backend: str = "nccl",
398
+ seed: int = 1024,
399
+ ):
400
+ """A wrapper for internlm.launch for SLURM launcher by reading rank and world size from the environment variables
401
+ set by SLURM
402
+
403
+ Args:
404
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
405
+ host (str): The master address for distributed training
406
+ port (str): The master port for distributed training
407
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
408
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
409
+ """
410
+ try:
411
+ rank = int(os.environ["SLURM_PROCID"])
412
+ world_size = int(os.environ["SLURM_NPROCS"])
413
+ except KeyError as e:
414
+ raise RuntimeError(f"Could not find {e} in the SLURM environment")
415
+
416
+ launch(
417
+ config=config,
418
+ rank=rank,
419
+ world_size=world_size,
420
+ host=host,
421
+ port=port,
422
+ backend=backend,
423
+ seed=seed,
424
+ )
425
+
426
+
427
+ def launch_from_torch(
428
+ config: Union[str, Path, Config, Dict],
429
+ backend: str = "nccl",
430
+ seed: int = 1024,
431
+ ):
432
+ """A wrapper for internlm.launch for torchrun or torch.distributed.launch by reading rank and world size
433
+ from the environment variables set by PyTorch
434
+
435
+ Args:
436
+ config (Union[str, dict, Config]): Config file or config file path are both acceptable
437
+ backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
438
+ seed (int, optional): Specified random seed for every process. Defaults to 1024.
439
+ """
440
+ try:
441
+ rank = int(os.environ["RANK"])
442
+ local_rank = int(os.environ["LOCAL_RANK"])
443
+ world_size = int(os.environ["WORLD_SIZE"])
444
+ host = os.environ["MASTER_ADDR"]
445
+ port = int(os.environ["MASTER_PORT"])
446
+ except KeyError as e:
447
+ raise RuntimeError(f"Could not find {e} in the torch environment")
448
+
449
+ launch(
450
+ config=config,
451
+ local_rank=local_rank,
452
+ rank=rank,
453
+ world_size=world_size,
454
+ host=host,
455
+ port=port,
456
+ backend=backend,
457
+ seed=seed,
458
+ )
459
+
460
+
461
+ @llm_timeout(func_name="initialize_distributed_env")
462
+ def initialize_distributed_env(
463
+ config: str,
464
+ launcher: str = "slurm",
465
+ master_port: int = 8888,
466
+ seed: int = 1024,
467
+ args_check=True,
468
+ ):
469
+ """
470
+ Initialize distributed environment for distributed training.
471
+
472
+ Args:
473
+ config (str): Config file path.
474
+ launcher (str): Launcher for launching distributed environment, can be slurm or torch. "slurm" by default.
475
+ master_port (str): The master port for distributed training. 8888 by default.
476
+ seed (int, optional): Specified random seed for every process. 1024 by default.
477
+ """
478
+
479
+ torch.cuda.empty_cache()
480
+
481
+ if launcher == "torch":
482
+ launch_from_torch(config=config, seed=seed)
483
+ elif launcher == "slurm":
484
+ launch_from_slurm(
485
+ config=config,
486
+ host=get_master_node(),
487
+ port=master_port,
488
+ seed=seed,
489
+ )
490
+ else:
491
+ assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
492
+
493
+ if args_check:
494
+ args_sanity_check()
495
+
496
+ # init light monitor client
497
+ alert_config = gpc.config.monitor.alert
498
+ if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
499
+ light_monitor_address = alert_config.light_monitor_address
500
+ if light_monitor_address:
501
+ initialize_light_monitor(light_monitor_address)
502
+ else:
503
+ logger.warning("monitor address is none, monitor could not be used!")
504
+
505
+
506
+ def get_config_value(config, key, defalut):
507
+ try:
508
+ value = config[key]
509
+ except KeyError:
510
+ value = defalut
511
+ return value
InternLM/internlm/initialize/legacy/__init__.py ADDED
File without changes
InternLM/internlm/initialize/legacy/launch.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from internlm.initialize.launch import get_config_value
5
+ from internlm.utils.logger import get_logger
6
+
7
+ logger = get_logger(__file__)
8
+
9
+
10
+ def auto_resume_sanity_check(ckpt_config):
11
+ load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
12
+ if load_given_ckpt is None:
13
+ return True # default value is True
14
+ else:
15
+ return not load_given_ckpt
16
+
17
+
18
+ def ckpt_info_sanity_check(ckpt_config):
19
+ load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
20
+
21
+ load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
22
+
23
+ if load_model_only_folder is not None:
24
+ assert (
25
+ load_ckpt_folder is None
26
+ ), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
27
+ # and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
28
+ return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
29
+ else:
30
+ load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
31
+
32
+ if isinstance(load_ckpt_folder, str):
33
+ if load_optimizer:
34
+ return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
35
+ else:
36
+ return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
37
+ elif load_ckpt_folder is None:
38
+ return None
39
+ else:
40
+ assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
InternLM/internlm/model/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from .embedding import Embedding1D, RotaryEmbedding
5
+ from .linear import FeedForward, RewardModelLinear, ScaleColumnParallelLinear
6
+ from .metrics import AccPerplex
7
+ from .modeling_internlm import build_model_with_cfg
8
+ from .modeling_vit import build_vit_model_with_cfg
9
+ from .multi_head_attention import MHA
10
+ from .utils import gather_forward_split_backward
11
+
12
+ __all__ = [
13
+ "Embedding1D",
14
+ "FeedForward",
15
+ "RotaryEmbedding",
16
+ "RewardModelLinear",
17
+ "ScaleColumnParallelLinear",
18
+ "AccPerplex",
19
+ "MHA",
20
+ "gather_forward_split_backward",
21
+ "build_model_with_cfg",
22
+ "build_vit_model_with_cfg"
23
+ ]
InternLM/internlm/model/embedding.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from typing import Tuple
5
+
6
+ import rotary_emb
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
11
+ from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
12
+ from torch import Tensor, nn
13
+
14
+ from internlm.core.context import ParallelMode
15
+ from internlm.core.context import global_context as gpc
16
+
17
+ from .utils import gather_forward_split_backward, split_forward_gather_backward
18
+
19
+
20
+ from .muse import VQGANModel
21
+
22
+ class Embedding1DLVM(nn.Module):
23
+ def __init__(
24
+ self,
25
+ vq_model_path: str,
26
+ embedding_dim: int = None,
27
+ freeze_vq_model: bool = True
28
+ ):
29
+ super().__init__()
30
+
31
+ self.vq_model = VQGANModel.from_pretrained(vq_model_path)
32
+ if freeze_vq_model:
33
+ self.vq_model.requires_grad_(False)
34
+ self.vq_model.eval()
35
+
36
+ self.num_embeddings, vq_embed_dim = self.vq_model.quantize.embedding.weight.shape
37
+
38
+ if embedding_dim is not None:
39
+ self.embed_proj = nn.Linear(vq_embed_dim, embedding_dim, bias=False)
40
+ self.embedding_dim = embedding_dim
41
+ else:
42
+ self.embed_proj = None
43
+ self.embedding_dim = vq_embed_dim
44
+
45
+ def forward(self, input_: Tensor) -> Tensor:
46
+
47
+ # input: N x seq
48
+ output_parallel = self.vq_model.quantize.get_codebook_entry_for_lvm(input_) # N x vq_embed_dim x sqrt(seq) x sqrt(seq)
49
+
50
+ if self.embed_proj is not None:
51
+ output_parallel = self.embed_proj(output_parallel)
52
+
53
+ output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
54
+
55
+ if gpc.config.parallel.sequence_parallel:
56
+ output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
57
+
58
+ return output
59
+
60
+
61
+ class Embedding1D(nn.Module):
62
+ """
63
+ 1D Embedding.
64
+
65
+ Args:
66
+ num_embeddings (int): The size of vocab.
67
+ embedding_dim (int): The dimention of model.
68
+ padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
69
+ therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
70
+ i.e. it remains as a fixed "pad". None by default.
71
+ dtype (Optional[torch.dtype]): Data type None by default.
72
+
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ num_embeddings: int,
78
+ embedding_dim: int,
79
+ *args,
80
+ padding_idx: int = None,
81
+ dtype: torch.dtype = None,
82
+ **kwargs,
83
+ ):
84
+ super().__init__()
85
+
86
+ self.num_embeddings = num_embeddings
87
+ self.embed_dim = embedding_dim
88
+ embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size
89
+
90
+ self.padding_idx = padding_idx
91
+ self.embed_args = args
92
+ self.embed_kwargs = kwargs
93
+
94
+ self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype))
95
+
96
+ def forward(self, input_: Tensor) -> Tensor:
97
+ output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
98
+
99
+ output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1)
100
+
101
+ if gpc.config.parallel.sequence_parallel:
102
+ output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1)
103
+
104
+ return output
105
+
106
+
107
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
108
+ """
109
+ ApplyRotaryEmbQKV_
110
+ """
111
+
112
+ @staticmethod
113
+ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
114
+ """
115
+ qkv: (total, 3, nheads, headdim)
116
+ cos, sin: (seqlen, rotary_dim / 2)
117
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
118
+ rotary_dim must be <= headdim
119
+ Apply rotary embedding *inplace* to the first rotary_dim of q and k.
120
+ """
121
+ _, three, _, headdim = qkv.shape
122
+ assert three == 3
123
+ rotary_seqlen, rotary_dim = cos.shape
124
+ rotary_dim *= 2
125
+ assert rotary_dim <= headdim
126
+ cos_k = cos if cos_k is None else cos_k
127
+ sin_k = sin if sin_k is None else sin_k
128
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
129
+ q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
130
+ rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
131
+ k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
132
+ rotary_emb.apply_rotary(
133
+ k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
134
+ )
135
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
136
+ return qkv
137
+
138
+ @staticmethod
139
+ def backward(ctx, dqkv):
140
+ cos, sin, cos_k, sin_k = ctx.saved_tensors
141
+ rotary_dim = cos.shape[-1]
142
+ rotary_dim *= 2
143
+ dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
144
+ rotary_emb.apply_rotary(
145
+ dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
146
+ )
147
+ dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
148
+ rotary_emb.apply_rotary(
149
+ dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
150
+ )
151
+ return dqkv, None, None, None, None
152
+
153
+
154
+ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
155
+ legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
156
+ legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
157
+
158
+
159
+ class RotaryEmbedding(torch.nn.Module):
160
+ """
161
+ The rotary position embeddings from RoFormer_ (Su et. al).
162
+ A crucial insight from the method is that the query and keys are
163
+ transformed by rotation matrices which depend on the relative positions.
164
+
165
+ Other implementations are available in the Rotary Transformer repo_ and in
166
+ GPT-NeoX_, GPT-NeoX was an inspiration
167
+
168
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
169
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
170
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
171
+
172
+ If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
173
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
174
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
175
+ """
176
+
177
+ def __init__(self, dim: int, base=10000, scale_base=0, device=None):
178
+ """ """
179
+ super().__init__()
180
+ # Generate and save the inverse frequency buffer (non trainable)
181
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
182
+ self.scale_base = scale_base
183
+ self.scale = (
184
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
185
+ if scale_base > 0
186
+ else None
187
+ )
188
+
189
+ self._seq_len_cached = 0
190
+ self._cos_cached = None
191
+ self._sin_cached = None
192
+ self._cos_k_cached = None
193
+ self._sin_k_cached = None
194
+
195
+ def _update_cos_sin_cache(self, x, indexes):
196
+ """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
197
+ if not isinstance(indexes, int):
198
+ seqlen = indexes.max().item() + 1
199
+ else:
200
+ seqlen = indexes + 1 # eval_forward
201
+ # Reset the tables if the sequence length has changed,
202
+ # or if we're on a new device (possibly due to tracing for instance)
203
+ if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
204
+ self._seq_len_cached = seqlen
205
+ t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
206
+ # Don't do einsum, it converts fp32 to fp16
207
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
208
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device))
209
+ if self.scale is None:
210
+ self._cos_cached = torch.cos(freqs).to(x.dtype)
211
+ self._sin_cached = torch.sin(freqs).to(x.dtype)
212
+ else:
213
+ power = (
214
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
215
+ ) / self.scale_base
216
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
217
+ # We want the multiplication by scale to happen in fp32
218
+ self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
219
+ self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
220
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
221
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
222
+
223
+ def forward(self, qkv: torch.Tensor, **kwargs):
224
+ if kwargs.get("indexes", None) is not None:
225
+ return self._forward(qkv, kwargs.pop("indexes"))
226
+ if kwargs.get("inference_params", None) is not None:
227
+ return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset)
228
+ else:
229
+ return self._eval_forward(qkv)
230
+
231
+ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
232
+ self._update_cos_sin_cache(qkv, indexes)
233
+ if self.scale is None:
234
+ return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes])
235
+ else:
236
+ return apply_rotary_emb_qkv_(
237
+ qkv,
238
+ self._cos_cached[indexes],
239
+ self._sin_cached[indexes],
240
+ self._cos_k_cached[indexes],
241
+ self._sin_k_cached[indexes],
242
+ )
243
+
244
+ def _eval_forward(self, qkv, seqlen_offset=0):
245
+ """
246
+ seqlen_offset: can be used in generation where the qkv being passed in is only the last
247
+ token in the batch.
248
+ """
249
+ self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
250
+ if self.scale is None:
251
+ return legacy_apply_rotary_embed_qkv(
252
+ qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
253
+ )
254
+ else:
255
+ return legacy_apply_rotary_embed_qkv(
256
+ qkv,
257
+ self._cos_cached[seqlen_offset:],
258
+ self._sin_cached[seqlen_offset:],
259
+ self._cos_k_cached[seqlen_offset:],
260
+ self._sin_k_cached[seqlen_offset:],
261
+ )
262
+
263
+ def _single_forward(self, x, indexes=0):
264
+ assert self.scale is None
265
+ self._update_cos_sin_cache(x, indexes)
266
+ x = x[None, ...]
267
+ ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
268
+ return ret
269
+
270
+ def _single_eval_forward(self, x, seqlen_offset=0):
271
+ assert self.scale is None
272
+ self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
273
+ return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
InternLM/internlm/model/linear.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
9
+ from flash_attn.utils.distributed import all_reduce, reduce_scatter
10
+ from torch import nn
11
+
12
+ from internlm.core.context import ParallelMode
13
+ from internlm.core.context import global_context as gpc
14
+ from internlm.model.utils import fused_dense_func_torch
15
+
16
+
17
+ class ScaleColumnParallelLinear(nn.Linear):
18
+ """
19
+ ScaleColumnParallelLinear.
20
+
21
+ Args:
22
+ in_features (int): size of each input sample
23
+ out_features (int): size of each output sample
24
+ process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
25
+ bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
26
+ in the config.
27
+ sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
28
+ we do an all_gather of x before doing the matmul.
29
+ If not, then the input is already gathered.
30
+ device (Optional[Union[str, torch.device]]): The device will be used.
31
+ dtype (Optional[torch.dtype]): The type of data.
32
+ weight_scale (int): For training stability. 1 by default.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ in_features: int,
38
+ out_features: int,
39
+ process_group: Optional[torch.distributed.ProcessGroup],
40
+ bias: bool = True,
41
+ device: Optional[torch.device] = None,
42
+ dtype: Optional[torch.dtype] = None,
43
+ weight_scale: int = 1,
44
+ ) -> None:
45
+ world_size = torch.distributed.get_world_size(process_group)
46
+ if out_features % world_size != 0:
47
+ raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})")
48
+ super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype)
49
+ self.process_group = process_group
50
+ self.weight_scale = weight_scale
51
+
52
+ def forward(self, input): # pylint: disable=W0622
53
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
54
+ # we do an all_gather of x before doing the matmul.
55
+ # If not, then the input is already gathered.
56
+ if self.weight_scale != 1:
57
+ weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
58
+ else:
59
+ weight = self.weight
60
+ return fused_dense_func_torch(
61
+ input,
62
+ weight,
63
+ self.bias,
64
+ process_group=self.process_group,
65
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
66
+ )
67
+
68
+
69
+ class RewardModelLinear(ScaleColumnParallelLinear):
70
+ """
71
+ RewardModelLinear.
72
+ Args:
73
+ in_features (int): size of each input sample
74
+ out_features (int): size of each output sample
75
+ process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
76
+ bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
77
+ in the config.
78
+ sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
79
+ we do an all_gather of x before doing the matmul.
80
+ If not, then the input is already gathered.
81
+ device (Optional[Union[str, torch.device]]): The device will be used.
82
+ dtype (Optional[torch.dtype]): The type of data.
83
+ weight_scale (int): For training stability. 1 by default.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ in_features: int,
89
+ out_features: int,
90
+ process_group: Optional[torch.distributed.ProcessGroup],
91
+ bias: bool = True,
92
+ device: Optional[torch.device] = None,
93
+ dtype: Optional[torch.dtype] = None,
94
+ weight_scale: int = 1,
95
+ ) -> None:
96
+ super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale)
97
+ torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
98
+ if bias:
99
+ torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group)
100
+
101
+ def forward(self, input): # pylint: disable=W0622
102
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
103
+ # we do an all_gather of x before doing the matmul.
104
+ # If not, then the input is already gathered.
105
+ if self.weight_scale != 1:
106
+ weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach()
107
+ else:
108
+ weight = self.weight
109
+ return fused_dense_func_torch(
110
+ input,
111
+ weight,
112
+ self.bias,
113
+ process_group=self.process_group,
114
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
115
+ )
116
+
117
+
118
+ class ColumnParallelLinearTorch(ColumnParallelLinear):
119
+ def forward(self, x):
120
+ # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
121
+ # we do an all_gather of x before doing the matmul.
122
+ # If not, then the input is already gathered.
123
+
124
+ return fused_dense_func_torch(
125
+ x, self.weight, self.bias, process_group=self.process_group, sequence_parallel=self.sequence_parallel
126
+ )
127
+
128
+
129
+ class RowParallelLinearTorch(RowParallelLinear):
130
+ def forward(self, x):
131
+ """
132
+ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
133
+ a reduce_scatter of the result.
134
+ """
135
+ out = fused_dense_func_torch(x, self.weight, self.bias)
136
+ reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
137
+ return reduce_fn(out, self.process_group)
138
+
139
+
140
+ class FeedForward(nn.Module):
141
+ """
142
+ FeedForward.
143
+
144
+ Args:
145
+ in_features (int): size of each input sample
146
+ hidden_features (int): size of hidden state of FFN
147
+ out_features (int): size of each output sample
148
+ process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
149
+ bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
150
+ in the config.
151
+ device (Optional[Union[str, torch.device]]): The device will be used.
152
+ dtype (Optional[torch.dtype]): The type of data.
153
+ multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ in_features: int,
159
+ hidden_features: int,
160
+ out_features: int = None,
161
+ process_group: Optional[torch.distributed.ProcessGroup] = None,
162
+ bias: bool = True,
163
+ device: Optional[torch.device] = None,
164
+ dtype: Optional[torch.dtype] = None,
165
+ multiple_of: int = 256,
166
+ ):
167
+ super().__init__()
168
+
169
+ hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
170
+
171
+ self.w1 = ColumnParallelLinearTorch(
172
+ in_features,
173
+ hidden_features,
174
+ process_group,
175
+ bias,
176
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
177
+ device=device,
178
+ dtype=dtype,
179
+ )
180
+ self.w2 = ColumnParallelLinearTorch(
181
+ in_features,
182
+ hidden_features,
183
+ process_group,
184
+ bias,
185
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
186
+ device=device,
187
+ dtype=dtype,
188
+ )
189
+ self.w3 = RowParallelLinearTorch(
190
+ hidden_features,
191
+ out_features,
192
+ process_group,
193
+ bias=bias,
194
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
195
+ device=device,
196
+ dtype=dtype,
197
+ )
198
+
199
+ def forward(self, x):
200
+ out = self.w3(F.silu(self.w1(x)) * self.w2(x))
201
+ return out
InternLM/internlm/model/loss.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch.nn.functional as F
5
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
6
+ from torch import nn
7
+
8
+ from internlm.core.context import ParallelMode
9
+ from internlm.core.context import global_context as gpc
10
+
11
+
12
+ class FlashGPTLMLoss(nn.Module):
13
+ """
14
+ Loss function for flash GPT Language Model.
15
+ """
16
+
17
+ def __init__(self, parallel_output=True, label_smoothing=0):
18
+ super().__init__()
19
+
20
+ if label_smoothing is not None:
21
+ if label_smoothing != 0:
22
+ if gpc.is_rank_for_log():
23
+ print(f"use label_smoothing: {label_smoothing}")
24
+ else:
25
+ label_smoothing = 0
26
+ self.label_smoothing = label_smoothing
27
+
28
+ if parallel_output:
29
+ self.loss_fn = FlashCrossEntropyLoss(
30
+ reduction="mean",
31
+ inplace_backward=True,
32
+ process_group=gpc.get_group(ParallelMode.TENSOR),
33
+ label_smoothing=label_smoothing,
34
+ ) # The loss in this place is bound to the gather_output initialized by VocabParallelClassifier1D
35
+ else:
36
+ # Here, the output will gather output is set in the model, so use ordinary loss
37
+ self.loss_fn = nn.CrossEntropyLoss(reduction="mean", label_smoothing=label_smoothing)
38
+
39
+ def forward(self, *args):
40
+ if len(args) == 3:
41
+ # residual is to match prenorm
42
+ logits, _, labels = args
43
+ elif len(args) == 2:
44
+ # When using postnorm
45
+ logits, labels = args
46
+ else:
47
+ raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
48
+ shift_logits = logits.contiguous().view(-1, logits.size(-1))
49
+ shift_labels = labels.contiguous().view(-1)
50
+ loss = self.loss_fn(
51
+ shift_logits, shift_labels
52
+ ) # There is no need to consider the ignore_index problem here, because the loss calculation will be
53
+ # calculated through the calculation range, and -100 must be outside this range, so there is no problem
54
+
55
+ return loss
56
+
57
+
58
+ class KLDivLoss(nn.Module):
59
+ def __init__(self):
60
+ super().__init__()
61
+ self.temperature = gpc.config.kd_config.get('temperature', 1)
62
+ self.inverse = gpc.config.kd_config.get('inverse', False)
63
+
64
+ def forward(self, *args):
65
+ if len(args) == 3:
66
+ if self.inverse:
67
+ logits_teacher, logits_student, _ = args
68
+ else:
69
+ logits_student, logits_teacher, _ = args
70
+ else:
71
+ raise RuntimeError(f"The number of criterion inputs are:{len(args)}")
72
+
73
+ logits_teacher = logits_teacher.contiguous().view(-1, logits_teacher.size(-1))
74
+ logits_student = logits_student.contiguous().view(-1, logits_student.size(-1))
75
+
76
+ log_pred_student = F.log_softmax(logits_student / self.temperature, dim=1)
77
+ pred_teacher = F.softmax(logits_teacher / self.temperature, dim=1)
78
+ loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
79
+ loss_kd *= self.temperature ** 2
80
+
81
+ return loss_kd
InternLM/internlm/model/metrics.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
5
+ from torch_scatter import scatter
6
+
7
+ from internlm.core.context import ParallelMode
8
+ from internlm.core.context import global_context as gpc
9
+ from internlm.utils.parallel import is_no_pp_or_last_stage
10
+
11
+
12
+ class AccPerplex:
13
+ """
14
+ AccPerplex module for calculating model's accuracy and perplexity metrics.
15
+
16
+ Args:
17
+ device: The GPU device.
18
+ tp_pg: The tensor parallel process group.
19
+ dp_pg: The data parallel process group.
20
+ tokenizer: For calculating BPB.
21
+ dataset_types (List[str]): Various data types that will be used in the current training process,
22
+ such as ['en', 'cn', 'code']. The order of the List should be consistent with the type_id specified
23
+ in the dataset. Changed parameters need to be used in conjunction with set_current_type_ids().
24
+ """
25
+
26
+ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str] = None):
27
+ self.device = device
28
+ self.right = torch.Tensor([0]).to(device=device)
29
+ self.total = torch.Tensor([0]).to(device=device)
30
+ self.total_log_probs = torch.Tensor([0]).to(device=device)
31
+ self.tp_pg = tp_pg
32
+ self.dp_pg = dp_pg
33
+ self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
34
+ self.tokenizer = tokenizer
35
+ self.total_bytes = torch.Tensor([0]).to(device=device).view(1)
36
+ self.batch_shift = 0
37
+ self.type_ids = None
38
+ if dataset_types is not None:
39
+ self.dataset_types = dataset_types
40
+ self.total_type_count = len(dataset_types)
41
+ self.ds_right = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
42
+ self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device)
43
+
44
+ self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types)
45
+
46
+ def set_current_type_ids(self, type_ids: torch.Tensor):
47
+ self.batch_shift = 0
48
+ self.type_ids = type_ids.cuda()
49
+
50
+ def __call__(self, logits, labels):
51
+ return self.update(logits, labels, type_ids=self.type_ids)
52
+
53
+ def update(self, logits, labels, type_ids=None):
54
+ if gpc.config.model.use_flash_attn:
55
+ micro_bsz = labels.size(0)
56
+ else:
57
+ micro_bsz = 1
58
+ if type_ids is not None:
59
+ type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
60
+ self.batch_shift += 1
61
+ self.loss_with_type_id.update(logits, labels, type_ids)
62
+
63
+ with torch.no_grad():
64
+ if isinstance(logits, (list, tuple)):
65
+ logits = logits[0]
66
+
67
+ logits = logits.detach().clone()
68
+ labels = labels.detach().clone()
69
+
70
+ if self.tokenizer: # need to calculate bits per bytes
71
+ sequences = self.tokenizer.decode_ids(labels.tolist())
72
+ self.total_bytes += sum(map(lambda x: len(x.encode("utf-8")), sequences))
73
+
74
+ shift_logits = logits.view(-1, logits.size(-1))
75
+ shift_labels = labels.view(-1)
76
+ # There is a shift according to the current rank, because the logits are split
77
+ pred_shift = self.tp_local_rank * logits.shape[-1]
78
+
79
+ logits_max = torch.max(shift_logits, dim=-1)[0]
80
+ torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=self.tp_pg)
81
+ # Determine whether the maximum value of the current local tensor is the global maximum value
82
+ logits_global = logits_max == torch.max(shift_logits, dim=-1)[0]
83
+
84
+ corrects = torch.logical_and(
85
+ (shift_labels == (shift_logits.argmax(dim=-1) + pred_shift)), logits_global
86
+ ).long()
87
+ mask = shift_labels.ne(-100).long()
88
+ if hasattr(self, "total_type_count"):
89
+ ds_acc = scatter(corrects, type_ids, dim=0, reduce="sum")
90
+ token_num_type = scatter(mask, type_ids, dim=0, reduce="sum")
91
+ if len(ds_acc) < self.total_type_count:
92
+ ds_acc = torch.cat([ds_acc, ds_acc.new_zeros(self.total_type_count - len(ds_acc))])
93
+ token_num_type = torch.cat(
94
+ [token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
95
+ )
96
+ self.ds_tokens += token_num_type
97
+ sync_tensor = ds_acc
98
+ torch.distributed.all_reduce(sync_tensor, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
99
+ self.ds_right += sync_tensor.view(-1)
100
+
101
+ acc = corrects.sum()
102
+ torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
103
+ self.right += acc # Masked_fill is not needed here because -100 is not available anyway
104
+ self.total += mask.sum()
105
+
106
+ # Subtract the maximum value.
107
+ shift_logits = shift_logits.sub(logits_max.unsqueeze(dim=-1))
108
+
109
+ # Get the partition's vocab indecies
110
+ partition_vocab_size = shift_logits.size()[-1]
111
+ vocab_start_index = partition_vocab_size * self.tp_local_rank
112
+ vocab_end_index = vocab_start_index + partition_vocab_size
113
+
114
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
115
+ target_mask = (shift_labels < vocab_start_index) | (shift_labels >= vocab_end_index)
116
+ masked_target = shift_labels - vocab_start_index
117
+ masked_target[target_mask] = 0
118
+
119
+ # Get predicted-logits = logits[target].
120
+ # For Simplicity, we model_hf logits to a 2-D tensor with size
121
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
122
+ logits_2d = shift_logits.view(-1, partition_vocab_size)
123
+ masked_target_1d = masked_target.view(-1)
124
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
125
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
126
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
127
+ predicted_logits = predicted_logits_1d.view_as(shift_labels) # bsz x max_len
128
+ predicted_logits[target_mask] = 0.0
129
+ # All reduce is needed to get the chunks from other GPUs.
130
+ torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
131
+
132
+ pred_exp_logits = torch.exp(predicted_logits)
133
+ # Sum of exponential of logits along vocab dimension across all GPUs.
134
+ sum_exp_logits = torch.exp(shift_logits).sum(dim=-1)
135
+ torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
136
+
137
+ total_log_probs = -(pred_exp_logits / sum_exp_logits).log().masked_fill(shift_labels.eq(-100), 0).sum()
138
+ self.total_log_probs += total_log_probs
139
+
140
+ def get_metric(self, reset=True):
141
+ if is_no_pp_or_last_stage() and self.dp_pg is not None:
142
+ torch.distributed.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
143
+ torch.distributed.all_reduce(self.total, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
144
+ torch.distributed.all_reduce(self.total_log_probs, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
145
+ if hasattr(self, "total_type_count"):
146
+ torch.distributed.all_reduce(self.ds_right, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
147
+ torch.distributed.all_reduce(self.ds_tokens, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
148
+ if self.tokenizer:
149
+ torch.distributed.all_reduce(self.total_bytes, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
150
+
151
+ acc = round((self.right / self.total).item(), 4)
152
+ perplexity = round(torch.exp(self.total_log_probs / self.total).item(), 4)
153
+ bits_per_bytes = round((self.total_log_probs / self.total_bytes).item(), 4) if self.tokenizer else 0
154
+
155
+ if hasattr(self, "total_type_count"):
156
+ ds_acc = {}
157
+ ds_tokens = {}
158
+ for i in range(self.total_type_count):
159
+ ds_acc[f"acc/{self.dataset_types[i]}"] = round(
160
+ (self.ds_right[i].float() / (self.ds_tokens[i].float() + 1e-5)).item(), 4
161
+ )
162
+ ds_tokens[f"tokens/{self.dataset_types[i]}"] = self.ds_tokens[i].item()
163
+ if reset:
164
+ self.right.fill_(0)
165
+ self.total.fill_(0)
166
+ self.total_log_probs.fill_(0)
167
+ self.total_bytes.fill_(0)
168
+ if hasattr(self, "total_type_count"):
169
+ self.ds_right.fill_(0)
170
+ self.ds_tokens.fill_(0)
171
+ if self.tokenizer is not None:
172
+ res = {"acc": acc, "perplexity": perplexity, "BPB": bits_per_bytes}
173
+ else:
174
+ res = {"acc": acc, "perplexity": perplexity}
175
+ if hasattr(self, "total_type_count"):
176
+ res.update(ds_acc)
177
+ res.update(ds_tokens)
178
+
179
+ loss_res = self.loss_with_type_id.get_metric(reset)
180
+ res.update(loss_res)
181
+
182
+ return res
183
+
184
+
185
+ class LossWithTypeId:
186
+ """
187
+ Notice the loss value computed here may be not the same with the main info loss,
188
+ cause loss here is the reduced result of the data parallel.
189
+ """
190
+
191
+ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
192
+ self.device = device
193
+ self.dp_pg = dp_pg
194
+
195
+ self.loss = torch.Tensor([0.0]).to(device=device)
196
+ self.token_num = torch.Tensor([0.0]).to(device=device)
197
+
198
+ if dataset_types is not None:
199
+ self.dataset_types = dataset_types
200
+ self.total_type_count = len(dataset_types)
201
+ self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
202
+ self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
203
+
204
+ self.loss_fn = FlashCrossEntropyLoss(
205
+ reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
206
+ )
207
+
208
+ def update(self, logits, labels, type_ids=None):
209
+ with torch.no_grad():
210
+ if isinstance(logits, (list, tuple)):
211
+ logits = logits[0]
212
+ logits = logits.contiguous().view(-1, logits.size(-1))
213
+ labels = labels.contiguous().view(-1)
214
+ loss_list = self.loss_fn(logits, labels)
215
+
216
+ cond = labels != -100
217
+ real_loss_list = loss_list[cond]
218
+ self.loss += real_loss_list.sum()
219
+ self.token_num += real_loss_list.numel()
220
+
221
+ if hasattr(self, "total_type_count"):
222
+ type_ids = type_ids.contiguous().view(-1).to(self.device)
223
+ real_type_ids = type_ids[cond]
224
+
225
+ loss_list_type = scatter(real_loss_list, real_type_ids, dim=0, reduce="sum")
226
+ token_num_type = scatter(torch.ones_like(real_loss_list), real_type_ids, dim=0, reduce="sum")
227
+
228
+ if len(loss_list_type) < self.total_type_count:
229
+ loss_list_type = torch.cat(
230
+ [loss_list_type, loss_list_type.new_zeros(self.total_type_count - len(loss_list_type))]
231
+ )
232
+ token_num_type = torch.cat(
233
+ [token_num_type, token_num_type.new_zeros(self.total_type_count - len(token_num_type))]
234
+ )
235
+ self.ds_loss += loss_list_type
236
+ self.ds_token_num += token_num_type
237
+
238
+ def get_metric(self, reset=True):
239
+ if is_no_pp_or_last_stage() and self.dp_pg is not None:
240
+ torch.distributed.all_reduce(self.loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
241
+ torch.distributed.all_reduce(self.token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
242
+ if hasattr(self, "total_type_count"):
243
+ torch.distributed.all_reduce(self.ds_loss, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
244
+ torch.distributed.all_reduce(self.ds_token_num, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg)
245
+
246
+ loss = round((self.loss / self.token_num).item(), 4)
247
+ res = {
248
+ "loss_from_metric": loss,
249
+ }
250
+ if hasattr(self, "total_type_count"):
251
+ ds_loss = {}
252
+ for i in range(self.total_type_count):
253
+ ds_loss[f"loss/{self.dataset_types[i]}"] = round((self.ds_loss[i] / self.ds_token_num[i]).item(), 4)
254
+ res.update(ds_loss)
255
+
256
+ if reset:
257
+ self.loss.fill_(0.0)
258
+ self.token_num.fill_(0.0)
259
+ if hasattr(self, "total_type_count"):
260
+ self.ds_loss.fill_(0.0)
261
+ self.ds_token_num.fill_(0.0)
262
+
263
+ return res
InternLM/internlm/model/modeling_internlm.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import math
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from flash_attn.modules.embedding import ParallelGPT2Embeddings
9
+ from flash_attn.modules.mlp import ParallelFusedMLP
10
+ from torch import nn
11
+
12
+ from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
13
+ from internlm.core.context.parallel_context import global_context as gpc
14
+ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
15
+ from internlm.model.embedding import Embedding1D, Embedding1DLVM
16
+ from internlm.model.linear import (
17
+ FeedForward,
18
+ RewardModelLinear,
19
+ ScaleColumnParallelLinear,
20
+ )
21
+ from internlm.model.multi_head_attention import MHA
22
+ from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm
23
+ from internlm.solver.pipeline_utils import partition_uniform
24
+ from internlm.utils.checkpoint import activation_checkpoint
25
+ from internlm.utils.common import filter_kwargs
26
+ from internlm.utils.logger import get_logger
27
+ from internlm.utils.registry import MODEL_INITIALIZER
28
+
29
+ MODEL_TYPE = "INTERNLM"
30
+
31
+ logger = get_logger(__file__)
32
+ RMSNorm = try_import_RMSNorm()
33
+
34
+
35
+ class PackedFlashBaseLayer1D(nn.Module):
36
+ """
37
+ 1D Packed Flash Base Layer.
38
+
39
+ Args:
40
+ hidden_size (int): The hidden size of model. 768 by default.
41
+ num_attention_heads (int): The number of attention heads. 12 by default.
42
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
43
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
44
+ drop_rate (float): The dropout rate of the input hidden state. 0.0 by default.
45
+ dtype (torch.dtype): Type of data. torch.float by default.
46
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
47
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
48
+ layer_idx (int): The index of current layer. 0 by default.
49
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
50
+ device (Optional[Union[str, torch.device]]): The device will be used.
51
+ norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
52
+ use_flash_attn (bool): Whether use flash-attn. True by default.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ hidden_size: int = 768,
58
+ num_attention_heads: int = 12,
59
+ mlp_ratio: int = 4,
60
+ attn_drop_rate: float = 0,
61
+ drop_rate: float = 0.0,
62
+ dtype: torch.dtype = torch.float,
63
+ layer_norm_epsilon: float = 1e-6,
64
+ checkpoint: bool = False,
65
+ layer_idx: int = 0,
66
+ residual_in_fp32: bool = False,
67
+ device: Optional[torch.device] = None,
68
+ norm_type: str = "rmsnorm",
69
+ dropout_selective_checkpoint: bool = True,
70
+ use_scaled_init: bool = True,
71
+ use_swiglu: bool = True,
72
+ use_flash_attn: bool = True,
73
+ ):
74
+ super().__init__()
75
+ self.checkpoint = checkpoint
76
+ # dropout selective checkpoint can only be enabled when checkpoint is disabled.
77
+ self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
78
+ self.layer_idx = layer_idx
79
+ self.use_flash_attn = use_flash_attn
80
+
81
+ head_dim = hidden_size // num_attention_heads
82
+ self.mixer = MHA(
83
+ embed_dim=hidden_size,
84
+ num_heads=num_attention_heads,
85
+ process_group=gpc.get_group(ParallelMode.TENSOR),
86
+ dropout=attn_drop_rate,
87
+ softmax_scale=1 / math.sqrt(head_dim),
88
+ causal=True,
89
+ layer_idx=layer_idx,
90
+ rotary_emb_dim=head_dim,
91
+ rotary_emb_scale_base=0,
92
+ use_flash_attn=use_flash_attn,
93
+ device=device,
94
+ dtype=dtype,
95
+ )
96
+
97
+ self.dropout1 = nn.Dropout(drop_rate)
98
+ if norm_type == "rmsnorm":
99
+ self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
100
+ self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
101
+ else:
102
+ self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
103
+ self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
104
+
105
+ if use_swiglu:
106
+ self.mlp = FeedForward(
107
+ hidden_size,
108
+ int(hidden_size * mlp_ratio),
109
+ out_features=hidden_size,
110
+ process_group=gpc.get_group(ParallelMode.TENSOR),
111
+ bias=False,
112
+ device=device,
113
+ dtype=dtype,
114
+ )
115
+ else:
116
+ self.mlp = ParallelFusedMLP(
117
+ hidden_size,
118
+ int(hidden_size * mlp_ratio),
119
+ out_features=hidden_size,
120
+ activation="gelu_approx",
121
+ process_group=gpc.get_group(ParallelMode.TENSOR),
122
+ bias1=False,
123
+ bias2=False,
124
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
125
+ checkpoint_lvl=0,
126
+ heuristic="auto",
127
+ device=device,
128
+ dtype=dtype,
129
+ )
130
+ for _, param in self.mlp.named_parameters():
131
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
132
+ setattr(param, IS_TENSOR_PARALLEL, True)
133
+ self.dropout2 = nn.Dropout(drop_rate)
134
+ self.use_swiglu = use_swiglu
135
+ self.use_scaled_init = use_scaled_init
136
+ self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
137
+ self.return_residual = False
138
+ self.reset_parameters()
139
+
140
+ def reset_parameters(self):
141
+ with torch.no_grad():
142
+ for name, param in self.mixer.named_parameters():
143
+ if param.ndim == 1:
144
+ param.data.zero_()
145
+ elif "Wqkv" in name:
146
+ normal_(std=0.006)(param.data)
147
+ elif self.use_scaled_init:
148
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
149
+ else:
150
+ normal_(std=0.0015)(param.data)
151
+
152
+ for name, param in self.mlp.named_parameters():
153
+ if param.ndim == 1 and "bias" in name:
154
+ param.data.zero_()
155
+ elif self.use_swiglu:
156
+ if self.use_scaled_init and "w2" in name:
157
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
158
+ else:
159
+ normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
160
+ else:
161
+ if self.use_scaled_init and "fc1" not in name:
162
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
163
+ else:
164
+ normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
165
+
166
+ def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
167
+ if self.checkpoint and self.training:
168
+ return activation_checkpoint(
169
+ self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
170
+ )
171
+ else:
172
+ return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
173
+
174
+ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
175
+ r"""Pass the input through the encoder layer.
176
+
177
+ Args:
178
+ hidden_states: the sequence to the encoder layer (required).
179
+ residual: hidden_states = Attn/MLP(LN(residual))
180
+ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
181
+ indexes: the length of index is same as hidden states, which stand for the current position
182
+ """
183
+ mixer_kwargs = {
184
+ "cu_seqlens": cu_seqlens,
185
+ "max_seqlen": max_seqlen,
186
+ "indexes": indexes,
187
+ "inference_params": inference_params,
188
+ }
189
+
190
+ def _dropout_and_norm_attn(_hidden_states):
191
+ _dropped = self.dropout1(_hidden_states)
192
+ _residual = _dropped
193
+ _hidden_states = self.norm1(_residual.float())
194
+ return _residual, _hidden_states
195
+
196
+ if self.dropout_selective_checkpoint:
197
+ residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, hidden_states)
198
+ else:
199
+ residual, hidden_states = _dropout_and_norm_attn(hidden_states)
200
+
201
+ if self.residual_in_fp32:
202
+ residual = residual.to(torch.float32)
203
+
204
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
205
+
206
+ def _dropout_and_norm_ffn(_residual, _hidden_states):
207
+ _dropped = self.dropout2(_hidden_states)
208
+ _residual = (_dropped + _residual) if _residual is not None else _dropped
209
+ _hidden_states = self.norm2(_residual.float())
210
+ return _residual, _hidden_states
211
+
212
+ if self.dropout_selective_checkpoint:
213
+ residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states)
214
+ else:
215
+ residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states)
216
+
217
+ if self.residual_in_fp32:
218
+ residual = residual.to(torch.float32)
219
+
220
+ hidden_states = self.mlp(hidden_states)
221
+
222
+ return hidden_states + residual
223
+
224
+
225
+ class PackedFlashInternLm1D(nn.Module):
226
+ """
227
+ 1D Packed Flash InternLm.
228
+
229
+ Args:
230
+ num_layers (int): The number of layer. 12 by default.
231
+ hidden_size (int): The size of hidden state. 768 by default.
232
+ num_attention_heads (int): The number of attention head. 12 by default.
233
+ vocab_size (int): The size of vocabulary. 50304 by default.
234
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
235
+ attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
236
+ drop_rate (float): The dropout rate of input hidden state. 0.0 by default.
237
+ dtype (torch.dtype): The type of data. torch.float by default.
238
+ checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
239
+ of layers. 0.0 by default.
240
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
241
+ first (bool): Whether input embedding layer or not. False by default.
242
+ last (bool): Whether output embedding layer or not. False by default.
243
+ embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
244
+ True by default.
245
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
246
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
247
+ start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
248
+ device (Optional[Union[str, torch.device]]): The device will be used. None by default.
249
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
250
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
251
+ use_flash_attn (bool): Whether to use flash-attn. True by default.
252
+
253
+ """
254
+
255
+ def __init__(
256
+ self,
257
+ num_layers: int = 12,
258
+ hidden_size: int = 768,
259
+ num_attention_heads: int = 12,
260
+ vocab_size: int = 50304,
261
+ mlp_ratio: int = 4.0,
262
+ attn_drop_rate: float = 0.0,
263
+ drop_rate: float = 0.0,
264
+ dtype: torch.dtype = torch.float,
265
+ checkpoint: float = 0.0,
266
+ layer_norm_epsilon: float = 1e-5,
267
+ first: bool = False,
268
+ last: bool = False,
269
+ embed_split_hidden: bool = False,
270
+ embed_grad_scale: float = 0.1,
271
+ parallel_output: bool = True,
272
+ start_layer_idx: int = 0,
273
+ device: Optional[torch.device] = None,
274
+ residual_in_fp32: bool = False,
275
+ norm_type: str = "rmsnorm",
276
+ is_reward: bool = False,
277
+ dropout_selective_checkpoint: bool = True,
278
+ use_scaled_init: bool = True,
279
+ use_swiglu: bool = True,
280
+ use_flash_attn: bool = True,
281
+ lvm_config: dict = None,
282
+ ):
283
+ super().__init__()
284
+ self.lvm_config = lvm_config
285
+
286
+ checkpoint_layer_num = int(num_layers * checkpoint)
287
+
288
+ if is_reward:
289
+ head_cls = RewardModelLinear
290
+ else:
291
+ head_cls = ScaleColumnParallelLinear
292
+ if first:
293
+ if self.lvm_config.get('enable', False):
294
+ self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg'))
295
+ if self.embedding.embed_proj is not None:
296
+ for _, param in self.embedding.embed_proj.named_parameters():
297
+ normal_(std=0.0052)(param)
298
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
299
+ setattr(param, IS_TENSOR_PARALLEL, True)
300
+ else:
301
+ if embed_split_hidden:
302
+ self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
303
+ else:
304
+ self.embedding = ParallelGPT2Embeddings(
305
+ embed_dim=hidden_size,
306
+ vocab_size=vocab_size,
307
+ max_position_embeddings=-1,
308
+ process_group=gpc.get_group(ParallelMode.TENSOR),
309
+ padding_idx=None,
310
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+ for _, param in self.embedding.named_parameters():
315
+ normal_(std=0.0052)(param)
316
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
317
+ setattr(param, IS_TENSOR_PARALLEL, True)
318
+ self.embed_grad_scale = embed_grad_scale
319
+ self.blocks = nn.ModuleList(
320
+ [
321
+ PackedFlashBaseLayer1D(
322
+ hidden_size=hidden_size,
323
+ num_attention_heads=num_attention_heads,
324
+ mlp_ratio=mlp_ratio,
325
+ attn_drop_rate=attn_drop_rate,
326
+ drop_rate=drop_rate,
327
+ dtype=dtype,
328
+ layer_norm_epsilon=layer_norm_epsilon,
329
+ checkpoint=lid < checkpoint_layer_num,
330
+ layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
331
+ residual_in_fp32=residual_in_fp32,
332
+ device=device,
333
+ norm_type=norm_type,
334
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
335
+ use_scaled_init=use_scaled_init,
336
+ use_swiglu=use_swiglu,
337
+ use_flash_attn=use_flash_attn,
338
+ )
339
+ for lid in range(num_layers)
340
+ ]
341
+ )
342
+ if last:
343
+ if norm_type == "rmsnorm":
344
+ self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
345
+ else:
346
+ self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
347
+ self.head = head_cls(
348
+ in_features=hidden_size,
349
+ out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
350
+ process_group=gpc.get_group(ParallelMode.TENSOR),
351
+ bias=False,
352
+ device=device,
353
+ dtype=dtype,
354
+ weight_scale=embed_grad_scale,
355
+ )
356
+ for _, param in self.head.named_parameters():
357
+ normal_(std=0.0052)(param)
358
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
359
+ setattr(param, IS_TENSOR_PARALLEL, True)
360
+ self.parallel_output = parallel_output
361
+
362
+ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
363
+ # attention_mask: compute attention on the places where the value is 1
364
+ if hasattr(self, "embedding"):
365
+ hidden_states = self.embedding(input_ids)
366
+ if self.embed_grad_scale != 1:
367
+ hidden_states = (
368
+ self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
369
+ )
370
+ if isinstance(cu_seqlens, list):
371
+ assert len(cu_seqlens) == 1
372
+ cu_seqlens = cu_seqlens[0].to(hidden_states.device)
373
+
374
+ if cu_seqlens is not None:
375
+ cu_seqlens = cu_seqlens.squeeze(0)
376
+ hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
377
+ # the batch dimension with a size of 1 should be directly squeezed off.
378
+
379
+ if indexes is not None:
380
+ assert len(indexes) == 1
381
+ # The indexes are used to indicate the actual position IDs of each token in the packed input.
382
+ indexes = indexes[0]
383
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
384
+
385
+ for _, block in enumerate(self.blocks):
386
+ hidden_states = block(
387
+ hidden_states,
388
+ cu_seqlens=cu_seqlens,
389
+ indexes=indexes,
390
+ inference_params=inference_params,
391
+ max_seqlen=max_seqlen,
392
+ )
393
+
394
+ if hasattr(self, "norm"):
395
+ hidden_states = self.norm(hidden_states.float())
396
+ if hasattr(self, "head"):
397
+ hidden_states = self.head(hidden_states)
398
+
399
+ if not self.parallel_output:
400
+ hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
401
+ return hidden_states
402
+
403
+
404
+ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
405
+ """
406
+ build generic model 1d
407
+
408
+ Args:
409
+ num_layers (int): The number of layer.
410
+ num_chunks (int): The number of partitions in pipeline parallel.
411
+ device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
412
+
413
+ """
414
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
415
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
416
+
417
+ all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
418
+ parts = all_parts[pipeline_rank]
419
+ if gpc.is_rank_for_log():
420
+ logger.info(f"The layer sharding is {all_parts}.")
421
+
422
+ models = []
423
+
424
+ for start, end in parts:
425
+ kwargs["num_layers"] = end - start
426
+ kwargs["first"] = start == 0
427
+ # If there is no content in the final layer, assign the last layer.
428
+ kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
429
+ kwargs["device"] = device
430
+ kwargs["start_layer_idx"] = start
431
+ chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
432
+
433
+ models.append(chunk)
434
+ torch.distributed.barrier()
435
+ if len(models) == 1:
436
+ model = models[0]
437
+ else:
438
+ model = nn.ModuleList(models)
439
+
440
+ return model
441
+
442
+
443
+ @MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
444
+ def build_model_with_cfg(
445
+ num_chunks=1,
446
+ checkpoint=0.0,
447
+ dtype=torch.float,
448
+ embed_split_hidden=False,
449
+ num_layers=48,
450
+ hidden_size=2048,
451
+ vocab_size=50304,
452
+ embed_grad_scale=1,
453
+ parallel_output=True,
454
+ num_attention_heads=32,
455
+ mlp_ratio=4.0,
456
+ residual_in_fp32=False,
457
+ norm_type="rmsnorm",
458
+ drop_rate=0,
459
+ attn_drop_rate=0,
460
+ apply_post_layer_norm=False, # pylint: disable=W0613
461
+ layer_norm_epsilon=1e-5,
462
+ is_reward=False,
463
+ dropout_selective_checkpoint=True,
464
+ use_scaled_init: bool = True,
465
+ use_swiglu: bool = True,
466
+ use_flash_attn: bool = True,
467
+ lvm_config=None,
468
+ ):
469
+ """
470
+ Build model with config.
471
+
472
+ Args:
473
+ num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
474
+ checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
475
+ dtype (torch.dtype): The type of data. torch.float by default.
476
+ embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
477
+ False by default.
478
+ num_layers (int): The number of layer. 48 by default.
479
+ hidden_size (int): The size of hidden state. 2048 by default.
480
+ vocab_size (int): The size of vocabulary. 50304 by default.
481
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
482
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
483
+ num_attention_heads (int): The number of attention head. 32 by default.
484
+ mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
485
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
486
+ because this parameter requires inconsistent data types to be passed between pipelines,
487
+ which requires significant modifications to internlm.
488
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
489
+ drop_rate (float): The dropout rate of input hidden state. 0 by default.
490
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
491
+ apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
492
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
493
+ is_reward (bool): Whether to use reward model. False by default.
494
+ dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
495
+ use_scaled_init (bool): Whether to use scaled init. True by default.
496
+ use_swiglu (bool): Whether to use swiglu. True by default.
497
+ use_flash_attn (bool): Whether to use flash-attn. True by default.
498
+
499
+ """
500
+
501
+ cfg = dict(
502
+ hidden_size=hidden_size,
503
+ num_attention_heads=num_attention_heads,
504
+ checkpoint=checkpoint,
505
+ dtype=dtype,
506
+ embed_split_hidden=embed_split_hidden,
507
+ vocab_size=vocab_size,
508
+ embed_grad_scale=embed_grad_scale,
509
+ parallel_output=parallel_output,
510
+ mlp_ratio=mlp_ratio,
511
+ residual_in_fp32=residual_in_fp32,
512
+ norm_type=norm_type,
513
+ drop_rate=drop_rate,
514
+ attn_drop_rate=attn_drop_rate,
515
+ layer_norm_epsilon=layer_norm_epsilon,
516
+ is_reward=is_reward,
517
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
518
+ use_scaled_init=use_scaled_init,
519
+ use_swiglu=use_swiglu,
520
+ use_flash_attn=use_flash_attn,
521
+ lvm_config=lvm_config,
522
+ )
523
+
524
+ return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
InternLM/internlm/model/modeling_vit.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import math
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from flash_attn.modules.embedding import ParallelGPT2Embeddings
9
+ from flash_attn.modules.mlp import ParallelFusedMLP
10
+ from torch import nn
11
+
12
+ from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
13
+ from internlm.core.context.parallel_context import global_context as gpc
14
+ from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
15
+ from internlm.model.embedding import Embedding1D, Embedding1DLVM
16
+ from internlm.model.linear import (
17
+ FeedForward,
18
+ RewardModelLinear,
19
+ ScaleColumnParallelLinear,
20
+ )
21
+ from internlm.model.multi_head_attention import MHA
22
+ from internlm.model.utils import gather_forward_split_backward, try_import_RMSNorm, try_import_LayerNorm
23
+ from internlm.solver.pipeline_utils import partition_uniform
24
+ from internlm.utils.checkpoint import activation_checkpoint
25
+ from internlm.utils.common import filter_kwargs
26
+ from internlm.utils.logger import get_logger
27
+ from internlm.utils.registry import MODEL_INITIALIZER
28
+
29
+ MODEL_TYPE = "ViT"
30
+
31
+ logger = get_logger(__file__)
32
+ RMSNorm = try_import_RMSNorm()
33
+ LayerNorm = try_import_LayerNorm()
34
+
35
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
36
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37
+
38
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
39
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
40
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
41
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
42
+ 'survival rate' as the argument.
43
+
44
+ """
45
+ if drop_prob == 0. or not training:
46
+ return x
47
+ keep_prob = 1 - drop_prob
48
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
49
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
50
+ if keep_prob > 0.0 and scale_by_keep:
51
+ random_tensor.div_(keep_prob)
52
+ return x * random_tensor
53
+
54
+
55
+ class DropPath(nn.Module):
56
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
57
+ """
58
+ def __init__(self, drop_prob=None, scale_by_keep=True):
59
+ super(DropPath, self).__init__()
60
+ self.drop_prob = drop_prob
61
+ self.scale_by_keep = scale_by_keep
62
+
63
+ def forward(self, x):
64
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
65
+
66
+
67
+ class PackedFlashBaseLayer1D(nn.Module):
68
+ """
69
+ 1D Packed Flash Base Layer.
70
+
71
+ Args:
72
+ hidden_size (int): The hidden size of model. 768 by default.
73
+ num_attention_heads (int): The number of attention heads. 12 by default.
74
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
75
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
76
+ drop_path_rate (float): The drop path rate of the input hidden state. 0.0 by default.
77
+ dtype (torch.dtype): Type of data. torch.float by default.
78
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
79
+ checkpoint (bool): Whether to use checkpointing to save VRAM. True by default.
80
+ layer_idx (int): The index of current layer. 0 by default.
81
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
82
+ device (Optional[Union[str, torch.device]]): The device will be used.
83
+ norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
84
+ use_flash_attn (bool): Whether use flash-attn. True by default.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ hidden_size: int = 768,
90
+ num_attention_heads: int = 12,
91
+ mlp_ratio: int = 4,
92
+ mlp_bias: bool = False,
93
+ attn_drop_rate: float = 0,
94
+ drop_path_rate: float = 0.0,
95
+ dtype: torch.dtype = torch.float,
96
+ layer_norm_epsilon: float = 1e-6,
97
+ checkpoint: bool = False,
98
+ layer_idx: int = 0,
99
+ residual_in_fp32: bool = False,
100
+ device: Optional[torch.device] = None,
101
+ norm_type: str = "rmsnorm",
102
+ dropout_selective_checkpoint: bool = True,
103
+ use_scaled_init: bool = True,
104
+ use_swiglu: bool = True,
105
+ use_flash_attn: bool = True,
106
+ ):
107
+ super().__init__()
108
+ self.checkpoint = checkpoint
109
+ # dropout selective checkpoint can only be enabled when checkpoint is disabled.
110
+ self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
111
+ self.layer_idx = layer_idx
112
+ self.use_flash_attn = use_flash_attn
113
+
114
+ head_dim = hidden_size // num_attention_heads
115
+ self.mixer = MHA(
116
+ embed_dim=hidden_size,
117
+ num_heads=num_attention_heads,
118
+ process_group=gpc.get_group(ParallelMode.TENSOR),
119
+ dropout=attn_drop_rate,
120
+ softmax_scale=1 / math.sqrt(head_dim),
121
+ causal=True,
122
+ layer_idx=layer_idx,
123
+ rotary_emb_dim=head_dim,
124
+ rotary_emb_scale_base=0,
125
+ use_flash_attn=use_flash_attn,
126
+ device=device,
127
+ dtype=dtype,
128
+ )
129
+
130
+ self.dropout1 = DropPath(drop_path_rate)
131
+ if norm_type == "rmsnorm":
132
+ self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
133
+ self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon)
134
+ else:
135
+ self.norm1 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
136
+ self.norm2 = LayerNorm(hidden_size, eps=layer_norm_epsilon)
137
+
138
+ self.mlp = ParallelFusedMLP(
139
+ hidden_size,
140
+ int(hidden_size * mlp_ratio),
141
+ out_features=hidden_size,
142
+ activation="gelu_approx",
143
+ process_group=gpc.get_group(ParallelMode.TENSOR),
144
+ bias1=mlp_bias,
145
+ bias2=mlp_bias,
146
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
147
+ checkpoint_lvl=0,
148
+ heuristic="auto",
149
+ device=device,
150
+ dtype=dtype,
151
+ )
152
+ for _, param in self.mlp.named_parameters():
153
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
154
+ setattr(param, IS_TENSOR_PARALLEL, True)
155
+ self.dropout2 = DropPath(drop_path_rate)
156
+ self.use_swiglu = use_swiglu
157
+ self.use_scaled_init = use_scaled_init
158
+ self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm
159
+ self.return_residual = False
160
+ self.reset_parameters()
161
+
162
+ def reset_parameters(self):
163
+ with torch.no_grad():
164
+ for name, param in self.mixer.named_parameters():
165
+ if param.ndim == 1:
166
+ param.data.zero_()
167
+ elif "Wqkv" in name:
168
+ normal_(std=0.006)(param.data)
169
+ elif self.use_scaled_init:
170
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
171
+ else:
172
+ normal_(std=0.0015)(param.data)
173
+
174
+ for name, param in self.mlp.named_parameters():
175
+ if param.ndim == 1 and "bias" in name:
176
+ param.data.zero_()
177
+ elif self.use_swiglu:
178
+ if self.use_scaled_init and "w2" in name:
179
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
180
+ else:
181
+ normal_(std=0.006 if "w1" in name or "w2" in name else 0.0015)(param.data)
182
+ else:
183
+ if self.use_scaled_init and "fc1" not in name:
184
+ scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
185
+ else:
186
+ normal_(std=0.006 if "fc1" in name else 0.0015)(param.data)
187
+
188
+ def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
189
+ if self.checkpoint and self.training:
190
+ return activation_checkpoint(
191
+ self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen
192
+ )
193
+ else:
194
+ return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen)
195
+
196
+ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None):
197
+ r"""Pass the input through the encoder layer.
198
+
199
+ Args:
200
+ hidden_states: the sequence to the encoder layer (required).
201
+ residual: hidden_states = Attn/MLP(LN(residual))
202
+ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1
203
+ indexes: the length of index is same as hidden states, which stand for the current position
204
+ """
205
+ mixer_kwargs = {
206
+ "cu_seqlens": cu_seqlens,
207
+ "max_seqlen": max_seqlen,
208
+ "indexes": indexes,
209
+ "inference_params": inference_params,
210
+ }
211
+
212
+ residual = hidden_states
213
+
214
+ hidden_states = self.norm1(residual.float())
215
+ hidden_states = self.mixer(hidden_states, **mixer_kwargs)
216
+ hidden_states = self.dropout1(hidden_states)
217
+
218
+ residual = residual + hidden_states
219
+
220
+ hidden_states = self.norm2(residual.float())
221
+ hidden_states = self.mlp(hidden_states)
222
+ hidden_states = self.dropout2(hidden_states)
223
+
224
+ return hidden_states + residual
225
+
226
+
227
+ class PackedFlashInternLm1D(nn.Module):
228
+ """
229
+ 1D Packed Flash InternLm.
230
+
231
+ Args:
232
+ num_layers (int): The number of layer. 12 by default.
233
+ hidden_size (int): The size of hidden state. 768 by default.
234
+ num_attention_heads (int): The number of attention head. 12 by default.
235
+ vocab_size (int): The size of vocabulary. 50304 by default.
236
+ mlp_ratio (int): The ratio of MLP layers. 4 by default.
237
+ attn_drop_rate (float): The dropout rate of attention module. 0.0 by default.
238
+ drop_path_rate (float): The drop path rate of input hidden state. 0.0 by default.
239
+ dtype (torch.dtype): The type of data. torch.float by default.
240
+ checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number
241
+ of layers. 0.0 by default.
242
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default.
243
+ first (bool): Whether input embedding layer or not. False by default.
244
+ last (bool): Whether output embedding layer or not. False by default.
245
+ embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
246
+ True by default.
247
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
248
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
249
+ start_layer_idx (int): The index of start layer in the pipeline. 0 by default.
250
+ device (Optional[Union[str, torch.device]]): The device will be used. None by default.
251
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
252
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
253
+ use_flash_attn (bool): Whether to use flash-attn. True by default.
254
+
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ num_layers: int = 12,
260
+ hidden_size: int = 768,
261
+ num_attention_heads: int = 12,
262
+ vocab_size: int = 50304,
263
+ mlp_ratio: int = 4.0,
264
+ mlp_bias: bool = False,
265
+ attn_drop_rate: float = 0.0,
266
+ drop_path_rate: float = 0.0,
267
+ dtype: torch.dtype = torch.float,
268
+ checkpoint: float = 0.0,
269
+ layer_norm_epsilon: float = 1e-5,
270
+ first: bool = False,
271
+ last: bool = False,
272
+ embed_split_hidden: bool = False,
273
+ embed_grad_scale: float = 0.1,
274
+ parallel_output: bool = True,
275
+ start_layer_idx: int = 0,
276
+ device: Optional[torch.device] = None,
277
+ residual_in_fp32: bool = False,
278
+ norm_type: str = "rmsnorm",
279
+ is_reward: bool = False,
280
+ dropout_selective_checkpoint: bool = True,
281
+ use_scaled_init: bool = True,
282
+ use_swiglu: bool = True,
283
+ use_flash_attn: bool = True,
284
+ lvm_config: dict = None,
285
+ ):
286
+ super().__init__()
287
+ self.lvm_config = lvm_config
288
+
289
+ checkpoint_layer_num = int(num_layers * checkpoint)
290
+
291
+ head_cls = ScaleColumnParallelLinear
292
+ if first:
293
+ if self.lvm_config.get('enable', False):
294
+ self.embedding = Embedding1DLVM(**self.lvm_config.get('embedding_cfg'))
295
+ if self.embedding.embed_proj is not None:
296
+ for _, param in self.embedding.embed_proj.named_parameters():
297
+ normal_(std=0.0052)(param)
298
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
299
+ setattr(param, IS_TENSOR_PARALLEL, True)
300
+ else:
301
+ if embed_split_hidden:
302
+ self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
303
+ else:
304
+ self.embedding = ParallelGPT2Embeddings(
305
+ embed_dim=hidden_size,
306
+ vocab_size=vocab_size,
307
+ max_position_embeddings=-1,
308
+ process_group=gpc.get_group(ParallelMode.TENSOR),
309
+ padding_idx=None,
310
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
311
+ device=device,
312
+ dtype=dtype,
313
+ )
314
+ for _, param in self.embedding.named_parameters():
315
+ normal_(std=0.0052)(param)
316
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
317
+ setattr(param, IS_TENSOR_PARALLEL, True)
318
+ self.embed_grad_scale = embed_grad_scale
319
+ self.blocks = nn.ModuleList(
320
+ [
321
+ PackedFlashBaseLayer1D(
322
+ hidden_size=hidden_size,
323
+ num_attention_heads=num_attention_heads,
324
+ mlp_ratio=mlp_ratio,
325
+ mlp_bias=mlp_bias,
326
+ attn_drop_rate=attn_drop_rate,
327
+ drop_path_rate=drop_path_rate,
328
+ dtype=dtype,
329
+ layer_norm_epsilon=layer_norm_epsilon,
330
+ checkpoint=lid < checkpoint_layer_num,
331
+ layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
332
+ residual_in_fp32=residual_in_fp32,
333
+ device=device,
334
+ norm_type=norm_type,
335
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
336
+ use_scaled_init=use_scaled_init,
337
+ use_swiglu=use_swiglu,
338
+ use_flash_attn=use_flash_attn,
339
+ )
340
+ for lid in range(num_layers)
341
+ ]
342
+ )
343
+ if last:
344
+ if norm_type == "rmsnorm":
345
+ self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon)
346
+ else:
347
+ self.norm = LayerNorm(hidden_size, eps=layer_norm_epsilon)
348
+ self.head = head_cls(
349
+ in_features=hidden_size,
350
+ out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size,
351
+ process_group=gpc.get_group(ParallelMode.TENSOR),
352
+ bias=False,
353
+ device=device,
354
+ dtype=dtype,
355
+ weight_scale=embed_grad_scale,
356
+ )
357
+ for _, param in self.head.named_parameters():
358
+ normal_(std=0.0052)(param)
359
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
360
+ setattr(param, IS_TENSOR_PARALLEL, True)
361
+ self.parallel_output = parallel_output
362
+
363
+ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
364
+ # attention_mask: compute attention on the places where the value is 1
365
+ if hasattr(self, "embedding"):
366
+ hidden_states = self.embedding(input_ids)
367
+ if self.embed_grad_scale != 1:
368
+ hidden_states = (
369
+ self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach()
370
+ )
371
+ if isinstance(cu_seqlens, list):
372
+ assert len(cu_seqlens) == 1
373
+ cu_seqlens = cu_seqlens[0].to(hidden_states.device)
374
+
375
+ if cu_seqlens is not None:
376
+ cu_seqlens = cu_seqlens.squeeze(0)
377
+ hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
378
+ # the batch dimension with a size of 1 should be directly squeezed off.
379
+
380
+ if indexes is not None:
381
+ assert len(indexes) == 1
382
+ # The indexes are used to indicate the actual position IDs of each token in the packed input.
383
+ indexes = indexes[0]
384
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
385
+
386
+ for _, block in enumerate(self.blocks):
387
+ hidden_states = block(
388
+ hidden_states,
389
+ cu_seqlens=cu_seqlens,
390
+ indexes=indexes,
391
+ inference_params=inference_params,
392
+ max_seqlen=max_seqlen,
393
+ )
394
+
395
+ if hasattr(self, "norm"):
396
+ hidden_states = self.norm(hidden_states.float())
397
+ if hasattr(self, "head"):
398
+ hidden_states = self.head(hidden_states)
399
+
400
+ if not self.parallel_output:
401
+ hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1)
402
+ return hidden_states
403
+
404
+
405
+ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
406
+ """
407
+ build generic model 1d
408
+
409
+ Args:
410
+ num_layers (int): The number of layer.
411
+ num_chunks (int): The number of partitions in pipeline parallel.
412
+ device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default.
413
+
414
+ """
415
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
416
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
417
+
418
+ all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
419
+ parts = all_parts[pipeline_rank]
420
+ if gpc.is_rank_for_log():
421
+ logger.info(f"The layer sharding is {all_parts}.")
422
+
423
+ models = []
424
+
425
+ for start, end in parts:
426
+ kwargs["num_layers"] = end - start
427
+ kwargs["first"] = start == 0
428
+ # If there is no content in the final layer, assign the last layer.
429
+ kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0
430
+ kwargs["device"] = device
431
+ kwargs["start_layer_idx"] = start
432
+ chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device)
433
+
434
+ models.append(chunk)
435
+ torch.distributed.barrier()
436
+ if len(models) == 1:
437
+ model = models[0]
438
+ else:
439
+ model = nn.ModuleList(models)
440
+
441
+ return model
442
+
443
+
444
+ @MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE)
445
+ def build_vit_model_with_cfg(
446
+ num_chunks=1,
447
+ checkpoint=0.0,
448
+ dtype=torch.float,
449
+ embed_split_hidden=False,
450
+ num_layers=48,
451
+ hidden_size=2048,
452
+ vocab_size=50304,
453
+ embed_grad_scale=1,
454
+ parallel_output=True,
455
+ num_attention_heads=32,
456
+ mlp_ratio=4.0,
457
+ mlp_bias: bool = False,
458
+ residual_in_fp32=False,
459
+ norm_type="rmsnorm",
460
+ drop_path_rate=0,
461
+ attn_drop_rate=0,
462
+ apply_post_layer_norm=False, # pylint: disable=W0613
463
+ layer_norm_epsilon=1e-5,
464
+ is_reward=False,
465
+ dropout_selective_checkpoint=True,
466
+ use_scaled_init: bool = True,
467
+ use_swiglu: bool = True,
468
+ use_flash_attn: bool = True,
469
+ lvm_config=None,
470
+ ):
471
+ """
472
+ Build model with config.
473
+
474
+ Args:
475
+ num_chunks (int): The number of partitions in pipeline parallel. 1 by default.
476
+ checkpoint (bool): Whether to use checkpointing to save VRAM. False by default.
477
+ dtype (torch.dtype): The type of data. torch.float by default.
478
+ embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention.
479
+ False by default.
480
+ num_layers (int): The number of layer. 48 by default.
481
+ hidden_size (int): The size of hidden state. 2048 by default.
482
+ vocab_size (int): The size of vocabulary. 50304 by default.
483
+ embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default.
484
+ parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default.
485
+ num_attention_heads (int): The number of attention head. 32 by default.
486
+ mlp_ratio (int): The ratio of MLP layers. 4.0 by default.
487
+ residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily
488
+ because this parameter requires inconsistent data types to be passed between pipelines,
489
+ which requires significant modifications to internlm.
490
+ norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
491
+ drop_path_rate (float): The drop path rate rate of input hidden state. 0 by default.
492
+ attn_drop_rate (float): The dropout rate of attention module. 0 by default.
493
+ apply_post_layer_norm (bool): Whether to apply post layer norm. False by default.
494
+ layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default.
495
+ is_reward (bool): Whether to use reward model. False by default.
496
+ dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
497
+ use_scaled_init (bool): Whether to use scaled init. True by default.
498
+ use_swiglu (bool): Whether to use swiglu. True by default.
499
+ use_flash_attn (bool): Whether to use flash-attn. True by default.
500
+
501
+ """
502
+
503
+ cfg = dict(
504
+ hidden_size=hidden_size,
505
+ num_attention_heads=num_attention_heads,
506
+ checkpoint=checkpoint,
507
+ dtype=dtype,
508
+ embed_split_hidden=embed_split_hidden,
509
+ vocab_size=vocab_size,
510
+ embed_grad_scale=embed_grad_scale,
511
+ parallel_output=parallel_output,
512
+ mlp_ratio=mlp_ratio,
513
+ mlp_bias=mlp_bias,
514
+ residual_in_fp32=residual_in_fp32,
515
+ norm_type=norm_type,
516
+ drop_path_rate=drop_path_rate,
517
+ attn_drop_rate=attn_drop_rate,
518
+ layer_norm_epsilon=layer_norm_epsilon,
519
+ is_reward=is_reward,
520
+ dropout_selective_checkpoint=dropout_selective_checkpoint,
521
+ use_scaled_init=use_scaled_init,
522
+ use_swiglu=use_swiglu,
523
+ use_flash_attn=use_flash_attn,
524
+ lvm_config=lvm_config,
525
+ )
526
+
527
+ return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
InternLM/internlm/model/multi_head_attention.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from einops import rearrange
8
+ from flash_attn.modules.mha import (
9
+ CrossAttention,
10
+ FlashCrossAttention,
11
+ FlashSelfAttention,
12
+ SelfAttention,
13
+ _update_kv_cache,
14
+ )
15
+ from torch import nn
16
+
17
+ from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
18
+ from internlm.core.context import global_context as gpc
19
+ from internlm.model.embedding import RotaryEmbedding
20
+ from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
21
+
22
+
23
+ class MHA(nn.Module):
24
+ """
25
+ Multi-head self-attention and cross-attention.
26
+
27
+ Args:
28
+ embed_dim (int): The dimention of hidden state.
29
+ num_heads (int): The number of attention heads.
30
+ process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`.
31
+ bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and
32
+ output projection. True by default.
33
+ dropout (float): The dropout rate for cross attention and self attention. 0.0 by default.
34
+ softmax_scale (float): The temperature to use for the softmax attention.
35
+ causal (boolean): Whether to apply causal attention mask. False by default.
36
+ layer_idx (int): The index of current layer. None by default.
37
+ rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default.
38
+ rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements
39
+ XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default.
40
+ use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used.
41
+ False by default.
42
+ sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw
43
+ of x will be done before doing the matmul.
44
+ device (Optional[Union[str, torch.device]]): The device will be used.
45
+ dtype (Optional[torch.dtype]): The type of data.
46
+ use_flash_attn (bool): Whether to use flash-attn. True by default.
47
+
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ embed_dim: int,
53
+ num_heads: int,
54
+ process_group: Optional[torch.distributed.ProcessGroup],
55
+ dropout: float = 0.0,
56
+ softmax_scale: float = None,
57
+ causal: bool = False,
58
+ layer_idx: int = None,
59
+ rotary_emb_dim: int = 0,
60
+ rotary_emb_scale_base: int = 0,
61
+ use_flash_attn: bool = True,
62
+ device: Optional[torch.device] = None,
63
+ dtype: Optional[torch.dtype] = None,
64
+ ) -> None:
65
+ factory_kwargs = {"device": device, "dtype": dtype}
66
+ super().__init__()
67
+ self.embed_dim = embed_dim
68
+ self.causal = causal
69
+ self.layer_idx = layer_idx
70
+ self.rotary_emb_dim = rotary_emb_dim
71
+ self.use_flash_attn = use_flash_attn
72
+ self.num_heads = num_heads
73
+ assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
74
+ self.head_dim = self.embed_dim // num_heads
75
+
76
+ if self.rotary_emb_dim > 0:
77
+ self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
78
+
79
+ # notice here should change bias=True
80
+ self.Wqkv = ColumnParallelLinearTorch(
81
+ embed_dim,
82
+ 3 * embed_dim,
83
+ process_group,
84
+ bias=True,
85
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
86
+ **factory_kwargs,
87
+ ) # according to https://spaces.ac.cn/archives/9577
88
+
89
+ inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
90
+ inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
91
+ self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
92
+ self.inner_cross_attn = inner_cross_attn_cls(
93
+ causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
94
+ )
95
+
96
+ # output projection always have the bias (for now)
97
+ self.out_proj = RowParallelLinearTorch(
98
+ embed_dim,
99
+ embed_dim,
100
+ process_group,
101
+ sequence_parallel=gpc.config.parallel.sequence_parallel,
102
+ **factory_kwargs,
103
+ )
104
+ # need to assign tp attribute so that internlm know it is tensor parallel module
105
+ if gpc.get_world_size(ParallelMode.TENSOR) > 1:
106
+ for name in ["out_proj", "Wqkv"]:
107
+ for param in getattr(self, name).parameters():
108
+ setattr(param, IS_TENSOR_PARALLEL, True)
109
+
110
+ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
111
+ if kwargs.get("indexes", None) is not None:
112
+ return self._packed_forward(x=x, inference_params=inference_params, **kwargs)
113
+ else:
114
+ return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
115
+
116
+ def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
117
+ """
118
+ Arguments:
119
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
120
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
121
+ split x during sequence parallel, we split the batch * seqlen dimension
122
+ (in case batch is small).
123
+ """
124
+ qkv = self.Wqkv(x)
125
+ if seqlen is None:
126
+ qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
127
+ else:
128
+ qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
129
+
130
+ if self.rotary_emb_dim > 0:
131
+ kwargs["inference_params"] = inference_params
132
+ qkv = self.rotary_emb(qkv, **kwargs)
133
+
134
+ if inference_params is None:
135
+ if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
136
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
137
+ if qkv.dtype not in [torch.float16, torch.bfloat16]:
138
+ qkv = qkv.to(torch.bfloat16)
139
+ context = self.inner_attn(qkv).to(x.dtype)
140
+ else:
141
+ context = self.inner_attn(qkv)
142
+ else:
143
+ q = qkv[:, :, 0]
144
+ assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
145
+ kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
146
+ # If we're processing the prompt, causal=None (use self.causal).
147
+ # If we're decoding, then causal=False.
148
+ causal = None if inference_params.sequence_len_offset == 0 else False
149
+ context = self.inner_cross_attn(q, kv, causal=causal)
150
+
151
+ if seqlen is None:
152
+ context = rearrange(context, "b s h d -> b s (h d)")
153
+ else:
154
+ context = rearrange(context, "b s h d -> (b s) (h d)")
155
+
156
+ out = self.out_proj(context)
157
+ return out
158
+
159
+ def _packed_forward(self, x, inference_params=None, **kwargs):
160
+ """
161
+ Arguments:
162
+ x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
163
+ If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
164
+ split x during sequence parallel, we split the batch * seqlen dimension
165
+ (in case batch is small).
166
+ """
167
+ qkv = self.Wqkv(x) # total x hsz'
168
+ qkv = rearrange(qkv, "t (three h d) -> t three h d", three=3, d=self.head_dim) # total x 3 x n_head x d
169
+ qkv = self.rotary_emb(qkv, **kwargs)
170
+ kwargs.pop("indexes")
171
+
172
+ if inference_params is None:
173
+ if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
174
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
175
+ if qkv.dtype not in [torch.float16, torch.bfloat16]:
176
+ qkv = qkv.to(torch.bfloat16)
177
+ context = self.inner_attn(qkv, **kwargs).to(x.dtype)
178
+ else:
179
+ context = self.inner_attn(qkv, **kwargs)
180
+
181
+ else:
182
+ raise RuntimeError("Not support this right now")
183
+
184
+ context = rearrange(context, "b h d -> b (h d)") # recover the shape
185
+ out = self.out_proj(context)
186
+ return out
InternLM/internlm/model/muse/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ from .modeling_taming_vqgan import VQGANModel
InternLM/internlm/model/muse/modeling_taming_vqgan.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Taming Transformers Authors and The HuggingFace Inc. team.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from functools import partial
17
+ from typing import Tuple
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import torch.utils.checkpoint
22
+ from torch import nn
23
+
24
+ from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
25
+
26
+
27
+ class Upsample(nn.Module):
28
+ def __init__(self, in_channels: int, with_conv: bool):
29
+ super().__init__()
30
+ self.with_conv = with_conv
31
+ if self.with_conv:
32
+ self.conv = nn.Conv2d(
33
+ in_channels,
34
+ in_channels,
35
+ kernel_size=3,
36
+ stride=1,
37
+ padding=1,
38
+ )
39
+
40
+ def forward(self, hidden_states):
41
+ hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
42
+ if self.with_conv:
43
+ hidden_states = self.conv(hidden_states)
44
+ return hidden_states
45
+
46
+
47
+ class Downsample(nn.Module):
48
+ def __init__(self, in_channels: int, with_conv: bool):
49
+ super().__init__()
50
+
51
+ self.with_conv = with_conv
52
+ if self.with_conv:
53
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
54
+
55
+ def forward(self, hidden_states):
56
+ if self.with_conv:
57
+ pad = (0, 1, 0, 1) # pad height and width dim
58
+ hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
59
+ hidden_states = self.conv(hidden_states)
60
+ else:
61
+ hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=2, stride=2)
62
+ return hidden_states
63
+
64
+
65
+ class ResnetBlock(nn.Module):
66
+ def __init__(
67
+ self,
68
+ in_channels: int,
69
+ out_channels: int = None,
70
+ use_conv_shortcut: bool = False,
71
+ dropout_prob: float = 0.0,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.in_channels = in_channels
76
+ self.out_channels = out_channels
77
+ self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
78
+ self.use_conv_shortcut = use_conv_shortcut
79
+
80
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
81
+ self.conv1 = nn.Conv2d(
82
+ self.in_channels,
83
+ self.out_channels_,
84
+ kernel_size=3,
85
+ stride=1,
86
+ padding=1,
87
+ )
88
+
89
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)
90
+ self.dropout = nn.Dropout(dropout_prob)
91
+ self.conv2 = nn.Conv2d(
92
+ self.out_channels_,
93
+ self.out_channels_,
94
+ kernel_size=3,
95
+ stride=(1, 1),
96
+ padding=1,
97
+ )
98
+
99
+ if self.in_channels != self.out_channels_:
100
+ if use_conv_shortcut:
101
+ self.conv_shortcut = nn.Conv2d(
102
+ self.in_channels,
103
+ self.out_channels_,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1,
107
+ )
108
+ else:
109
+ self.nin_shortcut = nn.Conv2d(
110
+ self.in_channels,
111
+ self.out_channels_,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0,
115
+ )
116
+
117
+ def forward(self, hidden_states):
118
+ residual = hidden_states
119
+ hidden_states = self.norm1(hidden_states)
120
+ hidden_states = F.silu(hidden_states)
121
+ hidden_states = self.conv1(hidden_states)
122
+
123
+ hidden_states = self.norm2(hidden_states)
124
+ hidden_states = F.silu(hidden_states)
125
+ hidden_states = self.dropout(hidden_states)
126
+ hidden_states = self.conv2(hidden_states)
127
+
128
+ if self.in_channels != self.out_channels_:
129
+ if self.use_conv_shortcut:
130
+ residual = self.conv_shortcut(residual)
131
+ else:
132
+ residual = self.nin_shortcut(residual)
133
+
134
+ return hidden_states + residual
135
+
136
+
137
+ class AttnBlock(nn.Module):
138
+ def __init__(self, in_channels: int):
139
+ super().__init__()
140
+
141
+ self.in_channels = in_channels
142
+ conv = partial(nn.Conv2d, self.in_channels, self.in_channels, kernel_size=1, stride=1, padding=0)
143
+
144
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=self.in_channels, eps=1e-6, affine=True)
145
+ self.q, self.k, self.v = conv(), conv(), conv()
146
+ self.proj_out = conv()
147
+
148
+ def forward(self, hidden_states):
149
+ residual = hidden_states
150
+ hidden_states = self.norm(hidden_states)
151
+
152
+ query = self.q(hidden_states)
153
+ key = self.k(hidden_states)
154
+ value = self.v(hidden_states)
155
+
156
+ # compute attentions
157
+ batch, channels, height, width = query.shape
158
+ query = query.reshape((batch, channels, height * width))
159
+ query = query.permute(0, 2, 1) # (b, hw, c)
160
+ key = key.reshape((batch, channels, height * width))
161
+
162
+ attn_weights = torch.bmm(query, key) # b,hw,hw
163
+ attn_weights = attn_weights * (int(channels) ** -0.5)
164
+ attn_weights = nn.functional.softmax(attn_weights, dim=2)
165
+
166
+ # attend to values
167
+ value = value.reshape((batch, channels, height * width))
168
+ attn_weights = attn_weights.permute(0, 2, 1)
169
+ hidden_states = torch.bmm(value, attn_weights)
170
+ hidden_states = hidden_states.reshape((batch, channels, height, width))
171
+
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = hidden_states + residual
174
+ return hidden_states
175
+
176
+
177
+ class UpsamplingBlock(nn.Module):
178
+ def __init__(self, config, curr_res: int, block_idx: int):
179
+ super().__init__()
180
+
181
+ self.config = config
182
+ self.block_idx = block_idx
183
+ self.curr_res = curr_res
184
+
185
+ if self.block_idx == self.config.num_resolutions - 1:
186
+ block_in = self.config.hidden_channels * self.config.channel_mult[-1]
187
+ else:
188
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
189
+
190
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
191
+
192
+ res_blocks = []
193
+ attn_blocks = []
194
+ for _ in range(self.config.num_res_blocks + 1):
195
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
196
+ block_in = block_out
197
+ if self.curr_res in self.config.attn_resolutions:
198
+ attn_blocks.append(AttnBlock(block_in))
199
+
200
+ self.block = nn.ModuleList(res_blocks)
201
+ self.attn = nn.ModuleList(attn_blocks)
202
+
203
+ self.upsample = None
204
+ if self.block_idx != 0:
205
+ self.upsample = Upsample(block_in, self.config.resample_with_conv)
206
+
207
+ def forward(self, hidden_states):
208
+ for i, res_block in enumerate(self.block):
209
+ hidden_states = res_block(hidden_states)
210
+ if len(self.attn) > 1:
211
+ hidden_states = self.attn[i](hidden_states)
212
+
213
+ if self.upsample is not None:
214
+ hidden_states = self.upsample(hidden_states)
215
+
216
+ return hidden_states
217
+
218
+
219
+ class DownsamplingBlock(nn.Module):
220
+ def __init__(self, config, curr_res: int, block_idx: int):
221
+ super().__init__()
222
+
223
+ self.config = config
224
+ self.curr_res = curr_res
225
+ self.block_idx = block_idx
226
+
227
+ in_channel_mult = (1,) + tuple(self.config.channel_mult)
228
+ block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
229
+ block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
230
+
231
+ res_blocks = nn.ModuleList()
232
+ attn_blocks = nn.ModuleList()
233
+ for _ in range(self.config.num_res_blocks):
234
+ res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config.dropout))
235
+ block_in = block_out
236
+ if self.curr_res in self.config.attn_resolutions:
237
+ attn_blocks.append(AttnBlock(block_in))
238
+
239
+ self.block = res_blocks
240
+ self.attn = attn_blocks
241
+
242
+ self.downsample = None
243
+ if self.block_idx != self.config.num_resolutions - 1:
244
+ self.downsample = Downsample(block_in, self.config.resample_with_conv)
245
+
246
+ def forward(self, hidden_states):
247
+ for i, res_block in enumerate(self.block):
248
+ hidden_states = res_block(hidden_states)
249
+ if len(self.attn) > 1:
250
+ hidden_states = self.attn[i](hidden_states)
251
+
252
+ if self.downsample is not None:
253
+ hidden_states = self.downsample(hidden_states)
254
+
255
+ return hidden_states
256
+
257
+
258
+ class MidBlock(nn.Module):
259
+ def __init__(self, config, in_channels: int, no_attn: False, dropout: float):
260
+ super().__init__()
261
+
262
+ self.config = config
263
+ self.in_channels = in_channels
264
+ self.no_attn = no_attn
265
+ self.dropout = dropout
266
+
267
+ self.block_1 = ResnetBlock(
268
+ self.in_channels,
269
+ self.in_channels,
270
+ dropout_prob=self.dropout,
271
+ )
272
+ if not no_attn:
273
+ self.attn_1 = AttnBlock(self.in_channels)
274
+ self.block_2 = ResnetBlock(
275
+ self.in_channels,
276
+ self.in_channels,
277
+ dropout_prob=self.dropout,
278
+ )
279
+
280
+ def forward(self, hidden_states):
281
+ hidden_states = self.block_1(hidden_states)
282
+ if not self.no_attn:
283
+ hidden_states = self.attn_1(hidden_states)
284
+ hidden_states = self.block_2(hidden_states)
285
+ return hidden_states
286
+
287
+
288
+ class Encoder(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+
292
+ self.config = config
293
+
294
+ # downsampling
295
+ self.conv_in = nn.Conv2d(
296
+ self.config.num_channels,
297
+ self.config.hidden_channels,
298
+ kernel_size=3,
299
+ stride=1,
300
+ padding=1,
301
+ )
302
+
303
+ curr_res = self.config.resolution
304
+ downsample_blocks = []
305
+ for i_level in range(self.config.num_resolutions):
306
+ downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
307
+
308
+ if i_level != self.config.num_resolutions - 1:
309
+ curr_res = curr_res // 2
310
+ self.down = nn.ModuleList(downsample_blocks)
311
+
312
+ # middle
313
+ mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
314
+ self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout)
315
+
316
+ # end
317
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
318
+ self.conv_out = nn.Conv2d(
319
+ mid_channels,
320
+ self.config.z_channels,
321
+ kernel_size=3,
322
+ stride=1,
323
+ padding=1,
324
+ )
325
+
326
+ def forward(self, pixel_values):
327
+ # downsampling
328
+ hidden_states = self.conv_in(pixel_values)
329
+ for block in self.down:
330
+ hidden_states = block(hidden_states)
331
+
332
+ # middle
333
+ hidden_states = self.mid(hidden_states)
334
+
335
+ # end
336
+ hidden_states = self.norm_out(hidden_states)
337
+ hidden_states = F.silu(hidden_states)
338
+ hidden_states = self.conv_out(hidden_states)
339
+
340
+ return hidden_states
341
+
342
+
343
+ class Decoder(nn.Module):
344
+ def __init__(self, config):
345
+ super().__init__()
346
+
347
+ self.config = config
348
+
349
+ # compute in_channel_mult, block_in and curr_res at lowest res
350
+ block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
351
+ curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
352
+ self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
353
+
354
+ # z to block_in
355
+ self.conv_in = nn.Conv2d(
356
+ self.config.z_channels,
357
+ block_in,
358
+ kernel_size=3,
359
+ stride=1,
360
+ padding=1,
361
+ )
362
+
363
+ # middle
364
+ self.mid = MidBlock(config, block_in, self.config.no_attn_mid_block, self.config.dropout)
365
+
366
+ # upsampling
367
+ upsample_blocks = []
368
+ for i_level in reversed(range(self.config.num_resolutions)):
369
+ upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level))
370
+ if i_level != 0:
371
+ curr_res = curr_res * 2
372
+ self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
373
+
374
+ # end
375
+ block_out = self.config.hidden_channels * self.config.channel_mult[0]
376
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)
377
+ self.conv_out = nn.Conv2d(
378
+ block_out,
379
+ self.config.num_channels,
380
+ kernel_size=3,
381
+ stride=1,
382
+ padding=1,
383
+ )
384
+
385
+ def forward(self, hidden_states):
386
+ # z to block_in
387
+ hidden_states = self.conv_in(hidden_states)
388
+
389
+ # middle
390
+ hidden_states = self.mid(hidden_states)
391
+
392
+ # upsampling
393
+ for block in reversed(self.up):
394
+ hidden_states = block(hidden_states)
395
+
396
+ # end
397
+ hidden_states = self.norm_out(hidden_states)
398
+ hidden_states = F.silu(hidden_states)
399
+ hidden_states = self.conv_out(hidden_states)
400
+
401
+ return hidden_states
402
+
403
+
404
+ class VectorQuantizer(nn.Module):
405
+ """
406
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
407
+ Discretization bottleneck part of the VQ-VAE.
408
+ """
409
+
410
+ def __init__(self, num_embeddings, embedding_dim, commitment_cost):
411
+ r"""
412
+ Args:
413
+ num_embeddings: number of vectors in the quantized space.
414
+ embedding_dim: dimensionality of the tensors in the quantized space.
415
+ Inputs to the modules must be in this format as well.
416
+ commitment_cost: scalar which controls the weighting of the loss terms
417
+ (see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
418
+ """
419
+ super().__init__()
420
+
421
+ self.num_embeddings = num_embeddings
422
+ self.embedding_dim = embedding_dim
423
+ self.commitment_cost = commitment_cost
424
+
425
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
426
+ self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
427
+
428
+ def forward(self, hidden_states, return_loss=False):
429
+ """
430
+ Inputs the output of the encoder network z and maps it to a discrete one-hot vector that is the index of the
431
+ closest embedding vector e_j z (continuous) -> z_q (discrete) z.shape = (batch, channel, height, width)
432
+ quantization pipeline:
433
+ 1. get encoder input (B,C,H,W)
434
+ 2. flatten input to (B*H*W,C)
435
+ """
436
+ # reshape z -> (batch, height, width, channel) and flatten
437
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
438
+
439
+ distances = self.compute_distances(hidden_states)
440
+ min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
441
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
442
+ min_encodings.scatter_(1, min_encoding_indices, 1)
443
+
444
+ # get quantized latent vectors
445
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
446
+
447
+ # reshape to (batch, num_tokens)
448
+ min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
449
+
450
+ # compute loss for embedding
451
+ loss = None
452
+ if return_loss:
453
+ loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.commitment_cost * torch.mean(
454
+ (z_q - hidden_states.detach()) ** 2
455
+ )
456
+ # preserve gradients
457
+ z_q = hidden_states + (z_q - hidden_states).detach()
458
+
459
+ # reshape back to match original input shape
460
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
461
+
462
+ return z_q, min_encoding_indices, loss
463
+
464
+ def compute_distances(self, hidden_states):
465
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
466
+ hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
467
+ emb_weights = self.embedding.weight.t()
468
+
469
+ inputs_norm_sq = hidden_states_flattended.pow(2.0).sum(dim=1, keepdim=True)
470
+ codebook_t_norm_sq = emb_weights.pow(2.0).sum(dim=0, keepdim=True)
471
+ distances = torch.addmm(
472
+ inputs_norm_sq + codebook_t_norm_sq,
473
+ hidden_states_flattended,
474
+ emb_weights,
475
+ alpha=-2.0,
476
+ )
477
+ return distances
478
+
479
+ def get_codebook_entry(self, indices):
480
+ # indices are expected to be of shape (batch, num_tokens)
481
+ # get quantized latent vectors
482
+ batch, num_tokens = indices.shape
483
+ z_q = self.embedding(indices)
484
+ z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
485
+ return z_q
486
+
487
+ def get_codebook_entry_for_lvm(self, indices):
488
+ batch, num_tokens = indices.shape
489
+ z_q = self.embedding(indices)
490
+ z_q = z_q.reshape(batch, num_tokens, -1)
491
+ return z_q
492
+
493
+ # adapted from https://github.com/kakaobrain/rq-vae-transformer/blob/main/rqvae/models/rqvae/quantizations.py#L372
494
+ def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
495
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
496
+ distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
497
+
498
+ soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
499
+ if stochastic:
500
+ code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
501
+ else:
502
+ code = distances.argmin(dim=-1) # (batch * height * width)
503
+
504
+ code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
505
+ batch, num_tokens = code.shape
506
+ soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
507
+ return soft_code, code
508
+
509
+ def get_code(self, hidden_states):
510
+ # reshape z -> (batch, height, width, channel)
511
+ hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
512
+ distances = self.compute_distances(hidden_states)
513
+ indices = torch.argmin(distances, axis=1).unsqueeze(1)
514
+ indices = indices.reshape(hidden_states.shape[0], -1)
515
+ return indices
516
+
517
+
518
+ class VQGANModel(ModelMixin, ConfigMixin):
519
+ @register_to_config
520
+ def __init__(
521
+ self,
522
+ resolution: int = 256,
523
+ num_channels: int = 3,
524
+ hidden_channels: int = 128,
525
+ channel_mult: Tuple = (1, 1, 2, 2, 4),
526
+ num_res_blocks: int = 2,
527
+ attn_resolutions: int = (16,),
528
+ no_attn_mid_block: bool = False,
529
+ z_channels: int = 256,
530
+ num_embeddings: int = 1024,
531
+ quantized_embed_dim: int = 256,
532
+ dropout: float = 0.0,
533
+ resample_with_conv: bool = True,
534
+ commitment_cost: float = 0.25,
535
+ ):
536
+ super().__init__()
537
+
538
+ self.config.num_resolutions = len(channel_mult)
539
+ self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1)
540
+ self.config.latent_size = resolution // self.config.reduction_factor
541
+
542
+ self.encoder = Encoder(self.config)
543
+ self.decoder = Decoder(self.config)
544
+ self.quantize = VectorQuantizer(
545
+ self.config.num_embeddings, self.config.quantized_embed_dim, self.config.commitment_cost
546
+ )
547
+ self.quant_conv = nn.Conv2d(
548
+ self.config.z_channels,
549
+ self.config.quantized_embed_dim,
550
+ kernel_size=1,
551
+ )
552
+ self.post_quant_conv = nn.Conv2d(
553
+ self.config.quantized_embed_dim,
554
+ self.config.z_channels,
555
+ kernel_size=1,
556
+ )
557
+
558
+ def encode(self, pixel_values, return_loss=False):
559
+ hidden_states = self.encoder(pixel_values)
560
+ hidden_states = self.quant_conv(hidden_states)
561
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
562
+ output = (quantized_states, codebook_indices)
563
+ if return_loss:
564
+ output = output + (codebook_loss,)
565
+ return output
566
+
567
+ def decode(self, quantized_states):
568
+ hidden_states = self.post_quant_conv(quantized_states)
569
+ reconstructed_pixel_values = self.decoder(hidden_states)
570
+ return reconstructed_pixel_values
571
+
572
+ def decode_code(self, codebook_indices):
573
+ quantized_states = self.quantize.get_codebook_entry(codebook_indices)
574
+ reconstructed_pixel_values = self.decode(quantized_states)
575
+ return reconstructed_pixel_values
576
+
577
+ def get_code(self, pixel_values):
578
+ hidden_states = self.encoder(pixel_values)
579
+ hidden_states = self.quant_conv(hidden_states)
580
+ codebook_indices = self.quantize.get_code(hidden_states)
581
+ return codebook_indices
582
+
583
+ def forward(self, pixel_values, return_loss=False):
584
+ hidden_states = self.encoder(pixel_values)
585
+ hidden_states = self.quant_conv(hidden_states)
586
+ quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
587
+ reconstructed_pixel_values = self.decode(quantized_states)
588
+ outputs = (reconstructed_pixel_values, quantized_states, codebook_indices)
589
+ if return_loss:
590
+ outputs = outputs + (codebook_loss,)
591
+ return outputs
InternLM/internlm/model/muse/modeling_utils.py ADDED
@@ -0,0 +1,1171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ import inspect
18
+ import json
19
+ import os
20
+ from collections import OrderedDict
21
+ from functools import partial
22
+ from pathlib import PosixPath
23
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
24
+
25
+ import accelerate
26
+ import numpy as np
27
+ import torch
28
+ from accelerate.utils import set_module_tensor_to_device
29
+ from huggingface_hub import hf_hub_download
30
+ from huggingface_hub.utils import (
31
+ EntryNotFoundError,
32
+ RepositoryNotFoundError,
33
+ RevisionNotFoundError,
34
+ )
35
+ from requests import HTTPError
36
+ from torch import Tensor, device
37
+
38
+ from . import __version__
39
+ from internlm.utils.logger import get_logger
40
+
41
+ logger = get_logger(__file__)
42
+
43
+
44
+ hf_cache_home = os.path.expanduser(
45
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
46
+ )
47
+ default_cache_path = os.path.join(hf_cache_home, "muse")
48
+
49
+
50
+ CONFIG_NAME = "config.json"
51
+ WEIGHTS_NAME = "pytorch_model.bin"
52
+ SAFETENSORS_WEIGHTS_NAME = "pytorch_model.safetensors"
53
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
54
+ MUSE_CACHE = default_cache_path
55
+ MUSE_DYNAMIC_MODULE_NAME = "myse_modules"
56
+ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
57
+
58
+
59
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
60
+
61
+
62
+ def get_parameter_device(parameter: torch.nn.Module):
63
+ try:
64
+ return next(parameter.parameters()).device
65
+ except StopIteration:
66
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
67
+
68
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
69
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
70
+ return tuples
71
+
72
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
73
+ first_tuple = next(gen)
74
+ return first_tuple[1].device
75
+
76
+
77
+ def get_parameter_dtype(parameter: torch.nn.Module):
78
+ try:
79
+ return next(parameter.parameters()).dtype
80
+ except StopIteration:
81
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
82
+
83
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
84
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
85
+ return tuples
86
+
87
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
88
+ first_tuple = next(gen)
89
+ return first_tuple[1].dtype
90
+
91
+
92
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
93
+ """
94
+ Reads a checkpoint file, returning properly formatted errors if they arise.
95
+ """
96
+ try:
97
+ if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
98
+ return torch.load(checkpoint_file, map_location="cpu")
99
+ except Exception as e:
100
+ try:
101
+ with open(checkpoint_file) as f:
102
+ if f.read().startswith("version"):
103
+ raise OSError(
104
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
105
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
106
+ "you cloned."
107
+ )
108
+ else:
109
+ raise ValueError(
110
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
111
+ "model. Make sure you have saved the model properly."
112
+ ) from e
113
+ except (UnicodeDecodeError, ValueError):
114
+ raise OSError(
115
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
116
+ f"at '{checkpoint_file}'. "
117
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
118
+ )
119
+
120
+
121
+ def _load_state_dict_into_model(model_to_load, state_dict):
122
+ # Convert old format to new format if needed from a PyTorch state_dict
123
+ # copy state_dict so _load_from_state_dict can modify it
124
+ state_dict = state_dict.copy()
125
+ error_msgs = []
126
+
127
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
128
+ # so we need to apply the function recursively.
129
+ def load(module: torch.nn.Module, prefix=""):
130
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
131
+ module._load_from_state_dict(*args)
132
+
133
+ for name, child in module._modules.items():
134
+ if child is not None:
135
+ load(child, prefix + name + ".")
136
+
137
+ load(model_to_load)
138
+
139
+ return error_msgs
140
+
141
+
142
+ def _get_model_file(
143
+ pretrained_model_name_or_path,
144
+ *,
145
+ weights_name,
146
+ subfolder,
147
+ cache_dir,
148
+ force_download,
149
+ proxies,
150
+ resume_download,
151
+ local_files_only,
152
+ use_auth_token,
153
+ user_agent,
154
+ revision,
155
+ ):
156
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
157
+ if os.path.isfile(pretrained_model_name_or_path):
158
+ return pretrained_model_name_or_path
159
+ elif os.path.isdir(pretrained_model_name_or_path):
160
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
161
+ # Load from a PyTorch checkpoint
162
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
163
+ return model_file
164
+ elif subfolder is not None and os.path.isfile(
165
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
166
+ ):
167
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
168
+ return model_file
169
+ else:
170
+ raise EnvironmentError(
171
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
172
+ )
173
+ else:
174
+ try:
175
+ # Load from URL or cache if already cached
176
+ model_file = hf_hub_download(
177
+ pretrained_model_name_or_path,
178
+ filename=weights_name,
179
+ cache_dir=cache_dir,
180
+ force_download=force_download,
181
+ proxies=proxies,
182
+ resume_download=resume_download,
183
+ local_files_only=local_files_only,
184
+ use_auth_token=use_auth_token,
185
+ user_agent=user_agent,
186
+ subfolder=subfolder,
187
+ revision=revision,
188
+ )
189
+ return model_file
190
+
191
+ except RepositoryNotFoundError:
192
+ raise EnvironmentError(
193
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
194
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
195
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
196
+ "login`."
197
+ )
198
+ except RevisionNotFoundError:
199
+ raise EnvironmentError(
200
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
201
+ "this model name. Check the model page at "
202
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
203
+ )
204
+ except EntryNotFoundError:
205
+ raise EnvironmentError(
206
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
207
+ )
208
+ except HTTPError as err:
209
+ raise EnvironmentError(
210
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
211
+ )
212
+ except ValueError:
213
+ raise EnvironmentError(
214
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
215
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
216
+ f" directory containing a file named {weights_name} or"
217
+ " \nCheckout your internet connection or see how to run the library in"
218
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
219
+ )
220
+ except EnvironmentError:
221
+ raise EnvironmentError(
222
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
223
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
224
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
225
+ f"containing a file named {weights_name}"
226
+ )
227
+
228
+
229
+ class ModelMixin(torch.nn.Module):
230
+ r"""
231
+ Base class for all models.
232
+
233
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
234
+ and saving models.
235
+
236
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
237
+ [`~models.ModelMixin.save_pretrained`].
238
+ """
239
+ config_name = CONFIG_NAME
240
+ _automatically_saved_args = ["_version", "_class_name", "_name_or_path"]
241
+ _supports_gradient_checkpointing = False
242
+
243
+ def __init__(self):
244
+ super().__init__()
245
+
246
+ @property
247
+ def is_gradient_checkpointing(self) -> bool:
248
+ """
249
+ Whether gradient checkpointing is activated for this model or not.
250
+
251
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
252
+ activations".
253
+ """
254
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
255
+
256
+ def enable_gradient_checkpointing(self):
257
+ """
258
+ Activates gradient checkpointing for the current model.
259
+
260
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
261
+ activations".
262
+ """
263
+ if not self._supports_gradient_checkpointing:
264
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
265
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
266
+
267
+ def disable_gradient_checkpointing(self):
268
+ """
269
+ Deactivates gradient checkpointing for the current model.
270
+
271
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
272
+ activations".
273
+ """
274
+ if self._supports_gradient_checkpointing:
275
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
276
+
277
+ def set_use_memory_efficient_attention_xformers(
278
+ self, valid: bool, attention_op: Optional[Callable] = None
279
+ ) -> None:
280
+ # Recursively walk through all the children.
281
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
282
+ # gets the message
283
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
284
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
285
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
286
+
287
+ for child in module.children():
288
+ fn_recursive_set_mem_eff(child)
289
+
290
+ for module in self.children():
291
+ if isinstance(module, torch.nn.Module):
292
+ fn_recursive_set_mem_eff(module)
293
+
294
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
295
+ r"""
296
+ Enable memory efficient attention as implemented in xformers.
297
+
298
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
299
+ time. Speed up at training time is not guaranteed.
300
+
301
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
302
+ is used.
303
+
304
+ Parameters:
305
+ attention_op (`Callable`, *optional*):
306
+ Override the default `None` operator for use as `op` argument to the
307
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
308
+ function of xFormers.
309
+
310
+ Examples:
311
+
312
+ ```py
313
+ >>> import torch
314
+ >>> from diffusers import UNet2DConditionModel
315
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
316
+
317
+ >>> model = UNet2DConditionModel.from_pretrained(
318
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
319
+ ... )
320
+ >>> model = model.to("cuda")
321
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
322
+ ```
323
+ """
324
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
325
+
326
+ def disable_xformers_memory_efficient_attention(self):
327
+ r"""
328
+ Disable memory efficient attention as implemented in xformers.
329
+ """
330
+ self.set_use_memory_efficient_attention_xformers(False)
331
+
332
+ def save_pretrained(
333
+ self,
334
+ save_directory: Union[str, os.PathLike],
335
+ is_main_process: bool = True,
336
+ save_function: Callable = None,
337
+ state_dict: Optional[Dict[str, torch.Tensor]] = None,
338
+ ):
339
+ """
340
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
341
+ `[`~models.ModelMixin.from_pretrained`]` class method.
342
+
343
+ Arguments:
344
+ save_directory (`str` or `os.PathLike`):
345
+ Directory to which to save. Will be created if it doesn't exist.
346
+ is_main_process (`bool`, *optional*, defaults to `True`):
347
+ Whether the process calling this is the main process or not. Useful when in distributed training like
348
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
349
+ the main process to avoid race conditions.
350
+ save_function (`Callable`):
351
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
352
+ need to replace `torch.save` by another method. Can be configured with the environment variable
353
+ `DIFFUSERS_SAVE_MODE`.
354
+ state_dict (`Dict[str, torch.Tensor]`, *optional*):
355
+ The state dictionary to save. If `None`, the model's state dictionary will be saved.
356
+ """
357
+ if os.path.isfile(save_directory):
358
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
359
+ return
360
+
361
+ if save_function is None:
362
+ save_function = torch.save
363
+
364
+ os.makedirs(save_directory, exist_ok=True)
365
+
366
+ model_to_save = self
367
+
368
+ # Attach architecture to the config
369
+ # Save the config
370
+ if is_main_process:
371
+ model_to_save.save_config(save_directory)
372
+
373
+ # Save the model
374
+ if state_dict is None:
375
+ state_dict = model_to_save.state_dict()
376
+
377
+ weights_name = WEIGHTS_NAME
378
+
379
+ # Save the model
380
+ save_function(state_dict, os.path.join(save_directory, weights_name))
381
+
382
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
383
+
384
+ @classmethod
385
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
386
+ r"""
387
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
388
+
389
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
390
+ the model, you should first set it back in training mode with `model.train()`.
391
+
392
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
393
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
394
+ task.
395
+
396
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
397
+ weights are discarded.
398
+
399
+ Parameters:
400
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
401
+ Can be either:
402
+
403
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
404
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
405
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
406
+ `./my_model_directory/`.
407
+
408
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
409
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
410
+ standard cache should not be used.
411
+ torch_dtype (`str` or `torch.dtype`, *optional*):
412
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
413
+ will be automatically derived from the model's weights.
414
+ force_download (`bool`, *optional*, defaults to `False`):
415
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
416
+ cached versions if they exist.
417
+ resume_download (`bool`, *optional*, defaults to `False`):
418
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
419
+ file exists.
420
+ proxies (`Dict[str, str]`, *optional*):
421
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
422
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
423
+ output_loading_info(`bool`, *optional*, defaults to `False`):
424
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
425
+ local_files_only(`bool`, *optional*, defaults to `False`):
426
+ Whether or not to only look at local files (i.e., do not try to download the model).
427
+ use_auth_token (`str` or *bool*, *optional*):
428
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
429
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
430
+ revision (`str`, *optional*, defaults to `"main"`):
431
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
432
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
433
+ identifier allowed by git.
434
+ from_flax (`bool`, *optional*, defaults to `False`):
435
+ Load the model weights from a Flax checkpoint save file.
436
+ subfolder (`str`, *optional*, defaults to `""`):
437
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
438
+ huggingface.co or downloaded locally), you can specify the folder name here.
439
+
440
+ mirror (`str`, *optional*):
441
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
442
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
443
+ Please refer to the mirror site for more information.
444
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
445
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
446
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
447
+ same device.
448
+
449
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
450
+ more information about each option see [designing a device
451
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
452
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
453
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
454
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
455
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
456
+ setting this argument to `True` will raise an error.
457
+
458
+ <Tip>
459
+
460
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
461
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
462
+
463
+ </Tip>
464
+
465
+ <Tip>
466
+
467
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
468
+ this method in a firewalled environment.
469
+
470
+ </Tip>
471
+
472
+ """
473
+ cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
474
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
475
+ force_download = kwargs.pop("force_download", False)
476
+ resume_download = kwargs.pop("resume_download", False)
477
+ proxies = kwargs.pop("proxies", None)
478
+ output_loading_info = kwargs.pop("output_loading_info", False)
479
+ local_files_only = kwargs.pop("local_files_only", False) # TODO
480
+ use_auth_token = kwargs.pop("use_auth_token", None)
481
+ revision = kwargs.pop("revision", None)
482
+ torch_dtype = kwargs.pop("torch_dtype", None)
483
+ subfolder = kwargs.pop("subfolder", None)
484
+ device_map = kwargs.pop("device_map", None)
485
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
486
+
487
+ if low_cpu_mem_usage is False and device_map is not None:
488
+ raise ValueError(
489
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
490
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
491
+ )
492
+
493
+ user_agent = {
494
+ "diffusers": __version__,
495
+ "file_type": "model",
496
+ "framework": "pytorch",
497
+ }
498
+
499
+ # Load config if we don't provide a configuration
500
+ config_path = pretrained_model_name_or_path
501
+
502
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
503
+ # Load model
504
+
505
+ model_file = None
506
+
507
+ if model_file is None:
508
+ model_file = _get_model_file(
509
+ pretrained_model_name_or_path,
510
+ weights_name=WEIGHTS_NAME,
511
+ cache_dir=cache_dir,
512
+ force_download=force_download,
513
+ resume_download=resume_download,
514
+ proxies=proxies,
515
+ local_files_only=local_files_only,
516
+ use_auth_token=use_auth_token,
517
+ revision=revision,
518
+ subfolder=subfolder,
519
+ user_agent=user_agent,
520
+ )
521
+
522
+ if low_cpu_mem_usage:
523
+ # Instantiate model with empty weights
524
+ with accelerate.init_empty_weights():
525
+ config, unused_kwargs = cls.load_config(
526
+ config_path,
527
+ cache_dir=cache_dir,
528
+ return_unused_kwargs=True,
529
+ force_download=force_download,
530
+ resume_download=resume_download,
531
+ proxies=proxies,
532
+ local_files_only=local_files_only,
533
+ use_auth_token=use_auth_token,
534
+ revision=revision,
535
+ subfolder=subfolder,
536
+ device_map=device_map,
537
+ **kwargs,
538
+ )
539
+ model = cls.from_config(config, **unused_kwargs)
540
+
541
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
542
+ if device_map is None:
543
+ param_device = "cpu"
544
+ state_dict = load_state_dict(model_file)
545
+ # move the params from meta device to cpu
546
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
547
+ if len(missing_keys) > 0:
548
+ raise ValueError(
549
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
550
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
551
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
552
+ " those weights or else make sure your checkpoint file is correct."
553
+ )
554
+
555
+ for param_name, param in state_dict.items():
556
+ accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
557
+ if accepts_dtype:
558
+ set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
559
+ else:
560
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
561
+ else: # else let accelerate handle loading and dispatching.
562
+ # Load weights and dispatch according to the device_map
563
+ # by deafult the device_map is None and the weights are loaded on the CPU
564
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
565
+
566
+ loading_info = {
567
+ "missing_keys": [],
568
+ "unexpected_keys": [],
569
+ "mismatched_keys": [],
570
+ "error_msgs": [],
571
+ }
572
+ else:
573
+ config, unused_kwargs = cls.load_config(
574
+ config_path,
575
+ cache_dir=cache_dir,
576
+ return_unused_kwargs=True,
577
+ force_download=force_download,
578
+ resume_download=resume_download,
579
+ proxies=proxies,
580
+ local_files_only=local_files_only,
581
+ use_auth_token=use_auth_token,
582
+ revision=revision,
583
+ subfolder=subfolder,
584
+ device_map=device_map,
585
+ **kwargs,
586
+ )
587
+ model = cls.from_config(config, **unused_kwargs)
588
+
589
+ state_dict = load_state_dict(model_file)
590
+
591
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
592
+ model,
593
+ state_dict,
594
+ model_file,
595
+ pretrained_model_name_or_path,
596
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
597
+ )
598
+
599
+ loading_info = {
600
+ "missing_keys": missing_keys,
601
+ "unexpected_keys": unexpected_keys,
602
+ "mismatched_keys": mismatched_keys,
603
+ "error_msgs": error_msgs,
604
+ }
605
+
606
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
607
+ raise ValueError(
608
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
609
+ )
610
+ elif torch_dtype is not None:
611
+ model = model.to(torch_dtype)
612
+
613
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
614
+
615
+ # Set model in evaluation mode to deactivate DropOut modules by default
616
+ model.eval()
617
+ if output_loading_info:
618
+ return model, loading_info
619
+
620
+ return model
621
+
622
+ @classmethod
623
+ def _load_pretrained_model(
624
+ cls,
625
+ model,
626
+ state_dict,
627
+ resolved_archive_file,
628
+ pretrained_model_name_or_path,
629
+ ignore_mismatched_sizes=False,
630
+ ):
631
+ # Retrieve missing & unexpected_keys
632
+ model_state_dict = model.state_dict()
633
+ loaded_keys = [k for k in state_dict.keys()]
634
+
635
+ expected_keys = list(model_state_dict.keys())
636
+
637
+ original_loaded_keys = loaded_keys
638
+
639
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
640
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
641
+
642
+ # Make sure we are able to load base models as well as derived models (with heads)
643
+ model_to_load = model
644
+
645
+ def _find_mismatched_keys(
646
+ state_dict,
647
+ model_state_dict,
648
+ loaded_keys,
649
+ ignore_mismatched_sizes,
650
+ ):
651
+ mismatched_keys = []
652
+ if ignore_mismatched_sizes:
653
+ for checkpoint_key in loaded_keys:
654
+ model_key = checkpoint_key
655
+
656
+ if (
657
+ model_key in model_state_dict
658
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
659
+ ):
660
+ mismatched_keys.append(
661
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
662
+ )
663
+ del state_dict[checkpoint_key]
664
+ return mismatched_keys
665
+
666
+ if state_dict is not None:
667
+ # Whole checkpoint
668
+ mismatched_keys = _find_mismatched_keys(
669
+ state_dict,
670
+ model_state_dict,
671
+ original_loaded_keys,
672
+ ignore_mismatched_sizes,
673
+ )
674
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
675
+
676
+ if len(error_msgs) > 0:
677
+ error_msg = "\n\t".join(error_msgs)
678
+ if "size mismatch" in error_msg:
679
+ error_msg += (
680
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
681
+ )
682
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
683
+
684
+ if len(unexpected_keys) > 0:
685
+ logger.warning(
686
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
687
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
688
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
689
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
690
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
691
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
692
+ " identical (initializing a BertForSequenceClassification model from a"
693
+ " BertForSequenceClassification model)."
694
+ )
695
+ else:
696
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
697
+ if len(missing_keys) > 0:
698
+ logger.warning(
699
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
700
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
701
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
702
+ )
703
+ elif len(mismatched_keys) == 0:
704
+ logger.info(
705
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
706
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
707
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
708
+ " without further training."
709
+ )
710
+ if len(mismatched_keys) > 0:
711
+ mismatched_warning = "\n".join(
712
+ [
713
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
714
+ for key, shape1, shape2 in mismatched_keys
715
+ ]
716
+ )
717
+ logger.warning(
718
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
719
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
720
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
721
+ " able to use it for predictions and inference."
722
+ )
723
+
724
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
725
+
726
+ @property
727
+ def device(self) -> device:
728
+ """
729
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
730
+ device).
731
+ """
732
+ return get_parameter_device(self)
733
+
734
+ @property
735
+ def dtype(self) -> torch.dtype:
736
+ """
737
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
738
+ """
739
+ return get_parameter_dtype(self)
740
+
741
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
742
+ """
743
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
744
+
745
+ Args:
746
+ only_trainable (`bool`, *optional*, defaults to `False`):
747
+ Whether or not to return only the number of trainable parameters
748
+
749
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
750
+ Whether or not to return only the number of non-embeddings parameters
751
+
752
+ Returns:
753
+ `int`: The number of parameters.
754
+ """
755
+
756
+ if exclude_embeddings:
757
+ embedding_param_names = [
758
+ f"{name}.weight"
759
+ for name, module_type in self.named_modules()
760
+ if isinstance(module_type, torch.nn.Embedding)
761
+ ]
762
+ non_embedding_parameters = [
763
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
764
+ ]
765
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
766
+ else:
767
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
768
+
769
+
770
+ """ ConfigMixin base class and utilities."""
771
+
772
+
773
+ class FrozenDict(OrderedDict):
774
+ def __init__(self, *args, **kwargs):
775
+ super().__init__(*args, **kwargs)
776
+
777
+ for key, value in self.items():
778
+ setattr(self, key, value)
779
+
780
+ self.__frozen = True
781
+
782
+ def __delitem__(self, *args, **kwargs):
783
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
784
+
785
+ def setdefault(self, *args, **kwargs):
786
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
787
+
788
+ def pop(self, *args, **kwargs):
789
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
790
+
791
+ def update(self, *args, **kwargs):
792
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
793
+
794
+ def __setattr__(self, name, value):
795
+ if hasattr(self, "__frozen") and self.__frozen:
796
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
797
+ super().__setattr__(name, value)
798
+
799
+ def __setitem__(self, name, value):
800
+ if hasattr(self, "__frozen") and self.__frozen:
801
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
802
+ super().__setitem__(name, value)
803
+
804
+
805
+ class ConfigMixin:
806
+ r"""
807
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
808
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
809
+ - [`~ConfigMixin.from_config`]
810
+ - [`~ConfigMixin.save_config`]
811
+
812
+ Class attributes:
813
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
814
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
815
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
816
+ overridden by subclass).
817
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
818
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the init function
819
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
820
+ subclass).
821
+ """
822
+ config_name = None
823
+ ignore_for_config = []
824
+ has_compatibles = False
825
+
826
+ _deprecated_kwargs = []
827
+
828
+ def register_to_config(self, **kwargs):
829
+ if self.config_name is None:
830
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
831
+ # Special case for `kwargs` used in deprecation warning added to schedulers
832
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
833
+ # or solve in a more general way.
834
+ kwargs.pop("kwargs", None)
835
+ for key, value in kwargs.items():
836
+ try:
837
+ setattr(self, key, value)
838
+ except AttributeError as err:
839
+ logger.error(f"Can't set {key} with value {value} for {self}")
840
+ raise err
841
+
842
+ if not hasattr(self, "_internal_dict"):
843
+ internal_dict = kwargs
844
+ else:
845
+ previous_dict = dict(self._internal_dict)
846
+ internal_dict = {**self._internal_dict, **kwargs}
847
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
848
+
849
+ self._internal_dict = FrozenDict(internal_dict)
850
+
851
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
852
+ """
853
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
854
+ [`~ConfigMixin.from_config`] class method.
855
+
856
+ Args:
857
+ save_directory (`str` or `os.PathLike`):
858
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
859
+ """
860
+ if os.path.isfile(save_directory):
861
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
862
+
863
+ os.makedirs(save_directory, exist_ok=True)
864
+
865
+ # If we save using the predefined names, we can load using `from_config`
866
+ output_config_file = os.path.join(save_directory, self.config_name)
867
+
868
+ self.to_json_file(output_config_file)
869
+ logger.info(f"Configuration saved in {output_config_file}")
870
+
871
+ @classmethod
872
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, **kwargs):
873
+ r"""
874
+ Instantiate a Python class from a config dictionary
875
+
876
+ Parameters:
877
+ config (`Dict[str, Any]`):
878
+ A config dictionary from which the Python class will be instantiated. Make sure to only load
879
+ configuration files of compatible classes.
880
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
881
+ Whether kwargs that are not consumed by the Python class should be returned or not.
882
+
883
+ kwargs (remaining dictionary of keyword arguments, *optional*):
884
+ Can be used to update the configuration object (after it being loaded) and initiate the Python class.
885
+ `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually
886
+ overwrite same named arguments of `config`.
887
+
888
+ Examples:
889
+
890
+ ```python
891
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
892
+
893
+ >>> # Download scheduler from huggingface.co and cache.
894
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
895
+
896
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
897
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
898
+
899
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
900
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
901
+ ```
902
+ """
903
+ # <===== TO BE REMOVED WITH DEPRECATION
904
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
905
+ if "pretrained_model_name_or_path" in kwargs:
906
+ config = kwargs.pop("pretrained_model_name_or_path")
907
+
908
+ if config is None:
909
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
910
+ # ======>
911
+
912
+ # Return model and optionally state and/or unused_kwargs
913
+ model = cls(**config)
914
+ return model
915
+
916
+ @classmethod
917
+ def load_config(
918
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
919
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
920
+ r"""
921
+ Instantiate a Python class from a config dictionary
922
+
923
+ Parameters:
924
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
925
+ Can be either:
926
+
927
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
928
+ organization name, like `google/ddpm-celebahq-256`.
929
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
930
+ `./my_model_directory/`.
931
+
932
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
933
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
934
+ standard cache should not be used.
935
+ force_download (`bool`, *optional*, defaults to `False`):
936
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
937
+ cached versions if they exist.
938
+ resume_download (`bool`, *optional*, defaults to `False`):
939
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
940
+ file exists.
941
+ proxies (`Dict[str, str]`, *optional*):
942
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
943
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
944
+ output_loading_info(`bool`, *optional*, defaults to `False`):
945
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
946
+ local_files_only(`bool`, *optional*, defaults to `False`):
947
+ Whether or not to only look at local files (i.e., do not try to download the model).
948
+ use_auth_token (`str` or *bool*, *optional*):
949
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
950
+ when running `transformers-cli login` (stored in `~/.huggingface`).
951
+ revision (`str`, *optional*, defaults to `"main"`):
952
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
953
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
954
+ identifier allowed by git.
955
+ subfolder (`str`, *optional*, defaults to `""`):
956
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
957
+ huggingface.co or downloaded locally), you can specify the folder name here.
958
+
959
+ <Tip>
960
+
961
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
962
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
963
+
964
+ </Tip>
965
+
966
+ <Tip>
967
+
968
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
969
+ use this method in a firewalled environment.
970
+
971
+ </Tip>
972
+ """
973
+ cache_dir = kwargs.pop("cache_dir", MUSE_CACHE)
974
+ force_download = kwargs.pop("force_download", False)
975
+ resume_download = kwargs.pop("resume_download", False)
976
+ proxies = kwargs.pop("proxies", None)
977
+ use_auth_token = kwargs.pop("use_auth_token", None)
978
+ local_files_only = kwargs.pop("local_files_only", False)
979
+ revision = kwargs.pop("revision", None)
980
+ _ = kwargs.pop("mirror", None)
981
+ subfolder = kwargs.pop("subfolder", None)
982
+
983
+ user_agent = {"file_type": "config"}
984
+
985
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
986
+
987
+ if cls.config_name is None:
988
+ raise ValueError(
989
+ "`self.config_name` is not defined. Note that one should not load a config from "
990
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
991
+ )
992
+
993
+ if os.path.isfile(pretrained_model_name_or_path):
994
+ config_file = pretrained_model_name_or_path
995
+ elif os.path.isdir(pretrained_model_name_or_path):
996
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
997
+ # Load from a PyTorch checkpoint
998
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
999
+ elif subfolder is not None and os.path.isfile(
1000
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
1001
+ ):
1002
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
1003
+ else:
1004
+ raise EnvironmentError(
1005
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
1006
+ )
1007
+ else:
1008
+ try:
1009
+ # Load from URL or cache if already cached
1010
+ config_file = hf_hub_download(
1011
+ pretrained_model_name_or_path,
1012
+ filename=cls.config_name,
1013
+ cache_dir=cache_dir,
1014
+ force_download=force_download,
1015
+ proxies=proxies,
1016
+ resume_download=resume_download,
1017
+ local_files_only=local_files_only,
1018
+ use_auth_token=use_auth_token,
1019
+ user_agent=user_agent,
1020
+ subfolder=subfolder,
1021
+ revision=revision,
1022
+ )
1023
+
1024
+ except RepositoryNotFoundError:
1025
+ raise EnvironmentError(
1026
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
1027
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
1028
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
1029
+ " login`."
1030
+ )
1031
+ except RevisionNotFoundError:
1032
+ raise EnvironmentError(
1033
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
1034
+ " this model name. Check the model page at"
1035
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
1036
+ )
1037
+ except EntryNotFoundError:
1038
+ raise EnvironmentError(
1039
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
1040
+ )
1041
+ except HTTPError as err:
1042
+ raise EnvironmentError(
1043
+ "There was a specific connection error when trying to load"
1044
+ f" {pretrained_model_name_or_path}:\n{err}"
1045
+ )
1046
+ except ValueError:
1047
+ raise EnvironmentError(
1048
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
1049
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
1050
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
1051
+ " run the library in offline mode at"
1052
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
1053
+ )
1054
+ except EnvironmentError:
1055
+ raise EnvironmentError(
1056
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
1057
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
1058
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
1059
+ f"containing a {cls.config_name} file"
1060
+ )
1061
+
1062
+ try:
1063
+ # Load config dict
1064
+ config_dict = cls._dict_from_json_file(config_file)
1065
+ except (json.JSONDecodeError, UnicodeDecodeError):
1066
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
1067
+
1068
+ if return_unused_kwargs:
1069
+ return config_dict, kwargs
1070
+
1071
+ return config_dict
1072
+
1073
+ @staticmethod
1074
+ def _get_init_keys(cls):
1075
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
1076
+
1077
+ @classmethod
1078
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
1079
+ with open(json_file, "r", encoding="utf-8") as reader:
1080
+ text = reader.read()
1081
+ return json.loads(text)
1082
+
1083
+ def __repr__(self):
1084
+ return f"{self.__class__.__name__} {self.to_json_string()}"
1085
+
1086
+ @property
1087
+ def config(self) -> Dict[str, Any]:
1088
+ """
1089
+ Returns the config of the class as a frozen dictionary
1090
+
1091
+ Returns:
1092
+ `Dict[str, Any]`: Config of the class.
1093
+ """
1094
+ return self._internal_dict
1095
+
1096
+ def to_json_string(self) -> str:
1097
+ """
1098
+ Serializes this instance to a JSON string.
1099
+
1100
+ Returns:
1101
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
1102
+ """
1103
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
1104
+ config_dict["_class_name"] = self.__class__.__name__
1105
+ config_dict["_version"] = __version__
1106
+
1107
+ def to_json_saveable(value):
1108
+ if isinstance(value, np.ndarray):
1109
+ value = value.tolist()
1110
+ elif isinstance(value, PosixPath):
1111
+ value = str(value)
1112
+ return value
1113
+
1114
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
1115
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
1116
+
1117
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
1118
+ """
1119
+ Save this instance to a JSON file.
1120
+
1121
+ Args:
1122
+ json_file_path (`str` or `os.PathLike`):
1123
+ Path to the JSON file in which this configuration instance's parameters will be saved.
1124
+ """
1125
+ with open(json_file_path, "w", encoding="utf-8") as writer:
1126
+ writer.write(self.to_json_string())
1127
+
1128
+
1129
+ def register_to_config(init):
1130
+ r"""
1131
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
1132
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
1133
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
1134
+
1135
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
1136
+ """
1137
+
1138
+ @functools.wraps(init)
1139
+ def inner_init(self, *args, **kwargs):
1140
+ # Ignore private kwargs in the init.
1141
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
1142
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
1143
+ if not isinstance(self, ConfigMixin):
1144
+ raise RuntimeError(
1145
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
1146
+ "not inherit from `ConfigMixin`."
1147
+ )
1148
+
1149
+ ignore = getattr(self, "ignore_for_config", [])
1150
+ # Get positional arguments aligned with kwargs
1151
+ new_kwargs = {}
1152
+ signature = inspect.signature(init)
1153
+ parameters = {
1154
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
1155
+ }
1156
+ for arg, name in zip(args, parameters.keys()):
1157
+ new_kwargs[name] = arg
1158
+
1159
+ # Then add all kwargs
1160
+ new_kwargs.update(
1161
+ {
1162
+ k: init_kwargs.get(k, default)
1163
+ for k, default in parameters.items()
1164
+ if k not in ignore and k not in new_kwargs
1165
+ }
1166
+ )
1167
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
1168
+ getattr(self, "register_to_config")(**new_kwargs)
1169
+ init(self, *args, **init_kwargs)
1170
+
1171
+ return inner_init
InternLM/internlm/model/norm.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm
2
+
3
+ import numbers
4
+
5
+ import torch
6
+ from torch.nn import init
7
+ from torch.nn.parameter import Parameter
8
+
9
+
10
+ def manual_rms_norm(my_input, normalized_shape, weight, eps):
11
+ # layer norm should always be calculated in float32
12
+ dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
13
+ variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
14
+ my_input = my_input * torch.rsqrt(variance + eps)
15
+
16
+ if weight is None:
17
+ return my_input
18
+
19
+ # model_hf into half-precision if necessary
20
+ if weight.dtype in [torch.float16, torch.bfloat16]:
21
+ my_input = my_input.to(weight.dtype)
22
+
23
+ return weight * my_input
24
+
25
+
26
+ class RMSNormTorch(torch.nn.Module):
27
+ """A custom PyTorch module for RMS normalization."""
28
+
29
+ def __init__(self, normalized_shape, eps=1e-5):
30
+ super().__init__()
31
+
32
+ if isinstance(normalized_shape, numbers.Integral):
33
+ normalized_shape = (normalized_shape,)
34
+ self.normalized_shape = torch.Size(normalized_shape)
35
+ self.eps = eps
36
+ self.weight = Parameter(torch.empty(*normalized_shape))
37
+ self.reset_parameters()
38
+
39
+ def forward(self, _input: torch.Tensor):
40
+ return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
41
+
42
+ def reset_parameters(self):
43
+ init.ones_(self.weight)
44
+
45
+ def extra_repr(self):
46
+ return "{normalized_shape}, eps={eps}, ".format(**self.__dict__)
InternLM/internlm/model/utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from flash_attn.ops.fused_dense import FusedDenseFunc
9
+ from flash_attn.utils.distributed import (
10
+ all_gather_raw,
11
+ all_reduce_raw,
12
+ reduce_scatter_raw,
13
+ )
14
+ from torch import Tensor
15
+ from torch.cuda.amp import custom_bwd
16
+ from torch.distributed import ProcessGroup
17
+
18
+ from internlm.core.context import global_context as gpc
19
+ from internlm.utils.logger import get_logger
20
+
21
+ logger = get_logger(__file__)
22
+
23
+
24
+ def _split(input_, parallel_mode, dim=-1):
25
+ # skip if only one rank involved
26
+ world_size = gpc.get_world_size(parallel_mode)
27
+ if world_size == 1:
28
+ return input_
29
+
30
+ # Split along last dimension.
31
+ dim_size = input_.size(dim)
32
+ assert dim_size % world_size == 0, (
33
+ f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), "
34
+ f"cannot split tensor evenly"
35
+ )
36
+
37
+ tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
38
+ rank = gpc.get_local_rank(parallel_mode)
39
+ output = tensor_list[rank].contiguous()
40
+
41
+ return output
42
+
43
+
44
+ def _gather(input_, parallel_mode, dim=-1):
45
+ # skip if only one rank involved
46
+ world_size = gpc.get_world_size(parallel_mode)
47
+ if world_size == 1:
48
+ return input_
49
+
50
+ # all gather
51
+ rank = gpc.get_local_rank(parallel_mode)
52
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
53
+ tensor_list[rank] = input_
54
+ group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
55
+ torch.distributed.all_gather(tensor_list, input_, group=group)
56
+
57
+ # concat
58
+ output = torch.cat(tensor_list, dim=dim).contiguous()
59
+
60
+ return output
61
+
62
+
63
+ class _GatherForwardSplitBackward(torch.autograd.Function):
64
+ """Gather the input from model parallel region and concatenate.
65
+
66
+ Args:
67
+ input_: input matrix.
68
+ parallel_mode: parallel mode.
69
+ dim: dimension
70
+ """
71
+
72
+ @staticmethod
73
+ def symbolic(input_):
74
+ return _gather(input_, parallel_mode=None)
75
+
76
+ @staticmethod
77
+ def forward(ctx, input_, parallel_mode, dim):
78
+ ctx.mode = parallel_mode
79
+ ctx.dim = dim
80
+ return _gather(input_, parallel_mode, dim)
81
+
82
+ @staticmethod
83
+ def backward(ctx, grad_output):
84
+ return _split(grad_output, ctx.mode, ctx.dim), None, None
85
+
86
+
87
+ def gather_forward_split_backward(input_, parallel_mode, dim):
88
+ return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
89
+
90
+
91
+ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
92
+ assert my_input.dtype == grad_output.dtype
93
+ grad_weight = torch.matmul(grad_output.t(), my_input)
94
+ grad_bias = grad_output.sum(dim=0) if has_d_bias else None
95
+ return grad_weight, grad_bias
96
+
97
+
98
+ # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
99
+ class FusedDenseFuncTorch(FusedDenseFunc):
100
+ """A custom PyTorch module extending FusedDenseFunc."""
101
+
102
+ @staticmethod
103
+ @custom_bwd
104
+ def backward(ctx, grad_output, *args):
105
+ grad_output = grad_output.contiguous()
106
+ if ctx.return_residual:
107
+ (grad_input,) = args
108
+ grad_input = grad_input.contiguous()
109
+ process_group = ctx.process_group
110
+ sequence_parallel = ctx.sequence_parallel
111
+ if ctx.compute_weight_gradient:
112
+ x, weight = ctx.saved_tensors
113
+ if process_group is not None and sequence_parallel:
114
+ total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
115
+ else:
116
+ total_x = x
117
+ else:
118
+ (weight,) = ctx.saved_tensors
119
+ total_x = None
120
+ batch_shape = grad_output.shape[:-1]
121
+ batch_dim = batch_shape.numel()
122
+ grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
123
+ if ctx.needs_input_grad[0]:
124
+ if not ctx.return_residual:
125
+ grad_input = F.linear(grad_output, weight.t())
126
+ else:
127
+ grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight)
128
+ grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
129
+ if process_group is not None:
130
+ reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
131
+ grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
132
+ else:
133
+ grad_input = None
134
+ if ctx.needs_input_grad[1]:
135
+ assert ctx.compute_weight_gradient
136
+ if process_group is not None and sequence_parallel:
137
+ handle_x.wait()
138
+ # we remove the cuda independence, which is different from flash_attn.
139
+ grad_weight, grad_bias = linear_bias_wgrad_torch(
140
+ total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
141
+ )
142
+ else:
143
+ grad_weight = None
144
+ grad_bias = grad_output if ctx.needs_input_grad[2] else None
145
+ if process_group is not None and ctx.needs_input_grad[0]:
146
+ handle_grad_input.wait()
147
+ return grad_input, grad_weight, grad_bias, None, None, None
148
+
149
+
150
+ def fused_dense_func_torch(
151
+ x: Tensor,
152
+ weight: Tensor,
153
+ bias: Optional[Tensor] = None,
154
+ return_residual: bool = False,
155
+ process_group: Optional[ProcessGroup] = None,
156
+ sequence_parallel: bool = True,
157
+ ):
158
+ dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
159
+ x.dtype == torch.float32 and torch.is_autocast_enabled()
160
+ )
161
+ if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
162
+ return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
163
+ else:
164
+ return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel)
165
+
166
+
167
+ class _SplitForwardGatherBackward(torch.autograd.Function):
168
+ """
169
+ Split the input and keep only the corresponding chuck to the rank.
170
+
171
+ Args:
172
+ input_: input matrix.
173
+ parallel_mode: parallel mode.
174
+ dim: dimension
175
+ """
176
+
177
+ @staticmethod
178
+ def symbolic(input_):
179
+ return _split(input_, parallel_mode=None)
180
+
181
+ @staticmethod
182
+ def forward(ctx, input_, parallel_mode, dim):
183
+ ctx.mode = parallel_mode
184
+ ctx.dim = dim
185
+ return _split(input_, parallel_mode, dim)
186
+
187
+ @staticmethod
188
+ def backward(ctx, grad_output):
189
+ return _gather(grad_output, ctx.mode, ctx.dim), None, None
190
+
191
+
192
+ def split_forward_gather_backward(input_, parallel_mode, dim):
193
+ return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim)
194
+
195
+
196
+ def try_import_RMSNorm():
197
+ """
198
+ Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
199
+
200
+ """
201
+ try:
202
+ from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm
203
+
204
+ return RMSNorm
205
+ except ModuleNotFoundError:
206
+ logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!")
207
+ from internlm.model.norm import RMSNormTorch as RMSNorm
208
+
209
+ return RMSNorm
210
+
211
+
212
+ def try_import_LayerNorm():
213
+ """
214
+ Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm
215
+
216
+ """
217
+ try:
218
+ from apex.normalization.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
219
+
220
+ return LayerNorm
221
+ except ModuleNotFoundError:
222
+ import torch.nn as nn
223
+
224
+ return nn.LayerNorm