VESSL Docs
Search…
Running Spot Instances
Vessl supports Amazon EC2 Spot Instances on Amazon Elastic Kubernetes Service. Spot instances are attractive in terms of price and performance compared to on-demand instances, especially on stateless and fault-tolerant container runs.
Be aware that spot instances are subject to interruptions. The claimed spot instances are suspended with 2 minutes of notice if the resource is needed elsewhere. Thus, saving and loading models for each epoch is highly recommended. Fortunately, most ML toolkits such as Fairseq and Detectron2, provide checkpointing which keeps the best-performing model. Refer to following documents to find more information about checkpointing:
  • PyTorch: Saving and Loading Models
  • TensorFlow: Save and Load Models
Refer to example codes at Vessl GitHub repository.

1. Save Checkpoints

While training a model, you need to save the model periodically. The following PyTorch and Keras code compares validation accuracy and save the best performing model for each eopch. Note that the code keeps track of checkpoints so you can load the value as a starch_epoch value.
PyTorch
Keras
1
import torch
2
3
def save_checkpoint(state, is_best, filename):
4
if is_best:
5
print("=> Saving a new best")
6
torch.save(state, filename)
7
else:
8
print("=> Validation Accuracy did not improve")
9
10
11
for epoch in range(epochs):
12
train(...)
13
test_accuracy =
14
15
16
test_accuracy = torch.FloatTensor([test_accuracy])
17
is_best = bool(test_accuracy.numpy() > best_accuracy.numpy())
18
best_accuracy = torch.FloatTensor(
19
max(test_accuracy.numpy(), best_accuracy.numpy()))
20
save_checkpoint({
21
'epoch': start_epoch + epoch + 1,
22
'state_dict': model.state_dict(),
23
'best_accuracy': best_accuracy,
24
}, is_best, checkpoint_file_path)
Copied!
1
from savvihub.keras import SavviHubCallback
2
from keras.callbacks import ModelCheckpoint
3
import os
4
5
checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
6
checkpoint_dir = os.path.dirname(checkpoint_path)
7
8
checkpoint_callback = ModelCheckpoint(
9
checkpoint_path,
10
monitor='val_accuracy',
11
verbose=1,
12
save_weights_only=True,
13
mode='max',
14
save_freq=args.save_model_freq,
15
)
16
17
# Compile model
18
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
19
model.compile(optimizer='adam',
20
loss=loss_fn,
21
metrics=['accuracy'])
22
23
model.save_weights(checkpoint_path.format(epoch=0))
24
25
model.fit(x_train, y_train,
26
batch_size=args.batch_size,
27
validation_data=(x_val, y_val),
28
epochs=args.epochs,
29
callbacks=[
30
SavviHubCallback(
31
data_type='image',
32
validation_data=(x_val, y_val),
33
num_images=5,
34
start_epoch=start_epoch,
35
save_image=args.save_image,
36
),
37
checkpoint_callback,
38
])
Copied!

2. Load Checkpoints

When spot instances are interrupted, the code is executed again from the beginning. To prevent this, you need to write a code that loads the saved checkpoint.
PyTorch
Keras
1
import torch
2
import os
3
4
def load_checkpoint(checkpoint_file_path):
5
print(f"=> Loading checkpoint '{checkpoint_file_path}' ...")
6
if device == 'cuda':
7
checkpoint = torch.load(checkpoint_file_path)
8
else:
9
checkpoint = torch.load(checkpoint_file_path,
10
map_location=lambda storage, loc: storage)
11
model.load_state_dict(checkpoint.get('state_dict'))
12
print(f"=> Loaded checkpoint (trained for {checkpoint.get('epoch')} epochs)")
13
return checkpoint.get('epoch'), checkpoint.get('best_accuracy')
14
15
16
if os.path.exists(args.checkpoint_path) and os.path.isfile(checkpoint_file_path):
17
start_epoch, best_accuracy = load_checkpoint(checkpoint_file_path)
18
else:
19
print("=> No checkpoint has found! train from scratch")
20
start_epoch, best_accuracy = 0, torch.FloatTensor([0])
21
if not os.path.exists(args.checkpoint_path):
22
print(f" [*] Make directories : {args.checkpoint_path}")
23
os.makedirs(args.checkpoint_path)
Copied!
1
import os
2
import tensorflow as tf
3
4
def parse_epoch(file_path):
5
return int(os.path.splitext(os.path.basename(file_path))[0].split('-')[1])
6
7
8
checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
9
checkpoint_dir = os.path.dirname(checkpoint_path)
10
if os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
11
latest = tf.train.latest_checkpoint(checkpoint_dir)
12
print(f"=> Loading checkpoint '{latest}' ...")
13
model.load_weights(latest)
14
start_epoch = parse_epoch(latest)
15
print(f'start_epoch:{start_epoch}')
16
else:
17
start_epoch = 0
18
if not os.path.exists(args.checkpoint_path):
19
print(f" [*] Make directories : {args.checkpoint_path}")
20
os.makedirs(args.checkpoint_path)
Copied!
The start_epoch value is a useful workaround to logging metrics to the Vessl server. Otherwise, the metrics graph might crash due to the spot instance interruption.
PyTorch
Keras
1
import savvihub
2
3
def train(...):
4
...
5
savvihub.log(
6
step=epoch+start_epoch+1,
7
row={'loss': loss.item()}
8
)
Copied!
1
from savvihub.keras import SavviHubCallback
2
3
model.fit(...,
4
callbacks=[SavviHubCallback(
5
...,
6
start_epoch=start_epoch,
7
...,
8
)]
9
)
Copied!

3. Use the spot instance option

To use a spot instance on Vessl, click the Use Spot Instance checkbox. We also put the postfix *.spot for every spot instance resource type. More resource types will be added in the future.
Last modified 1mo ago