自定义 SageMaker 算法如何确定是否启用了检查点?

0

【以下的问题经过翻译处理】 根据SageMaker环境变量文档,算法应该将模型产物保存到由SM_MODEL_DIR 变量指定的文件夹中。

SageMaker容器文档描述了额外的环境变量,包括SM_OUTPUT_DATA_DIR用于写入非模型训练产物。

...但是算法如何确定是否已请求存储检查点?

在Amazon SageMaker中使用检查点文档只指定了一个默认本地保存路径,我找不到任何环境变量可以指示是否要保存检查点。我看到有一段代码检查默认本地路径是否存在,但我并不确定它是否有效(请求检查点时存在,不请求时不存在)。

将检查点参数化是件好事,可以避免在不需要的时候浪费EBS空间(和宝贵的IOPS);根据其他I/O(如模型和数据文件夹)的惯例,我认为SageMaker有特定的机制来传递这个指令,而不仅仅是定义一个算法超参数?

profile picture
EXPERTE
gefragt vor 8 Monaten41 Aufrufe
1 Antwort
0

【以下的回答经过翻译处理】 您好,

对于自定义 Sagemaker 容器或深度学习框架,我倾向于这样做。以下是我尝试过的pytorch的例子

  • Entry point 文件:
# 1. Define a custom argument, say checkpointdir
 parser.add_argument("--checkpointdir", help="The checkpoint dir", type=str,
                     default=None)
# 2. You can additional params for checkpoint frequency etc

# 3. Code for checkpointing
if checkpointdir is not None:
   #TODO: save mode
  • Jupyter 笔记本示例 Sagemaker estimator
# 1. Define local and remote variables for checkpoints
checkpoint_s3 = "s3://{}/{}t/".format(bucket, "checkpoints")
localcheckpoint_dir="/opt/ml/checkpoints/"

hyperparameters = {

    "batchsize": "8",
    "epochs" : "1000",
    "learning_rate":.0001,
    "weight_decay":5e-5,
    "momentum":.9,
    "patience": 20,
    "log-level" : "INFO",
    "commit_id":commit_id,
    "model" :"FasterRcnnFactory",
    "accumulation_steps": 8,
# 2.  define hp for checkpoint dir
    "checkpointdir": localcheckpoint_dir
}

# In the Sagemaker estimator fit, specify the local and remote path
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
     entry_point='experiment_train.py',
                    source_dir = 'src',
                    dependencies =['src/datasets', 'src/evaluators', 'src/models'],
                    role=role,
                    framework_version ="1.0.0",
                    py_version='py3',
                    git_config= git_config,
                    image_name= docker_repo,
                    train_instance_count=1,
                    train_instance_type=instance_type,
# 3. The entrypoint file will pick up the checkpoint location from here
                    hyperparameters =hyperparameters,
                    output_path=s3_output_path,
                    metric_definitions=metric_definitions,
                    train_use_spot_instances = use_spot,
                    train_max_run =  train_max_run_secs,
                    train_max_wait = max_wait_time_secs,   
                    base_job_name ="object-detection",
# 4. Sagemaker knows that the checkpoints will need to be periodically copied from the localcheckpoint_dir to s3 pointed to by checkpoint_s3
                    checkpoint_s3_uri=checkpoint_s3,
                    checkpoint_local_path=localcheckpoint_dir)
profile picture
EXPERTE
beantwortet vor 8 Monaten

Du bist nicht angemeldet. Anmelden um eine Antwort zu veröffentlichen.

Eine gute Antwort beantwortet die Frage klar, gibt konstruktives Feedback und fördert die berufliche Weiterentwicklung des Fragenstellers.

Richtlinien für die Beantwortung von Fragen