339 lines
13 KiB
Python
339 lines
13 KiB
Python
|
|
|
|
|
|
|
|
|
|
from caffe2.python.schema import Struct, ConstRecord
|
|
from caffe2.python import core, workspace, model_helper
|
|
from caffe2.python.session import LocalSession
|
|
from caffe2.python.dataset import Dataset
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.checkpoint import (
|
|
CheckpointManager, MultiNodeCheckpointManager, Job, JobRunner, epoch_limiter,
|
|
UploadTaskGroupBuilder, db_name)
|
|
from caffe2.python.net_builder import ops
|
|
from caffe2.python.task import Node, Task, TaskGroup, WorkspaceType, Cluster
|
|
from caffe2.python.test_util import TestCase
|
|
from caffe2.python.dataio import ReaderWithLimit
|
|
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
|
|
|
|
def build_pipeline(node_id):
|
|
with Node('trainer_%d' % node_id):
|
|
with Job.current().init_group, Task():
|
|
data_arr = Struct(('val', np.array(list(range(10)))))
|
|
data = ConstRecord(ops, data_arr)
|
|
ds = Dataset(data, name='dataset:%d' % node_id)
|
|
full_reader = ds.reader(ops)
|
|
total = ops.Const([100])
|
|
|
|
def inc_total(rec):
|
|
ops.Add([total, rec.val()], [total])
|
|
|
|
epoch_reader = ReaderWithLimit(full_reader, num_iter=3)
|
|
pipe(epoch_reader, processor=inc_total)
|
|
Job.current().add_stop_condition(epoch_reader.data_finished())
|
|
return [total]
|
|
|
|
|
|
EXPECTED_TOTALS = [103, 115, 136, 145]
|
|
|
|
|
|
def local_copy_op(src, dest):
|
|
def copy_op(inputs, outputs):
|
|
shutil.copyfile(src, dest)
|
|
return copy_op
|
|
|
|
|
|
class UploadToLocalFile(UploadTaskGroupBuilder):
|
|
def __init__(self, dest_dir):
|
|
self.dest_dir = dest_dir
|
|
|
|
def build(self, epoch, checkpoint_manager):
|
|
with TaskGroup(WorkspaceType.GLOBAL) as upload_task_group:
|
|
for node, manager in checkpoint_manager._node_managers:
|
|
with Node(str(node)), Task():
|
|
src_path = db_name(epoch, manager._node_name, manager._db_prefix)
|
|
dest_path = os.path.join(self.dest_dir, str(node))
|
|
ops.Python((local_copy_op,
|
|
[src_path, dest_path], {}))([], [])
|
|
return upload_task_group
|
|
|
|
|
|
class TestCheckpoint(TestCase):
|
|
def run_with(self, builder):
|
|
with Cluster():
|
|
with Job() as job:
|
|
outputs = build_pipeline(node_id=0)
|
|
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
|
|
|
|
def fetch_total(session):
|
|
session.run(output_fetcher)
|
|
return output_fetcher.outputs()[0].fetch()
|
|
|
|
session, checkpoint = builder()
|
|
job.compile(LocalSession)
|
|
num_epochs = JobRunner(job, checkpoint).train(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
|
|
|
for initial_epoch in range(1, num_epochs + 1):
|
|
session, checkpoint = builder()
|
|
JobRunner(
|
|
job,
|
|
checkpoint, resume_from_epoch=initial_epoch
|
|
).train(session)
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
|
|
|
for epoch in range(1, num_epochs + 1):
|
|
session.run(checkpoint.load(epoch))
|
|
self.assertEquals(fetch_total(session),
|
|
EXPECTED_TOTALS[epoch - 1])
|
|
|
|
def test_single_checkpoint(self):
|
|
# test single node
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
def builder():
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = CheckpointManager(tmpdir, 'temp_node', 'minidb')
|
|
return session, checkpoint
|
|
|
|
self.run_with(builder)
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
# test multi-node
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
|
|
def builder():
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
return session, checkpoint
|
|
|
|
self.run_with(builder)
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
def test_ckpt_name_and_load_model_from_ckpts(self):
|
|
try:
|
|
num_nodes = 3
|
|
tmpdir = tempfile.mkdtemp()
|
|
# First, check if the checkpoint name generation mechanism is
|
|
# correct.
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Cluster():
|
|
with Job() as job:
|
|
for node_id in range(num_nodes):
|
|
build_pipeline(node_id)
|
|
job.compile(LocalSession)
|
|
checkpoint.init(job.nodes_to_checkpoint())
|
|
|
|
for node_id in range(num_nodes):
|
|
epoch = 5
|
|
node_name = 'trainer_%d' % node_id
|
|
expected_db_name = tmpdir + '/' + node_name + '.5'
|
|
self.assertEquals(
|
|
checkpoint.get_ckpt_db_name(node_name, epoch),
|
|
expected_db_name)
|
|
shutil.rmtree(tmpdir)
|
|
|
|
# Next, check mechanism to load model from checkpoints.
|
|
tmpdir = tempfile.mkdtemp()
|
|
workspace.ResetWorkspace()
|
|
for node_id in range(num_nodes):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Cluster():
|
|
with Job() as job:
|
|
build_pipeline(node_id)
|
|
job.compile(LocalSession)
|
|
job_runner = JobRunner(job, checkpoint)
|
|
num_epochs = job_runner.train(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
|
|
# There are 17 global blobs after finishing up the job runner.
|
|
# (only blobs on init_group are checkpointed)
|
|
self.assertEquals(len(ws.blobs), 17)
|
|
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
self.assertEquals(len(ws.blobs), 0)
|
|
model_blob_names = ['trainer_1/task_2/GivenTensorInt64Fill:0',
|
|
'trainer_2/task_2/GivenTensorInt64Fill:0']
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Cluster():
|
|
with Job() as job:
|
|
for node_id in range(num_nodes):
|
|
build_pipeline(node_id)
|
|
job.compile(LocalSession)
|
|
job_runner = JobRunner(job, checkpoint)
|
|
job_runner.load_blobs_from_checkpoints(
|
|
blob_names=model_blob_names, epoch=1, session=session)
|
|
|
|
# Check that we can successfully load from checkpoints of epochs
|
|
# 1 to 4, but not epoch 5.
|
|
for epoch in range(1, 5):
|
|
self.assertTrue(
|
|
job_runner.load_blobs_from_checkpoints(
|
|
blob_names=model_blob_names, epoch=epoch,
|
|
session=session))
|
|
# Check that all the model blobs are loaded.
|
|
for blob_name in model_blob_names:
|
|
self.assertTrue(ws.has_blob(blob_name))
|
|
self.assertEquals(
|
|
ws.fetch_blob(blob_name),
|
|
np.array([EXPECTED_TOTALS[epoch - 1]]))
|
|
self.assertFalse(
|
|
job_runner.load_blobs_from_checkpoints(
|
|
blob_names=model_blob_names, epoch=5, session=session))
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
def test_upload_checkpoint(self):
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
upload_dir = os.path.join(tmpdir, "upload")
|
|
os.mkdir(upload_dir)
|
|
num_nodes = 3
|
|
|
|
# The uploaded files do not exist yet.
|
|
for node_id in range(num_nodes):
|
|
node_name = 'trainer_%d' % node_id
|
|
upload_path = os.path.join(upload_dir, node_name)
|
|
self.assertFalse(os.path.exists(upload_path))
|
|
|
|
# Create and run the job runner.
|
|
for node_id in range(3):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Cluster():
|
|
with Job() as job:
|
|
build_pipeline(node_id)
|
|
job.compile(LocalSession)
|
|
local_upload_builder = UploadToLocalFile(upload_dir)
|
|
job_runner = JobRunner(
|
|
job, checkpoint,
|
|
upload_task_group_builder=local_upload_builder)
|
|
num_epochs = job_runner.train(session)
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
|
|
# The uploaded files should exist now.
|
|
for node_id in range(num_nodes):
|
|
node_name = 'trainer_%d' % node_id
|
|
upload_path = os.path.join(upload_dir, node_name)
|
|
self.assertTrue(os.path.exists(upload_path))
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|
|
|
|
def test_ckpt_save_failure(self):
|
|
num_nodes = 3
|
|
# The goal of this test is to ensure that the job runs
|
|
# successfully even if saving a checkpoint fails.
|
|
# Hence tmpdir is a non existent directory to emulate a failure
|
|
# while saving checkpoints
|
|
tmpdir = "/tmp/path_does_not_exist/"
|
|
|
|
# Check the saving checkpoint failure does not cause job failure
|
|
workspace.ResetWorkspace()
|
|
for node_id in range(num_nodes):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
with Cluster():
|
|
with Job() as job:
|
|
build_pipeline(node_id)
|
|
job.compile(LocalSession)
|
|
job_runner = JobRunner(job, checkpoint)
|
|
num_epochs = job_runner.train(session)
|
|
# make sure all epochs are executed even though saving the checkpoint failed
|
|
# Saving checkpoint failure should not cause job failure
|
|
self.assertEquals(num_epochs, len(EXPECTED_TOTALS))
|
|
|
|
def test_download_group_simple(self):
|
|
"""
|
|
A simple test that ensures we have download task group
|
|
executed between epoch_group and exit_group.
|
|
"""
|
|
model = model_helper.ModelHelper(name="test_model")
|
|
download_net = core.Net("download_net")
|
|
|
|
for name in ["input1", "input2", "output", "download_result"]:
|
|
model.param_init_net.ConstantFill([],
|
|
[name],
|
|
shape=[8, ],
|
|
value=1.0,
|
|
run_once=0)
|
|
model.net.Add(["input1", "input2"], ["output"])
|
|
download_net.Copy(["output"], ["download_result"])
|
|
|
|
# All blob values are initialized as 1.0, after download_net executed
|
|
# we expect to see download result is the same as training result.
|
|
with Job() as job:
|
|
with Node("trainer:0"):
|
|
with job.init_group:
|
|
Task(step=model.param_init_net)
|
|
with job.epoch_group:
|
|
with Task():
|
|
with ops.loop(1):
|
|
ops.net(model.net)
|
|
with job.download_group:
|
|
Task(step=download_net)
|
|
|
|
epoch_limiter(job, 1)
|
|
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
job_runner = JobRunner(job)
|
|
job_runner.train(session)
|
|
|
|
expected_result = np.full(8, 2.0).astype(np.float32)
|
|
self.assertTrue(np.array_equal(expected_result,
|
|
ws.fetch_blob("output")))
|
|
self.assertTrue(np.array_equal(expected_result,
|
|
ws.fetch_blob("download_result")))
|
|
|
|
def test_reuse_checkpoint_manager(self):
|
|
"""
|
|
A simple test that ensures we can reuse a MultiNodeCheckpointManager
|
|
object.
|
|
"""
|
|
try:
|
|
tmpdir = tempfile.mkdtemp()
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
checkpoint = MultiNodeCheckpointManager(tmpdir, 'minidb')
|
|
|
|
with Job() as job:
|
|
outputs = build_pipeline(node_id=0)
|
|
output_fetcher = Task(step=core.Net('empty'), outputs=outputs)
|
|
job.compile(LocalSession)
|
|
|
|
def fetch_total(session):
|
|
session.run(output_fetcher)
|
|
return output_fetcher.outputs()[0].fetch()
|
|
|
|
num_epochs = JobRunner(job, checkpoint).train(session)
|
|
for initial_epoch in range(1, num_epochs + 1):
|
|
JobRunner(
|
|
job,
|
|
checkpoint,
|
|
resume_from_epoch=initial_epoch
|
|
).train(session)
|
|
self.assertEquals(fetch_total(session), EXPECTED_TOTALS[-1])
|
|
|
|
finally:
|
|
shutil.rmtree(tmpdir)
|