AWS Sagemaker: Model Parallelism for NLP GPT-2 model

0

Hello Dears,

I am working on an NLP product based on GPT-2, I have some problems with data_pipeline.py that is included in AWS repo: https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel/gpt2/data_pipeline.py

Please, I need your help with how to customize this script for my own dataset in detail.

Below is the class of my dataset:

class MyDataset(Dataset):

def __init__(self, data, tokenizer, randomize=True):
    title, text,Claims = [], [],[]
    for k, v in data.items():
        title.append(v[0])
        text.append(v[1])
        Claims.append(v[2])
        
    self.randomize = randomize
    self.tokenizer = tokenizer 
    self.title     = title
    self.text      = text
    self.Claims      = Claims
#---------------------------------------------#
def __len__(self):
    return len(self.text)
#---------------------------------------------#  
def __getitem__(self, i):

    input = SPECIAL_TOKENS['bos_token'] + self.title[i] + \
            SPECIAL_TOKENS['sep_token'] + self.text[i] + SPECIAL_TOKENS['sep_token'] + self.Claims[i] + SPECIAL_TOKENS['eos_token']

    encodings_dict = tokenizer(input,                                   
                               truncation=True, 
                               max_length=MAXLEN, 
                               padding="max_length")   
    
    input_ids = encodings_dict['input_ids']
    attention_mask = encodings_dict['attention_mask']
    
    return {'label': torch.tensor(input_ids),
            'input_ids': torch.tensor(input_ids), 
            'attention_mask': torch.tensor(attention_mask)}

Thanks in advance!

Basem

profile picture
已提问 1 年前228 查看次数
1 回答
0

Can't give you the exact method as do not know your context, but this are some things you can do for this

First, you will need to determine the format and structure of your data. For example, is it stored in a file or in a database? What are the fields or columns that are included in your data?

Based on the structure of your data, you will need to modify the init method of the MyDataset class to correctly parse and extract the relevant data. For example, if your data is stored in a file with three columns, you will need to modify the code to read in these columns and store them in the appropriate variables (e.g., title, text, and Claims).

You will also need to modify the getitem method of the MyDataset class to properly process and encode your data for use with the transformer model. For example, you may need to modify the input variable to reflect the structure of your data (e.g., if your data has additional fields or columns that you want to include). You may also need to modify the tokenizer function call to reflect the specific tokenization and encoding options that you want to use.

Finally, you may need to modify the len method of the MyDataset class to accurately return the number of data points in your dataset.

SeanSi
已回答 1 年前

您未登录。 登录 发布回答。

一个好的回答可以清楚地解答问题和提供建设性反馈,并能促进提问者的职业发展。

回答问题的准则