188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408 | class DataCollatorForCompletionOnlyLM:
"""
Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
when they do not come from the assistant. This ensures that the loss is only
calculated on the completion made by the assistant.
"""
def __init__(
self,
tokenizer: Union[str, PreTrainedTokenizerBase],
response_template: Union[str, List[int]],
instruction_template: Optional[Union[str, List[int]]] = None,
*args,
mlm: bool = False,
ignore_index: int = -100,
**kwargs,
):
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
self.tokenizer = tokenizer
self.instruction_template = instruction_template
if isinstance(instruction_template, str):
self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
else:
self.instruction_token_ids = instruction_template
self.response_template = response_template
if isinstance(response_template, str):
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
else:
self.response_token_ids = response_template
if not mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
warnings.warn(
"The pad_token_id and eos_token_id values of this tokenizer are identical. "
"If you are planning for multi-turn training, "
"it can result in the model continuously generating questions and answers without eos token. "
"To avoid this, set the pad_token_id to a different value."
)
self.ignore_index = ignore_index
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers. "
"Please refer to the documentation for more information."
)
cand_indexes = []
for i, token in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * 0.15))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
if len(covered_indexes) != len(masked_lms):
raise ValueError("Length of covered_indexes is not equal to length of masked_lms.")
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
def jax_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = np.copy(inputs)
probability_matrix = np.full(labels.shape, 0.15)
if special_tokens_mask is None:
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
else:
special_tokens_mask = special_tokens_mask.astype(bool)
probability_matrix[special_tokens_mask] = 0
masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
labels[~masked_indices] = -100
indices_replaced = np.random.binomial(1, 0.8, size=labels.shape).astype(bool) & masked_indices
inputs[indices_replaced] = self.tokenizer.mask_token_id
indices_random = (
np.random.binomial(1, 0.5, size=labels.shape).astype(bool) & masked_indices & ~indices_replaced
)
random_words = np.random.randint(
low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
)
inputs[indices_random] = random_words
return inputs, labels
def jax_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
if isinstance(examples[0], Mapping):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = _collate_batch(input_ids, self.tokenizer, )
mask_labels = []
for e in examples:
ref_tokens = []
for ida in tolist(e["input_ids"]):
token = self.tokenizer._convert_id_to_token(ida)
ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = _collate_batch(mask_labels, self.tokenizer, )
inputs, labels = self.jax_mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def __call__(
self,
examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
batch = self.jax_call(examples)
if self.instruction_template is None:
for i in range(len(examples)):
response_token_ids_start_idx = None
for idx in jnp.where(batch["labels"][i] == self.response_token_ids[0])[0]:
if (
self.response_token_ids
== batch["labels"][i][idx: idx + len(self.response_token_ids)].tolist()
):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in jnp.where(batch["labels"][i] == self.response_token_ids[0])[0]:
if (
self.response_token_ids
== batch["labels"][i][assistant_idx: assistant_idx + len(self.response_token_ids)].tolist()
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.instruction_token_ids
for human_idx in jnp.where(batch["labels"][i] == human_token_ids[0])[0]:
if human_token_ids == batch["labels"][i][human_idx: human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
if (
len(human_token_ids_idxs) > 0
and len(response_token_ids_idxs) > 0
and human_token_ids_idxs[0] > response_token_ids_idxs[0]
):
human_token_ids_idxs = [0] + human_token_ids_idxs
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
if idx != 0:
batch["labels"][i, start:end] = self.ignore_index
else:
batch["labels"][i, :end] = self.ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
batch["labels"][i, human_token_ids_idxs[-1]:] = self.ignore_index
return batch
|