AWS Sagemaker: Model Parallelism for NLP GPT-2 model

0

Hello Dears,

I am working on an NLP product based on GPT-2, I have some problems with data_pipeline.py that is included in AWS repo: https://github.com/aws/amazon-sagemaker-examples/blob/main/training/distributed_training/pytorch/model_parallel/gpt2/data_pipeline.py

Please, I need your help with how to customize this script for my own dataset in detail.

Below is the class of my dataset:

class MyDataset(Dataset):

def __init__(self, data, tokenizer, randomize=True):
    title, text,Claims = [], [],[]
    for k, v in data.items():
        title.append(v[0])
        text.append(v[1])
        Claims.append(v[2])
        
    self.randomize = randomize
    self.tokenizer = tokenizer 
    self.title     = title
    self.text      = text
    self.Claims      = Claims
#---------------------------------------------#
def __len__(self):
    return len(self.text)
#---------------------------------------------#  
def __getitem__(self, i):

    input = SPECIAL_TOKENS['bos_token'] + self.title[i] + \
            SPECIAL_TOKENS['sep_token'] + self.text[i] + SPECIAL_TOKENS['sep_token'] + self.Claims[i] + SPECIAL_TOKENS['eos_token']

    encodings_dict = tokenizer(input,                                   
                               truncation=True, 
                               max_length=MAXLEN, 
                               padding="max_length")   
    
    input_ids = encodings_dict['input_ids']
    attention_mask = encodings_dict['attention_mask']
    
    return {'label': torch.tensor(input_ids),
            'input_ids': torch.tensor(input_ids), 
            'attention_mask': torch.tensor(attention_mask)}

Thanks in advance!

Basem

profile picture
질문됨 일 년 전235회 조회
1개 답변
0

Can't give you the exact method as do not know your context, but this are some things you can do for this

First, you will need to determine the format and structure of your data. For example, is it stored in a file or in a database? What are the fields or columns that are included in your data?

Based on the structure of your data, you will need to modify the init method of the MyDataset class to correctly parse and extract the relevant data. For example, if your data is stored in a file with three columns, you will need to modify the code to read in these columns and store them in the appropriate variables (e.g., title, text, and Claims).

You will also need to modify the getitem method of the MyDataset class to properly process and encode your data for use with the transformer model. For example, you may need to modify the input variable to reflect the structure of your data (e.g., if your data has additional fields or columns that you want to include). You may also need to modify the tokenizer function call to reflect the specific tokenization and encoding options that you want to use.

Finally, you may need to modify the len method of the MyDataset class to accurately return the number of data points in your dataset.

SeanSi
답변함 일 년 전

로그인하지 않았습니다. 로그인해야 답변을 게시할 수 있습니다.

좋은 답변은 질문에 명확하게 답하고 건설적인 피드백을 제공하며 질문자의 전문적인 성장을 장려합니다.

질문 답변하기에 대한 가이드라인