From 540f3b7c4aa2386ba034a2c8f1618aa9be369632 Mon Sep 17 00:00:00 2001 From: bhetherman Date: Sat, 5 Mar 2022 09:03:28 +0000 Subject: [PATCH] updates to refactor branch --- build.sh | 0 cluster.py | 34 +++++++--- csrc/CUDAGraph.cpp | 2 +- csrc/JobContext.cpp | 70 +++++++++++++++----- csrc/JobContext.h | 13 ++-- csrc/communication.h | 4 +- csrc/dataset.cpp | 110 +++++++++++++++++--------------- csrc/dataset.h | 131 +++++++++++++++++++++++++------------- csrc/rpcService.cpp | 1 + csrc/runnableModule.cpp | 118 ++++++++++++++++++++++++++++------ csrc/runnableModule.h | 25 ++++++-- csrc/tracer.h | 6 ++ jobDescription.py | 21 +++--- parallelizationPlanner.py | 60 +++++++++++++---- 14 files changed, 423 insertions(+), 172 deletions(-) mode change 100644 => 100755 build.sh mode change 100644 => 100755 parallelizationPlanner.py diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/cluster.py b/cluster.py index cc2065e..9489526 100644 --- a/cluster.py +++ b/cluster.py @@ -28,6 +28,7 @@ import grpc import runtime_pb2 import runtime_pb2_grpc +import os # import examples.vgg as vgg # TODO: this is used for debugging. Remove this later. @@ -79,7 +80,7 @@ def initCommNCCL(self, message, msgType, groupId, members): response = self.stub.InitCommNCCL(runtime_pb2.InitCommNCCLMsg( message=message, msg_type=msgType, group_id=groupId, members=members)) print("received: " + response.message) - return response.group_id; + return response.group_id def initCommGRPC(self, rankToIpMap): rankToIpMapInJson = json.dumps(rankToIpMap) @@ -103,10 +104,10 @@ def __init__(self, address: str, port: int, device: int, userId: str, sshKeyPath self.serverId = None self.proxy = None self.isCpp = isCpp - self.is_local = address == "127.0.0.1" + self.is_local = (address == "127.0.0.1" or address == "localhost") self.process = None - def getProxy(self, maxRetry = 180): + def getProxy(self, maxRetry = 360): if self.proxy != None: # print("getProxy() returned from cached proxy value.") return self.proxy @@ -200,11 +201,13 @@ def upSync(self, localPath, remotePath): class ClusterCoordinator(xmlrpc.server.SimpleXMLRPCServer): """ GPU cluster coordinator. It accepts training jobs from clients and schedule them to runtimes. """ - def __init__(self, addrToBind: str, portToBind: int, locations: List[Location], workDir: str, be_batch_size: int): + def __init__(self, addrToBind: str, portToBind: int, locations: List[Location], workDir: str, be_batch_size: int, localhost:bool): super(ClusterCoordinator, self).__init__((addrToBind, portToBind)) self.myAddr = addrToBind self.myPort = portToBind self.locations = locations + self.worldSize = len(locations) + self.localhost = localhost self.workDir = workDir self.processes = [] # from subprocess calls used for launching runtime. self.nextTagStartOffset = 1 @@ -233,7 +236,7 @@ def export_poke(self): return 'Returned from poke at %s' % self.myAddr def export_scheduleTraining(self, jobName: str, trainingJobInJSON: str, runbe): - job = TrainingJob("test", None, None, 0, 0, "") + job = TrainingJob("test", None, None, 0, 0, 0, "") job.loadJSON(trainingJobInJSON) print("received job") @@ -257,11 +260,15 @@ def export_scheduleTraining(self, jobName: str, trainingJobInJSON: str, runbe): if len(self.locations) < gpusUsed: return "Not enough servers available. %d gpus available while %d needed" % (len(self.locations), gpusUsed) + lfn = "NLL" + if "gpt2" in jobName: + lfn = "CrossEntropyLoss" jobParams = { "run_with_be": runbe, "nr_gpus": gpusUsed, "cifar_training": "cifar" in jobName, - "lossfn": "CrossEntropyLoss" if "gpt2" in jobName else "NLL", + "lossfn": lfn, + "epochsToTrain": 100 } jobParamsInJson = json.dumps(jobParams) @@ -277,7 +284,7 @@ def requestScheduleTraining(proxy, jobInJson): thread.start() waitthreads(threadList) - self.ongoingJobs[jobName] = {"iterTime": 0, "gpuMsec": 0, "gpusUsed": gpusUsed, "gpusFinished": 0, "globalBatchSize": job.globalBatchSize} + self.ongoingJobs[jobName] = {"iterTime": 0, "gpuMsec": 0, "gpusUsed": gpusUsed, "gpusFinished": 0, "globalBatchSize": job.globalBatchSize, "lossfn": job.lossfn} self.ongoingJobs[jobName].update({"beImagesPerIter": 0.0, "idleMsPerIter": 0.0}) # for rank in range(gpusUsed): @@ -518,7 +525,11 @@ def parse_args(): help="To launch CPP version runtimes.") parser.add_argument('--manualLaunch', default=False, action='store_true', help="Do not runtimes automatically. Primarily for using gdb on runtime processes.") - parser.add_argument("--logdir", type=str, default="", help="Full path of log directory") + parser.add_argument("--localhost", type=str, default=True, + help="Run cluster on local host only") + parser.add_argument("--logdir", type=str, default=os.getcwd(), + help="Run cluster on local host only") + # parser.add_argument("--logdir", type=str, default="", help="Full path of log directory") # For installing nsys.. (with other cuda toolkit..) # wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin # sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600 @@ -542,12 +553,15 @@ def main(): for deviceConfig in serverConfig["deviceList"]: rankToIpMap[str(len(locations))] = serverConfig["addr"] + ":" + str(deviceConfig["port"]) commGrpRanksWorld.append(len(locations)) - locations.append(Location(serverConfig["addr"], deviceConfig["port"], deviceConfig["device"], serverConfig["userId"], serverConfig["sshKeyPath"], args.cpp)) + if args.localhost: + locations.append(Location(serverConfig["addr"], deviceConfig["port"], deviceConfig["device"], serverConfig["userId"], None, args.cpp)) + else: + locations.append(Location(serverConfig["addr"], deviceConfig["port"], deviceConfig["device"], serverConfig["userId"], None, args.cpp)) addrToBindCombo = re.split('[-:]', args.addrToBind) addrToBind = addrToBindCombo[0] portToBind = int(addrToBindCombo[1]) - coordinator = ClusterCoordinator(addrToBind, portToBind, locations, clusterConfig["workDir"], args.be_batch_size) + coordinator = ClusterCoordinator(addrToBind, portToBind, locations, clusterConfig["workDir"], args.be_batch_size, args.localhost) if args.install: coordinator.installPackages() diff --git a/csrc/CUDAGraph.cpp b/csrc/CUDAGraph.cpp index 31a6ba9..f5b3e6f 100644 --- a/csrc/CUDAGraph.cpp +++ b/csrc/CUDAGraph.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include diff --git a/csrc/JobContext.cpp b/csrc/JobContext.cpp index ac5f223..4a826df 100644 --- a/csrc/JobContext.cpp +++ b/csrc/JobContext.cpp @@ -67,13 +67,13 @@ JobContext::JobContext(std::unique_ptr modelIn, if (job_params.contains("epochs_to_train")) epochsToTrain = job_params["epochs_to_train"].get(); - train_dataset_.reset( - Dataset::fromName(dset, rtctx->rank, model->globalBatchSize, - model->initialBatchSizes, model->sampleIndices, 2000)); - eval_dataset_.reset( - Dataset::fromName(dset + "_eval", rtctx->rank, model->globalBatchSize, - model->initialBatchSizes, model->sampleIndices, 10)); - dataset_pipeline_.reset(new DatasetPipelineWrapper(train_dataset_)); + train_dataset_ = Dataset::fromName(dset, rtctx->rank, model->globalBatchSize, + model->initialBatchSizes, model->sampleIndices, 2000, rtctx->worldSize); + + eval_dataset_ = Dataset::fromName(dset + "_eval", rtctx->rank, model->globalBatchSize, + model->initialBatchSizes, model->sampleIndices, 10, rtctx->worldSize); + + dataset_pipeline_ = std::make_shared(train_dataset_); if (!rtctx->use_fg_graph) iters_before_graph_capture = itersToTrain * epochsToTrain; @@ -158,11 +158,12 @@ void JobContext::StepOne(bool *iter_done) { job_done_ = true; return; } - ++totiters; + if (!graphCapture) + ++totiters; } } -void JobContext::Test() { +void JobContext::Test(int64_t curEpoch) { double total = 0.0; torch::Tensor correct = torch::zeros({1}).to(at::kLong).to(rtctx->c10dev); @@ -176,12 +177,12 @@ void JobContext::Test() { total += model->GetGlobalBatchSize(); auto batch = eval_dataset_->getNextThisRank(); - torch::Tensor input = batch.data; + torch::Tensor input = batch["data"]; if (input.defined()) input = input.to(rtctx->c10dev); auto output = Infer(input); if (output.defined() && output.nbytes() > 0) { auto pred = output.argmax(1); - correct += pred.eq(batch.target.to(rtctx->c10dev)).sum(); + correct += pred.eq(batch["target"].to(rtctx->c10dev)).sum(); } DP_LOG(DEBUG, "Evaluate iteration %lu/%lu\n", ++i, eval_dataset_->GetItersPerEpoch()); @@ -204,29 +205,62 @@ void JobContext::Test() { } torch::Tensor JobContext::Infer(torch::Tensor input) { + bool will_do_graph_capture = false; + if(totiters == iters_before_graph_capture && !model->has_graph) + will_do_graph_capture = true; + torch::NoGradGuard guard; model->SetEval(); - model->SetInputsTargets(input, {}); + model->SetInputsTargets(input); FinishIteration(); + + if(will_do_graph_capture){ + //model doesnt actuall run when the graph is captured + model->SetInputsTargets(input); + FinishIteration(); + } return model->getOutput(); } -void JobContext::Train(torch::Tensor input, torch::Tensor target) { +void JobContext::Train(torch::Tensor input, torch::Tensor target, torch::Tensor weights) { + bool will_do_graph_capture = false; + if(totiters == iters_before_graph_capture && !model->has_graph) + will_do_graph_capture = true; + model->SetTrain(); - model->SetInputsTargets(input, target); + model->SetInputsTargets(input, target, weights); FinishIteration(); + + if(will_do_graph_capture){ + //model doesnt actuall run when the graph is captured + model->SetInputsTargets(input, target, weights); + FinishIteration(); + } } -void JobContext::TrainOneEpoch() { +void JobContext::TrainOneEpoch(int64_t curEpochh) { dataset_pipeline_->Reset(); + model->ResetAvgLoss(); size_t i = 0; if (iters_before_graph_capture < totiters && rtctx->use_fg_graph) iters_before_graph_capture = totiters + 5; while (!dataset_pipeline_->IsDone() && !job_done_) { auto batch = dataset_pipeline_->getNextThisRank(); - Train(batch.data, batch.target); + if(batch.find("data") == batch.end()) + continue; + + Train(batch["data"], batch["target"], batch["weight"]); DP_LOG(DEBUG, "Training iteration %lu/%lu\n", ++i, dataset_pipeline_->GetItersPerEpoch()); + + if (nr_gpus_ > 1) { + rtctx->torch_stream.synchronize(); // sync before calling into NCCL + commHandler->comm_start(); + auto tmp_sync = torch::zeros({1}).to(rtctx->c10dev);; + commHandler->all_reduce(tmp_sync, c10d::ReduceOp::SUM); + commHandler->comm_end(); + commHandler->sync(); + } } double loss = model->GetAvgLoss(); DP_LOG(DEBUG, "Epoch done. Loss %.2f", loss); @@ -251,3 +285,7 @@ void JobContext::FinishIteration() { StepOne(&iter_done); } while (!iter_done && !job_done_); } + +size_t JobContext::getTrainItersPerEpoch(){ + return train_dataset_->GetItersPerEpoch(); +}; \ No newline at end of file diff --git a/csrc/JobContext.h b/csrc/JobContext.h index 3ef2c03..3a801ef 100644 --- a/csrc/JobContext.h +++ b/csrc/JobContext.h @@ -49,10 +49,12 @@ class JobContext { torch::Tensor Infer(torch::Tensor input); /* Run one training iteration with this input/target */ - void Train(torch::Tensor input, torch::Tensor target); + void Train(torch::Tensor input, + torch::Tensor target, + torch::Tensor weights = {}); /* Test the model on the test dataset */ - void Test(); + void Test(int64_t curEpoch = -1); /* Advance one step through the the model */ void StepOne(bool *iter_done); @@ -61,16 +63,19 @@ class JobContext { void FinishIteration(); /* Train one full epoch */ - void TrainOneEpoch(); + void TrainOneEpoch(int64_t curEpoch = -1); void printJobStatistics(); - std::unique_ptr model; + std::shared_ptr model; std::string name; std::shared_ptr commHandler; std::chrono::time_point start, end; uint64_t be_img_start, be_img_end; + size_t getTrainItersPerEpoch(); + size_t getWarmupIters(){return warmupIters;}; + private: // Params bool run_with_be_{false}; diff --git a/csrc/communication.h b/csrc/communication.h index b35f27d..31e5ccb 100644 --- a/csrc/communication.h +++ b/csrc/communication.h @@ -29,6 +29,8 @@ #include "rpcService.h" #include "utils.h" +#define SIZE_16_MiB 16*1024*1024 + using json = nlohmann::json; class CommunicationHandler { @@ -130,4 +132,4 @@ class CommunicationHandlerGRPC : public CommunicationHandler { std::unordered_map > clientPool; }; -#endif +#endif \ No newline at end of file diff --git a/csrc/dataset.cpp b/csrc/dataset.cpp index b7cc700..0efc4d9 100644 --- a/csrc/dataset.cpp +++ b/csrc/dataset.cpp @@ -9,47 +9,6 @@ ABSL_FLAG(std::string, cifar_dataset, "/home/friedj/mlsf/multimodel/data/cifar-10-batches-bin/", ""); -class FakeDataset : public Dataset { - public: - FakeDataset(size_t rank, long globalBatchSize, - std::vector initialBatchSizes, - std::vector sampleIndices, - std::function()> gen, - size_t images_per_epoch); - torch::data::Example<> getNext() override; - bool IsDone() override; - void Reset() override; - size_t GetItersPerEpoch() override; - - private: - size_t batches_per_epoch_; - size_t ctr_{0}; - std::vector> cached_; -}; - -class CifarDataset : public Dataset { - public: - CifarDataset(size_t rank, long globalBatchSize, - std::vector initialBatchSizes, - std::vector sampleIndices, bool is_eval); - torch::data::Example<> getNext() override; - bool IsDone() override; - void Reset() override; - size_t GetItersPerEpoch() override; - - private: - c10::optional>> cur_iter; - size_t batches_per_epoch_; - - std::unique_ptr>, - torch::data::transforms::Stack>>, - torch::data::samplers::SequentialSampler>> - loader; -}; - FakeDataset::FakeDataset(size_t rank, long globalBatchSize, std::vector initialBatchSizes, std::vector sampleIndices, @@ -64,13 +23,36 @@ size_t FakeDataset::GetItersPerEpoch() { return batches_per_epoch_; }; bool FakeDataset::IsDone() { return ctr_ >= batches_per_epoch_; } -torch::data::Example<> FakeDataset::getNext() { +std::map FakeDataset::getNext() { assert(!IsDone()); - return cached_[ctr_++ % cached_.size()]; + torch::data::Example<> vals = cached_[ctr_++ % cached_.size()]; + return {{std::string("data"), vals.data}, {std::string("target"), vals.target}}; } void FakeDataset::Reset() { ctr_ = 0; } +std::map FakeDataset::getNextThisRank(){ + std::map rtn; + auto ex = getNext(); + if (initialBatchSizes_.at(rank_)) + rtn["data"] = ex["data"].split_with_sizes(initialBatchSizes_)[rank_]; + + if (sampleIndices_.size()){ + std::vector spl(globalBatchSize_, 1); + auto splitsamples = + ex["target"].split_with_sizes(spl); // TODO make this clean.... + std::vector samplesOrdered; + for (auto &s : sampleIndices_) + samplesOrdered.push_back(splitsamples.at(s)); + rtn["target"] = torch::cat(samplesOrdered); + } + return rtn; +} + + + + + CifarDataset::CifarDataset(size_t rank, long globalBatchSize, std::vector initialBatchSizes, std::vector sampleIndices, bool is_eval) @@ -97,25 +79,53 @@ bool CifarDataset::IsDone() { return false; } -torch::data::Example<> CifarDataset::getNext() { +std::map CifarDataset::getNextThisRank(){ + std::map rtn; + auto ex = getNext(); + + // torch::Tensor data, target; + if (initialBatchSizes_.at(rank_)) + rtn["data"] = ex["data"].split_with_sizes(initialBatchSizes_)[rank_]; + + if (sampleIndices_.size()){ + std::vector spl(globalBatchSize_, 1); + auto splitsamples = + ex["target"].split_with_sizes(spl); // TODO make this clean.... + std::vector samplesOrdered; + for (auto &s : sampleIndices_) + samplesOrdered.push_back(splitsamples.at(s)); + rtn["target"] = torch::cat(samplesOrdered); + } + return rtn; +} + +std::map CifarDataset::getNext() +{ assert(!IsDone()); auto cur_example = *cur_iter.value(); cur_iter = ++cur_iter.value(); - return cur_example; + return {{std::string("data"), cur_example.data}, {std::string("target"), cur_example.target}}; } size_t CifarDataset::GetItersPerEpoch() { return batches_per_epoch_; }; void CifarDataset::Reset() { cur_iter = loader->begin(); } -Dataset *Dataset::fromName(std::string name, size_t rank, long globalBatchSize, + +std::shared_ptr Dataset::fromName(std::string name, size_t rank, long globalBatchSize, std::vector initialBatchSizes, std::vector sampleIndices, - size_t fake_train_iters_per_epoch) { + size_t fake_train_iters_per_epoch, + size_t worldSize) { bool eval = name.find("eval") != std::string::npos; if (name.find("cifar") != std::string::npos) - return new CifarDataset(rank, globalBatchSize, initialBatchSizes, + return std::make_shared(rank, globalBatchSize, initialBatchSizes, sampleIndices, eval); +#ifdef ENABLE_STREAMING_DATASET + else if (name.find("anvil") != std::string::npos) + return std::make_shared(rank, globalBatchSize, initialBatchSizes, + sampleIndices, eval, worldSize); +#endif long fake_images = globalBatchSize * fake_train_iters_per_epoch; @@ -131,7 +141,7 @@ Dataset *Dataset::fromName(std::string name, size_t rank, long globalBatchSize, {globalBatchSize, 1024}, topts); return torch::data::Example<>(data, target); }; - return new FakeDataset(rank, globalBatchSize, initialBatchSizes, + return std::make_shared(rank, globalBatchSize, initialBatchSizes, sampleIndices, gen, eval ? 1000 : fake_images); } @@ -145,6 +155,6 @@ Dataset *Dataset::fromName(std::string name, size_t rank, long globalBatchSize, torch::randint(/*low=*/0, /*high=*/1000, {globalBatchSize}, targetOpts); return torch::data::Example<>(data, target); }; - return new FakeDataset(rank, globalBatchSize, initialBatchSizes, + return std::make_shared(rank, globalBatchSize, initialBatchSizes, sampleIndices, gen, eval ? 1000 : fake_images); } \ No newline at end of file diff --git a/csrc/dataset.h b/csrc/dataset.h index 976d223..d8213cf 100644 --- a/csrc/dataset.h +++ b/csrc/dataset.h @@ -6,46 +6,35 @@ #include "runtime.h" #include "utils.h" +#include "cifar10.h" + class Dataset { public: - virtual torch::data::Example<> getNext() = 0; + + Dataset(size_t rank, long globalBatchSize, + std::vector initialBatchSizes, + std::vector sampleIndices) + : rank_(rank), + globalBatchSize_(globalBatchSize), + initialBatchSizes_(initialBatchSizes), + sampleIndices_(sampleIndices){} + virtual ~Dataset() {} + virtual std::map getNext() = 0; virtual size_t GetItersPerEpoch() = 0; virtual bool IsDone() = 0; virtual void Reset() = 0; - static Dataset *fromName(std::string name, size_t rank, long globalBatchSize, + virtual std::map getNextThisRank() = 0; + + static std::shared_ptr fromName(std::string name, size_t rank, long globalBatchSize, std::vector initialBatchSizes, std::vector sampleIndices, - size_t fake_train_iters_per_epoch); - - torch::data::Example<> getNextThisRank() { - auto ex = getNext(); - - torch::Tensor data, target; - if (initialBatchSizes_.at(rank_)) - data = ex.data.split_with_sizes(initialBatchSizes_)[rank_]; - - if (sampleIndices_.size()) { - std::vector spl(globalBatchSize_, 1); - auto splitsamples = - ex.target.split_with_sizes(spl); // TODO make this clean.... - std::vector samplesOrdered; - for (auto &s : sampleIndices_) - samplesOrdered.push_back(splitsamples.at(s)); - target = torch::cat(samplesOrdered); - } - return {data, target}; - } + size_t fake_train_iters_per_epoch, + size_t worldSize = 1); + + protected: long globalBatchSize_; - Dataset(size_t rank, long globalBatchSize, - std::vector initialBatchSizes, std::vector sampleIndices) - : globalBatchSize_(globalBatchSize), - rank_(rank), - initialBatchSizes_(initialBatchSizes), - sampleIndices_(sampleIndices){}; - - private: size_t rank_; std::vector initialBatchSizes_; std::vector sampleIndices_; @@ -96,8 +85,8 @@ class DatasetPipelineWrapper { public: DatasetPipelineWrapper(std::shared_ptr dataset) : dataset_(dataset) { auto next_sample = dataset_->getNextThisRank(); - data_pipeline_.reset(new TensorPipeline(next_sample.data)); - target_pipeline_.reset(new TensorPipeline(next_sample.target)); + for ( const auto &[key, value]: next_sample) + pipelines_[key].reset(new TensorPipeline(value)); } bool IsDone() { return is_done_; } @@ -108,27 +97,81 @@ class DatasetPipelineWrapper { dataset_->Reset(); is_done_ = false; auto next_sample = dataset_->getNextThisRank(); - data_pipeline_->SupplyNext(next_sample.data); - target_pipeline_->SupplyNext(next_sample.target); + for ( const auto &[key, pipe]: pipelines_) + pipe->SupplyNext(next_sample[key]); } - torch::data::Example<> getNextThisRank() { + std::map getNextThisRank() { assert(!is_done_); + std::map rtn_vals; if (dataset_->IsDone()) { - auto data = data_pipeline_->GetNext({}); - auto target = target_pipeline_->GetNext({}); + for ( const auto &[key, pipe]: pipelines_) + rtn_vals[key] = pipe->GetNext({}); is_done_ = true; - return {data, target}; + return rtn_vals; } - auto next_sample = dataset_->getNextThisRank(); - auto data = data_pipeline_->GetNext(next_sample.data); - auto target = target_pipeline_->GetNext(next_sample.target); - return {data, target}; + else{ + auto next_sample = dataset_->getNextThisRank(); + if(dataset_->IsDone()) + is_done_ = true; + if (next_sample.size() > 0) + for ( const auto &[key, pipe]: pipelines_) + rtn_vals[key] = pipe->GetNext(next_sample[key]); + } + return rtn_vals; } private: std::shared_ptr dataset_; - std::unique_ptr data_pipeline_; - std::unique_ptr target_pipeline_; + std::map> pipelines_; bool is_done_{false}; +}; + + + +class FakeDataset : public Dataset { + public: + FakeDataset(size_t rank, long globalBatchSize, + std::vector initialBatchSizes, + std::vector sampleIndices, + std::function()> gen, + size_t images_per_epoch); + std::map getNextThisRank() override; + std::map getNext() override; + bool IsDone() override; + void Reset() override; + size_t GetItersPerEpoch() override; + ~FakeDataset() {cached_.clear();}; + + private: + size_t batches_per_epoch_; + size_t ctr_{0}; + std::vector> cached_; +}; + + + +class CifarDataset : public Dataset { + public: + CifarDataset(size_t rank, long globalBatchSize, + std::vector initialBatchSizes, + std::vector sampleIndices, bool is_eval); + std::map getNext(); + bool IsDone(); + void Reset(); + size_t GetItersPerEpoch(); + std::map getNextThisRank(); + ~CifarDataset() {}; + + private: + c10::optional>> cur_iter; + size_t batches_per_epoch_; + + std::unique_ptr>, + torch::data::transforms::Stack>>, + torch::data::samplers::SequentialSampler>> + loader; }; \ No newline at end of file diff --git a/csrc/rpcService.cpp b/csrc/rpcService.cpp index b0a7a97..47b4c56 100644 --- a/csrc/rpcService.cpp +++ b/csrc/rpcService.cpp @@ -228,6 +228,7 @@ std::unique_ptr RuntimeServiceImpl::parseAndCreateTrainingTask( commHandler = std::make_shared( name, worldSize, tensorTags, rank, jobRankToGlobalRank); } + DP_LOG(DEBUG, "commHandler constructed."); json jobParams = json::parse(request->job_meta_params_in_json()); DP_LOG(DEBUG, "parsed jobParams into json"); diff --git a/csrc/runnableModule.cpp b/csrc/runnableModule.cpp index 277faf2..d0fac69 100644 --- a/csrc/runnableModule.cpp +++ b/csrc/runnableModule.cpp @@ -115,6 +115,12 @@ RunnableModule::RunnableModule( auto layer = std::make_shared(module, specialModule, id, layerIsActive, doLocalGradSync); + if (layer->active){ + for (const auto& params : layer->module.parameters()) { + parameters.push_back(params); + } + } + layers.push_back(layer); layer->commGroupKey = RankVecToKey(ldsc["gpuAssignment"]); @@ -236,7 +242,8 @@ void RunnableModule::SetupOptimizer() { } void RunnableModule::SetInputsTargets(torch::Tensor input, - torch::Tensor target) { + torch::Tensor target, + torch::Tensor weight) { assert(state == JobState::INIT); size_t input_nb = input.defined() ? input.nbytes() : 0; @@ -273,9 +280,27 @@ void RunnableModule::SetInputsTargets(torch::Tensor input, target.record_stream(rtctx->torch_stream); } + size_t weight_nb = weight.defined() ? weight.nbytes() : 0; + + if (weight_nb) { + /* should already be on device */ + assert(weight.is_cuda()); + + if (!weight_buf.defined() || weight_nb != weight_buf.nbytes()) { + /* reallocate target buffer */ + assert(!has_graph); + weight_buf = torch::empty(weight.sizes(), weight.options()); + assert(weight_buf.is_cuda()); + } + + weight_buf.copy_(weight, /*non_blocking=*/true); + weight.record_stream(rtctx->torch_stream); + } + /* enqueue input to first layer */ layers[0]->tensors_in[0] = input_buf; fpTargets = target_buf; + fpWeights = weight_buf; } torch::Tensor Layer::DoForward(bool captureLayer) { @@ -582,8 +607,9 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { layerQ.clear(); layerQ.push_back(layers[0].get()); - - nr_iters_++; + + if(!doGraphCapture) + nr_iters_++; if (layers[0]->active && !layers[0]->tensors_in[0].defined()) assert("MISSING INPUT TO FIRST LAYER!" && false); @@ -596,11 +622,12 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { graph_recording = true; c10::cuda::device_synchronize(); graph_mempool = DeepPool::graph_pool_handle(); - maingraph.capture_begin(graph_mempool); + fw_graph.capture_begin(graph_mempool); commHandler->precapture(); } else if (has_graph) { /* skip to forward phase */ state = JobState::FORWARD; + TimerRecordStage("zero"); return 0; } @@ -616,8 +643,16 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { if (has_graph) { DP_LOG(DEBUG, "Replay iter."); - fullgraph->Launch(rtctx->torch_stream); - state = JobState::FINISH; + main_fw_graph->Launch(rtctx->torch_stream); + TimerRecordStage("forward"); + TimerRecordStage("loss"); + if (!isTrain_){ + TimerRecordStage("backward"); + TimerRecordStage("step"); + state = JobState::FINISH; + } + else + state = JobState::BACKWARD; return 0; } @@ -625,12 +660,25 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { if (status == COMPLETED) { TimerRecordStage("forward"); + if (doGraphCapture) { + sync_manager_.Join(); + commHandler->postcapture(); + fw_graph.capture_end(); + if(isTrain_){ + bw_graph.capture_begin(graph_mempool); + commHandler->precapture(); + } + } if (!isTrain_) { for (auto& layer : layers) { layer->status = LayerStatus::PENDING_FP; layer->nr_current_depedencies = layer->prevLayers.size(); } DP_LOG(DEBUG, "Foward pass is completed."); + + TimerRecordStage("loss"); + TimerRecordStage("backward"); + TimerRecordStage("step"); state = JobState::FINISH; } else { DP_LOG(DEBUG, "Foward pass is completed. Calculating loss soon."); @@ -649,6 +697,14 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { } else if (state == JobState::BACKWARD) { DP_LOG(DEBUG, "JobState::BACKWARD."); + if (has_graph) { + DP_LOG(DEBUG, "Replay iter."); + main_bw_graph->Launch(rtctx->torch_stream); + state = JobState::STEP; + TimerRecordStage("backward"); + return 0; + } + JobStatus status = backwardAStep(layerProfile); if (status == COMPLETED) { TimerRecordStage("backward"); @@ -662,7 +718,7 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { if (doGraphCapture) { sync_manager_.Join(); commHandler->postcapture(); - maingraph.capture_end(); + bw_graph.capture_end(); syncgraph.capture_begin(graph_mempool); } @@ -671,13 +727,22 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { if (doGraphCapture) { syncgraph.capture_end(); - stepgraph.capture_begin(graph_mempool); } - - TimerRecordStage("sync"); + state = JobState::STEP; } else if (state == JobState::STEP) { DP_LOG(DEBUG, "JobState::STEP"); + if (has_graph) { + step_graph->Launch(rtctx->torch_stream); + state = JobState::FINISH; + TimerRecordStage("step"); + return 0; + } + if (doGraphCapture) { + stepgraph.capture_begin(graph_mempool); + commHandler->precapture(); + } + optimizer->step(); TimerRecordStage("step"); state = JobState::FINISH; @@ -690,25 +755,38 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { if (!isTrain_) { commHandler->postcapture(); - maingraph.capture_end(); - auto maingraph_e = - GraphPieces::GraphToExecs(maingraph.getGRAPH(), maingraphsplit); - fullgraph = GraphPieces::MergePieces({maingraph_e}); + auto mainfwgraph_e = + GraphPieces::GraphToExecs(fw_graph.getGRAPH(), maingraphsplit); + main_fw_graph = GraphPieces::MergePieces({mainfwgraph_e}); } else { + commHandler->postcapture(); stepgraph.capture_end(); - auto maingraph_e = - GraphPieces::GraphToExecs(maingraph.getGRAPH(), maingraphsplit); + auto mainfwgraph_e = + GraphPieces::GraphToExecs(fw_graph.getGRAPH(), maingraphsplit); + main_fw_graph = GraphPieces::MergePieces({mainfwgraph_e}); + + auto mainbwgraph_e = + GraphPieces::GraphToExecs(bw_graph.getGRAPH(), maingraphsplit); auto syncgraph_e = GraphPieces::GraphToExecs(syncgraph.getGRAPH(), -1.0); + main_bw_graph = + GraphPieces::MergePieces({mainbwgraph_e, syncgraph_e}); auto stepgraph_e = GraphPieces::GraphToExecs(stepgraph.getGRAPH(), stepgraphsplit); - fullgraph = - GraphPieces::MergePieces({maingraph_e, syncgraph_e, stepgraph_e}); + step_graph = + GraphPieces::MergePieces({stepgraph_e}); } has_graph = true; graph_recording = false; DP_LOG(NOTICE, "Ending capture."); + + //record zero for timers + TimerRecordStage("zero"); + TimerRecordStage("forward"); + TimerRecordStage("loss"); + TimerRecordStage("backward"); + TimerRecordStage("step"); } state = JobState::INIT; @@ -723,11 +801,11 @@ int RunnableModule::AdvanceTraining(bool doGraphCapture, bool layerProfile) { * Reset timers for profiling each layer. Happens every iteration. */ void RunnableModule::resetTimers() { - if (has_graph || graph_recording) return; + if (graph_recording) return; if (rtctx->profile_stage_time) timers.SaveAndReset(); - if (rtctx->profile_layer_times_timers) { + if (!has_graph && rtctx->profile_layer_times_timers) { layerts_fwd.SaveAndReset(); layerts_bwd.SaveAndReset(); } diff --git a/csrc/runnableModule.h b/csrc/runnableModule.h index e62ffe6..445618e 100644 --- a/csrc/runnableModule.h +++ b/csrc/runnableModule.h @@ -162,7 +162,7 @@ class RunnableModule { torch::Tensor getOutput() { return fpOutput; } - void SetInputsTargets(torch::Tensor input, torch::Tensor target = {}); + void SetInputsTargets(torch::Tensor input, torch::Tensor target = {}, torch::Tensor weights = {}); const auto& GetTimers() { return timers; } @@ -172,6 +172,11 @@ class RunnableModule { return loss_tracker_.item().toDouble() / static_cast(nr_iters_); } + void ResetAvgLoss() { + loss_tracker_ *= 0; + nr_iters_ = 0; + } + private: friend struct Layer; friend class JobContext; @@ -198,7 +203,7 @@ class RunnableModule { } inline void TimerRecordStage(std::string name) { - if (rtctx->profile_stage_time && !has_graph && !graph_recording) + if (rtctx->profile_stage_time && !graph_recording) timers.Record(name); } @@ -215,13 +220,16 @@ class RunnableModule { GradientSyncManager sync_manager_; // Topologically sorted list of layers. std::vector> layers; + std::vector parameters; std::unique_ptr optimizer; //////////////////////////////////////////// // Context for tracking partial progress. //////////////////////////////////////////// std::deque layerQ; torch::Tensor fpTargets; + torch::Tensor fpWeights; torch::Tensor fpOutput; + torch::Tensor fpLossResult; LossFunctions lossfn_; torch::Tensor loss_tracker_; @@ -232,16 +240,19 @@ class RunnableModule { bool backwards_did_sync{false}; bool has_graph{false}; bool graph_recording{false}; - torch::Tensor input_buf, target_buf; + torch::Tensor input_buf, target_buf, weight_buf; - std::shared_ptr fullgraph; - DeepPool::CUDAGraph maingraph, syncgraph, stepgraph; + std::shared_ptr main_fw_graph, main_bw_graph, step_graph; + DeepPool::CUDAGraph fw_graph, bw_graph, syncgraph, stepgraph; at::cuda::MempoolId_t graph_mempool; void ResetGraphs() { - fullgraph.reset(); + main_fw_graph.reset(); + main_bw_graph.reset(); + step_graph.reset(); has_graph = false; - maingraph = DeepPool::CUDAGraph(); + fw_graph = DeepPool::CUDAGraph(); + bw_graph = DeepPool::CUDAGraph(); syncgraph = DeepPool::CUDAGraph(); stepgraph = DeepPool::CUDAGraph(); rtctx->torch_stream diff --git a/csrc/tracer.h b/csrc/tracer.h index 4a7e837..7e217f5 100644 --- a/csrc/tracer.h +++ b/csrc/tracer.h @@ -33,9 +33,15 @@ class CpuTimer { } uint64_t avgMicros() { + if (count == 0) + return 0; return RAMCloud::Cycles::toMicroseconds(totalCycles / count); } + double avgMs() { + return (double)this->avgMicros()/1000.0; + } + const char* name; uint64_t lastStartTick{0}; uint64_t totalCycles{0}; diff --git a/jobDescription.py b/jobDescription.py index 06b19ae..59d2ad7 100644 --- a/jobDescription.py +++ b/jobDescription.py @@ -22,6 +22,7 @@ from collections import defaultdict import copy import sys +import hashlib class TensorProperties: def __init__(self, tensor: torch.Tensor = None): @@ -108,12 +109,14 @@ def getRandomInputs(self, batchsize, device="cuda"): return fakeInputs def getModuleId(self): - import hashlib - m = hashlib.sha256() - m.update(json.dumps([str(a) for a in self.getInputShapes()], separators=('_', '-')).encode("utf-8")) - return self.name +\ - json.dumps(self.params, sort_keys=True, separators=('_', '-')) +\ - m.hexdigest() + if hasattr(self.module, 'weight'): + h = hashlib.md5() + h.update(self.module.weight.cpu().detach().numpy()) + return self.name +\ + json.dumps(self.params, sort_keys=True, separators=('_', '-'))+str(h.hexdigest()) + else: + return self.name +\ + json.dumps(self.params, sort_keys=True, separators=('_', '-')) def scriptModule(self): if not self.moduleSavedLocation: @@ -187,11 +190,12 @@ def dumpForJSON(self): class TrainingJob: - def __init__(self, name: str, layers: List[Layer], layerConfigs: List[tuple], globalBatchSize: int, maxGpusUsed: int, datasetDir: str): + def __init__(self, name: str, layers: List[Layer], layerConfigs: List[tuple], globalBatchSize: int, lossfn: int, maxGpusUsed: int, datasetDir: str): self.name = name self.layers = layers self.layerConfigs = layerConfigs self.globalBatchSize = globalBatchSize + self.lossfn = lossfn self.maxGpusUsed = maxGpusUsed self.datasetDir = datasetDir self.bytesPerParam = 4 @@ -202,6 +206,7 @@ def __init__(self, name: str, layers: List[Layer], layerConfigs: List[tuple], gl def loadJSON(self, jobInJson: str): job = json.loads(jobInJson) self.globalBatchSize = job["globalBatchSize"] + self.lossfn = job["lossfn"] self.maxGpusUsed = job["maxGpusUsed"] self.layers = [] self.layerConfigs = [] @@ -250,7 +255,7 @@ def dumpInJSON(self, layers: List[Layer] = None, layerConfigs: list = None): prop = l.dumpForJSON() prop["config"] = config allProps.append(prop) - fullDesc = {"globalBatchSize": self.globalBatchSize, "maxGpusUsed": self.maxGpusUsed, "layers": allProps} + fullDesc = {"globalBatchSize": self.globalBatchSize, "lossfn": self.lossfn, "maxGpusUsed": self.maxGpusUsed, "layers": allProps} # return json.dumps(fullDesc, indent=1, sort_keys=False) return json.dumps(fullDesc, sort_keys=False) diff --git a/parallelizationPlanner.py b/parallelizationPlanner.py old mode 100644 new mode 100755 index f7b95b4..f595779 --- a/parallelizationPlanner.py +++ b/parallelizationPlanner.py @@ -45,11 +45,12 @@ class SearchContext: - def __init__(self, totalGpus: int, globalBatch: int, amplificationLimit: float = 2.0, + def __init__(self, totalGpus: int, globalBatch: int, lossfn: int, amplificationLimit: float = 2.0, dataParallelBaseline = False, sampleSplit=True, spatialSplit=False, filterSplit=False, doNotBench = False): self.totalGpus = totalGpus self.globalBatch = globalBatch + self.lossfn = lossfn self.amplificationLimit = amplificationLimit self.dataParallelBaseline = dataParallelBaseline self.sampleSplit = sampleSplit @@ -68,6 +69,7 @@ def __init__(self, profiler: GpuProfiler, netBw = 1.25E4, verbose=False, gpuProf # self.NET_LATENCY = 400 #40 self.verbose = verbose self.layerProfileCache = {} + # if gpuProfileLoc != None and False: if gpuProfileLocSub != None: self.loadGpuProfile(gpuProfileLocSub) if gpuProfileLoc != None: @@ -105,7 +107,7 @@ def queryFwBwTime(self, layer, config: tuple): p = GpuProfiler("cuda") return p.queryFwBwTime(layer, config) - def generateModuleDescription(self, layerConfigs: list, globalBatch: int): + def generateModuleDescription(self, layerConfigs: list, globalBatch: int, lossfn: int): # gpuTimeSum = 0 # profiler.start() maxGpuUsed = 0 @@ -119,7 +121,7 @@ def generateModuleDescription(self, layerConfigs: list, globalBatch: int): # print("gpuTimeSum: ", gpuTimeSum) # profiler.stop() - return TrainingJob("test", self.layers, layerConfigs, globalBatch, maxGpuUsed, "na") + return TrainingJob("test", self.layers, layerConfigs, globalBatch, lossfn, maxGpuUsed, "na") # job.dumpSingleRunnableModule(15) @@ -479,6 +481,39 @@ def AvgPool2d(self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, return module + def BatchNorm2d(self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, + affine: bool = True, track_running_stats: bool = True, + custom_previous_layers: list = None): + module = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=affine, + track_running_stats=track_running_stats) + + if custom_previous_layers == None and len(self.layers) > 0: + custom_previous_layers = [self.layers[-1]] + layer = Layer(module, "batchNorm2d", + {"num_features": num_features, "eps": eps, "momentum": momentum}, + prevLayers = custom_previous_layers) + self.layers.append(layer) + + return module + + def ConvTranspose2d(self, in_channels: int, out_channels: int, kernel_size: int, + stride: int = 1, padding: int = 0, output_padding: int = 0, groups: int = 1, + bias: bool = True, dilation: int = 1, padding_mode='zeros', + custom_previous_layers: list = None): + module = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, padding_mode=padding_mode) + + if custom_previous_layers == None and len(self.layers) > 0: + custom_previous_layers = [self.layers[-1]] + layer = Layer(module, "convTranspose2d", + {"in_channels": in_channels, "out_channels": out_channels, "kernel_size": kernel_size, + "stride": stride, "bias": bias}, + prevLayers = custom_previous_layers) + self.layers.append(layer) + + return module + def Linear(self, in_features: int, out_features: int, bias: bool = True, custom_previous_layers: list = None): module = nn.Linear(in_features, out_features, bias) @@ -1098,7 +1133,7 @@ def displayMultiChainResult(self, endLayer, endConfig: tuple, t, bestJoiningOpti (branch, idx, config, endTime) = schedule[i] print("%sLayer(%d, %2d) config: %15s done at %d" % (" "*55*branch, branch, idx, config, endTime)) - def searchBestSplits(self, totalGpus: int, globalBatch: int = 16, amplificationLimit: float = 2.0, dataParallelBaseline = False, spatialSplit=False): + def searchBestSplits(self, totalGpus: int, globalBatch: int = 16, amplificationLimit: float = 2.0, dataParallelBaseline = False, spatialSplit=False, lossfn=0): t = [[] for i in range(len(self.layers))] # [layer] = list of (config, cumulativeTime, prevConfigIndex) initialConfigs = [] @@ -1318,7 +1353,7 @@ def searchBestSplits(self, totalGpus: int, globalBatch: int = 16, amplificationL gpuUsecSum )) # print () - moduleDesc = self.generateModuleDescription([t[i][bestConfigChain[i]][0] for i in range(len(bestConfigChain))], globalBatch) + moduleDesc = self.generateModuleDescription([t[i][bestConfigChain[i]][0] for i in range(len(bestConfigChain))], globalBatch, lossfn) return (moduleDesc, finalTime / 1000., gpuUsecSum / 1000.) def searchBestSplitsV2(self, totalGpus: int, globalBatch: int = 16, useZhihaoAlgo = False): @@ -1502,7 +1537,7 @@ def searchLinear(self, preStartLayer, preStartConfig, startLayer, ctx: SearchCon gpuTime = self.benchGpuTime(layer, config, ctx=ctx) syncTime = self.calcSyncTime(layer, config, ctx) - if layer == startLayer: + if layer == startLayer or len(layer.prevLayers) == 0: if preStartLayer != None: cumulativeTime, prevLayerTime, prevConfigOfPrev, _, prevMpIdleTime = preStartLayer.t[preStartConfig] activationTime, activationSizeMatrix = self.calcInputXfer(preStartLayer, layer, preStartConfig, config) @@ -1756,7 +1791,7 @@ def to_gpuTimeline(self, name, totalGpus, dataParallelBaseline=False): plt.savefig("gpuTimeline.pdf") """ Generate a simple DP only plan, or use randomMode to randomly distribute layers """ - def JustDoDP(self, totalGpus: int, globalBatch: int, per_layer_rand_prob: float = 0.0): + def JustDoDP(self, totalGpus: int, globalBatch: int, per_layer_rand_prob: float = 0.0, lossfn: int = 0): randomMode = per_layer_rand_prob > 0.0 if randomMode: random.seed(0) lastCfg = 0 @@ -1816,12 +1851,12 @@ def JustDoDP(self, totalGpus: int, globalBatch: int, per_layer_rand_prob: float gpusUsed = totalGpus gpuTime = 1 finalTime = max(cumulativeTime, finalTime) - moduleDesc = TrainingJob("test", self.layers, [layer.bestCfg for layer in self.layers], globalBatch, totalGpus, "na") + moduleDesc = TrainingJob("test", self.layers, [layer.bestCfg for layer in self.layers], globalBatch, lossfn, totalGpus, "na") return (moduleDesc, dpTime / 1000., 0 / 1000., totalGpus) - def searchBestSplitsV3(self, totalGpus: int, globalBatch: int = 16, amplificationLimit: float = 2.0, dataParallelBaseline = False, sampleSplit=True, spatialSplit=False, filterSplit=False): + def searchBestSplitsV3(self, totalGpus: int, globalBatch: int = 16, lossfn: int = 0, amplificationLimit: float = 2.0, dataParallelBaseline = False, sampleSplit=True, spatialSplit=False, filterSplit=False): """ Parallelization strategy findiing for DeepPool. """ - ctx = SearchContext(totalGpus, globalBatch, amplificationLimit, dataParallelBaseline, sampleSplit=sampleSplit, spatialSplit=spatialSplit, filterSplit=filterSplit) + ctx = SearchContext(totalGpus, globalBatch, lossfn, amplificationLimit, dataParallelBaseline, sampleSplit=sampleSplit, spatialSplit=spatialSplit, filterSplit=filterSplit) ctx.doNotBench = totalGpus == 1 finalLayer = self.searchLinear(None, None, self.layers[0], ctx) @@ -2020,6 +2055,9 @@ def gpuAssignBranch(joiningLayer, availableRanks): if layer.name in ["conv2d"]: print("%11s (b=%2d, w=%3d, h=%3d, c=%4d, f=%4d) => " % (layer.name, *layer.initCfg), end="") print("(b=%2d, w=%3d, h=%3d, c=%4d, f=%4d) " % layer.bestCfg, end="") + elif len(layer.initCfg) == 2: + print("%11s (b=%2d, seq=%6d) => " % (layer.name, *layer.initCfg), end="") + print("(b=%2d, seq=%6d) " % layer.bestCfg, end="") elif len(layer.initCfg) == 3: #layer.name in ["linear", "ReLU1d"]: print("%11s (b=%2d, in=%6d, out=%6d) => " % (layer.name, *layer.initCfg), end="") print("(b=%2d, in=%6d, out=%6d) " % layer.bestCfg, end="") @@ -2039,7 +2077,7 @@ def gpuAssignBranch(joiningLayer, availableRanks): % (finalTime/1000, gpuUsecSum/1000, dpTime/1000, maxGpusUsed)) # moduleDesc = self.generateModuleDescription([layer.bestCfg for layer in self.layers], ctx.globalBatch) - moduleDesc = TrainingJob("test", self.layers, [layer.bestCfg for layer in self.layers], ctx.globalBatch, maxGpusUsed, "na") + moduleDesc = TrainingJob("test", self.layers, [layer.bestCfg for layer in self.layers], ctx.globalBatch, ctx.lossfn, maxGpusUsed, "na") # moduleDesc = None if ctx.dataParallelBaseline: return (moduleDesc, dpTime / 1000., gpuUsecSum / 1000., ctx.totalGpus)