DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
Parameters:
Name |
Type |
Description |
Default |
pad_token_id |
int
|
int: The tokenizers pad_token_id.
|
0
|
label_pad_token_id |
int
|
int: The label used for masking.
|
-100
|
is_encoder_decoder |
Optional[bool]
|
Optional[bool]: Whether you model has an encoder_decoder architecture
|
False
|
Source code in src/python/easydel/trainer/dpo/utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110 | @dataclass
class DPODataCollatorWithPadding:
r"""
DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch.
:param pad_token_id: int: The tokenizers pad_token_id.
:param label_pad_token_id: int: The label used for masking.
:param is_encoder_decoder: Optional[bool]: Whether you model has an encoder_decoder architecture
"""
max_prompt_length: int
max_target_length: int
pad_token_id: int = 0
label_pad_token_id: int = -100
is_encoder_decoder: Optional[bool] = False
ids_to_pop_from_dataset: Optional[dict] = None
auto_fix_data: bool = True
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
padded_batch = {}
for k in features[0].keys():
if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"):
if self.is_encoder_decoder:
to_pad = [jnp.array(ex[k], dtype="i4") for ex in features]
if (k.startswith("prompt")) and (k.endswith("input_ids")):
padding_value = self.pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
padding_value = self.label_pad_token_id
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4")
else:
if "prompt" in k:
to_pad = [jnp.array(ex[k][::-1], dtype="i4") for ex in features]
else:
to_pad = [jnp.array(ex[k], dtype="i4") for ex in features]
if k.endswith("_input_ids"):
padding_value = self.pad_token_id
elif k.endswith("_labels"):
padding_value = self.label_pad_token_id
elif k.endswith("_attention_mask"):
padding_value = 0
else:
raise ValueError(f"Unexpected key in batch '{k}'")
padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value).astype("i4")
if "prompt" in k:
padded_batch[k] = jnp.flip(padded_batch[k], axis=[1])
elif k.endswith("_logps"):
padded_batch[k] = jnp.array([ex[k] for ex in features])
else:
padded_batch[k] = [ex[k] for ex in features]
if self.ids_to_pop_from_dataset:
for key in self.ids_to_pop_from_dataset:
_ = padded_batch.pop(key, None)
for key in list(padded_batch.keys()):
if not (
key.endswith("_input_ids")
or key.endswith("_attention_mask")
or key.endswith("_labels")
or key.endswith("_log_probs")
):
_ = padded_batch.pop(key, None)
for k in list(padded_batch.keys()):
v = padded_batch[k]
padded_batch[k] = v.reshape(v.shape[0], -1)
if self.auto_fix_data:
padded_batch["rejected_input_ids"] = padded_batch["rejected_input_ids"][..., :self.max_target_length]
padded_batch[
"rejected_attention_mask"
] = padded_batch["rejected_attention_mask"][..., :self.max_target_length]
padded_batch["rejected_labels"] = padded_batch["rejected_labels"][..., :self.max_target_length]
padded_batch["chosen_input_ids"] = padded_batch["chosen_input_ids"][..., :self.max_target_length]
padded_batch["chosen_attention_mask"] = padded_batch["chosen_attention_mask"][..., :self.max_target_length]
padded_batch["chosen_labels"] = padded_batch["chosen_labels"][..., :self.max_target_length]
padded_batch["prompt_input_ids"] = padded_batch["prompt_input_ids"][..., :self.max_prompt_length]
padded_batch[
"prompt_attention_mask"
] = padded_batch["prompt_attention_mask"][..., :self.max_prompt_length]
return padded_batch
|