diff --git a/Finetune4bConfig.py b/Finetune4bConfig.py index a8a33bf..09c535c 100644 --- a/Finetune4bConfig.py +++ b/Finetune4bConfig.py @@ -15,7 +15,7 @@ class Finetune4bConfig: warmup_steps: int, save_steps: int, save_total_limit: int, logging_steps: int, checkpoint: bool, skip: bool, verbose: bool, txt_row_thd: int, use_eos_token: bool, groupsize: int, - local_rank: int, + local_rank: int, flash_attention: bool ): """ Args: @@ -48,6 +48,7 @@ class Finetune4bConfig: use_eos_token (bool): Use Eos token instead of padding with 0 groupsize (int): Group size of V2 model, use -1 to load V1 model local_rank (int): local rank if using torch.distributed.launch + flash_attention (bool): Enables flash attention """ self.dataset = dataset self.ds_type = ds_type @@ -84,6 +85,7 @@ class Finetune4bConfig: if self.ddp: self.gradient_accumulation_steps = self.gradient_accumulation_steps // self.world_size self.groupsize = groupsize + self.flash_attention = flash_attention def __str__(self) -> str: