-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtokenize_data.py
More file actions
39 lines (29 loc) · 1023 Bytes
/
tokenize_data.py
File metadata and controls
39 lines (29 loc) · 1023 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer
import torch.utils.data
import os
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('opt-seq-pubmed-tokenizer')
max_length = 1024
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=max_length, truncation=True)
cache_dir = './tmp' # Cache directory
data_files = {'train': ["data/trn.csv"],
'test' : ["data/val.csv"]}
extension = 'csv'
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=cache_dir)
preprocessing_num_workers = 4
overwrite_cache = False
removed_columns = ['text']
tokenized_datasets = raw_datasets.map(
tokenize_function,
#batched=True,
num_proc=preprocessing_num_workers,
remove_columns=removed_columns,
load_from_cache_file=not overwrite_cache,
desc="Running tokenizer on dataset",
)
tokenized_datasets.save_to_disk('tokenized_data')
# Remove cache_dir
#os.system("rm -rf ./tmp")
##ds = load_from_disk('TMP')