get_datasets(
*,
dataset_name: str | None = None,
dataset_config_name: str | None = None,
validation_split_percentage: float = 0,
train_file: str | None = None,
validation_file: str | None = None,
keep_linebreaks: bool = True
) -> datasets.Dataset | datasets.DatasetDict
Get the datasets.
You can either provide your own CSV/JSON/TXT training and evaluation files
(see below) or just provide the name of one of the public datasets
available on the hub at https://huggingface.co/datasets/ (the dataset will
be downloaded automatically from the datasets Hub).
For CSV/JSON files, this script will use the column called 'text' or the
first column if no column called 'text' is found. You can easily tweak this
behavior (see below).
In distributed training, the load_dataset function guarantee that only one
local process can concurrently.
Source code in llm/trainers/gpt/data.py
| def get_datasets(
*,
dataset_name: str | None = None,
dataset_config_name: str | None = None,
validation_split_percentage: float = 0,
train_file: str | None = None,
validation_file: str | None = None,
keep_linebreaks: bool = True,
) -> datasets.Dataset | datasets.DatasetDict:
"""Get the datasets.
You can either provide your own CSV/JSON/TXT training and evaluation files
(see below) or just provide the name of one of the public datasets
available on the hub at https://huggingface.co/datasets/ (the dataset will
be downloaded automatically from the datasets Hub).
For CSV/JSON files, this script will use the column called 'text' or the
first column if no column called 'text' is found. You can easily tweak this
behavior (see below).
In distributed training, the load_dataset function guarantee that only one
local process can concurrently.
"""
if dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = datasets.load_dataset(dataset_name, dataset_config_name)
if 'validation' not in raw_datasets.keys():
raw_datasets['validation'] = datasets.load_dataset(
dataset_name,
dataset_config_name,
split=f'train[:{validation_split_percentage}%]',
)
raw_datasets['train'] = datasets.load_dataset(
dataset_name,
dataset_config_name,
split=f'train[{validation_split_percentage}%:]',
)
elif train_file is not None:
data_files = {}
dataset_args = {}
data_files['train'] = train_file
if validation_file is not None:
data_files['validation'] = validation_file
extension = train_file.split('.')[-1]
if extension == 'txt':
extension = 'text'
dataset_args['keep_linebreaks'] = keep_linebreaks
raw_datasets = datasets.load_dataset(
extension,
data_files=data_files,
**dataset_args,
)
# If no validation data is there, validation_split_percentage will be
# used to divide the dataset.
if 'validation' not in raw_datasets.keys():
raw_datasets['validation'] = datasets.load_dataset(
extension,
data_files=data_files,
split=f'train[:{validation_split_percentage}%]',
**dataset_args,
)
raw_datasets['train'] = datasets.load_dataset(
extension,
data_files=data_files,
split=f'train[{validation_split_percentage}%:]',
**dataset_args,
)
else:
raise ValueError('One of dataset_name or train_file must be provided.')
return raw_datasets
|