diff --git a/train_data.py b/train_data.py index 62dee6d..76cb2e8 100644 --- a/train_data.py +++ b/train_data.py @@ -84,7 +84,7 @@ class TrainTxt(ATrainData): r_b = '' return new_rows - def prepare_data(self, thd=-1): + def prepare_data(self, thd=-1, **kwargs): if os.path.isdir(self.dataset): rows = [] for filename in os.listdir(self.dataset): @@ -124,7 +124,7 @@ class TrainSAD(ATrainData): "attention_mask": result["attention_mask"][:-1], } - def prepare_data(self) -> None: + def prepare_data(self, **kwargs) -> None: data = load_dataset("json", data_files=self.dataset) if self.val_set_size > 0: