Skip to content

Commit

Permalink
need to remove cpu code in gates
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Dai committed May 21, 2018
1 parent 6cb1976 commit 10a5390
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 22 deletions.
2 changes: 2 additions & 0 deletions cufhe/include/cufhe.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ struct Ctxt {
~Ctxt();
LWESample* lwe_sample_;
MemoryDeleter lwe_sample_deleter_;
LWESample* lwe_sample_device_;
MemoryDeleter lwe_sample_device_deleter_;
};

/** Plaintext is in {0, 1}. */
Expand Down
2 changes: 1 addition & 1 deletion cufhe/include/cufhe_gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public:
inline Stream() {}
inline Stream(int id) { Assert(id == 0); st_ = 0; }
inline ~Stream() {}
inline void Create() { cudaStreamCreateWithFlags(&this->st_, cudaStreamDefault); }
inline void Create() { cudaStreamCreateWithFlags(&this->st_, cudaStreamNonBlocking); }
inline void Destroy() { cudaStreamDestroy(this->st_); }
inline cudaStream_t st() { return st_; };
private:
Expand Down
9 changes: 9 additions & 0 deletions cufhe/lib/cufhe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,16 @@ PubKey::~PubKey() {

Ctxt::~Ctxt() {
lwe_sample_deleter_(lwe_sample_->data());
lwe_sample_deleter_ = nullptr;
lwe_sample_->set_data(nullptr);
delete lwe_sample_;
lwe_sample_ = nullptr;

lwe_sample_device_deleter_(lwe_sample_device_->data());
lwe_sample_device_deleter_ = nullptr;
lwe_sample_device_->set_data(nullptr);
delete lwe_sample_device_;
lwe_sample_device_ = nullptr;
}

void SetSeed(uint32_t seed) {
Expand Down
18 changes: 17 additions & 1 deletion cufhe/lib/cufhe_gates_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,23 @@ void Nand(Ctxt& out,
out.lwe_sample_->data()[i] = 0 - in0.lwe_sample_->data()[i]
- in1.lwe_sample_->data()[i];
out.lwe_sample_->b() += fix;
Bootstrap(out.lwe_sample_, out.lwe_sample_, mu, st.st());
cudaMemcpyAsync(out.lwe_sample_device_->data(),
out.lwe_sample_->data(),
out.lwe_sample_->SizeData(),
cudaMemcpyHostToDevice,
st.st());
//Bootstrap(out.lwe_sample_, out.lwe_sample_, mu, st.st());
Bootstrap(out.lwe_sample_device_, out.lwe_sample_device_, mu, st.st());
cudaMemcpyAsync(out.lwe_sample_->data(),
out.lwe_sample_device_->data(),
out.lwe_sample_->SizeData(),
cudaMemcpyDeviceToHost,
st.st());
cudaEvent_t end_of_gate;
cudaEventCreate(&end_of_gate);
cudaEventRecord(end_of_gate, st.st());
cudaStreamWaitEvent(0, end_of_gate, 0);
//Synchronize();
}

void Or(Ctxt& out,
Expand Down
12 changes: 9 additions & 3 deletions cufhe/lib/cufhe_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@ namespace cufhe {
uint32_t cnt = 0;

Ctxt::Ctxt(bool is_alias) {
std::pair<void*, MemoryDeleter> pair;
Param* param = GetDefaultParam();

lwe_sample_ = new LWESample(param->lwe_n_);
lwe_sample_deleter_ = nullptr;
std::pair<void*, MemoryDeleter> pair;
pair = AllocatorBoth::New(lwe_sample_->SizeMalloc());
//pair = AllocatorBoth::New(lwe_sample_->SizeMalloc());
pair = AllocatorCPU::New(lwe_sample_->SizeMalloc());
lwe_sample_->set_data((LWESample::PointerType)pair.first);
lwe_sample_deleter_ = pair.second;

lwe_sample_device_ = new LWESample(param->lwe_n_);
pair = AllocatorGPU::New(lwe_sample_device_->SizeMalloc());
lwe_sample_device_->set_data((LWESample::PointerType)pair.first);
lwe_sample_device_deleter_ = pair.second;
}

} // namespace cufhe
33 changes: 16 additions & 17 deletions cufhe/test/test_api_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,27 @@ void XorCheck(Ptxt& out, const Ptxt& in0, const Ptxt& in1) {
}

int main() {
cudaSetDevice(0);
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
uint32_t kNumSMs = prop.multiProcessorCount;
uint32_t kNumTests = kNumSMs * 32;
uint32_t kNumTests = kNumSMs * 8;
uint32_t kNumLevels = 8;

SetSeed(); // set random seed

PriKey pri_key_old; // private key
PubKey pub_key_old; // public key
PriKey pri_key; // private key
PubKey pub_key; // public key
Ptxt* pt = new Ptxt[2 * kNumTests];
Ctxt* ct = new Ctxt[2 * kNumTests];
Synchronize();
bool correct;

cout<< "------ Key Generation ------" <<endl;
KeyGen(pub_key_old, pri_key_old);
KeyGen(pub_key, pri_key);
// Alternatively ...
// PriKeyGen(pri_key);
// PubKeyGen(pub_key, pri_key);
WritePriKeyToFile(pri_key_old, "pri_key.txt");
WritePubKeyToFile(pub_key_old, "pub_key.txt");
PriKey pri_key; // private key
PubKey pub_key; // public key
ReadPriKeyFromFile(pri_key, "pri_key.txt");
ReadPubKeyFromFile(pub_key, "pub_key.txt");

cout<< "------ Test Encryption/Decryption ------" <<endl;
cout<< "Number of tests:\t" << kNumTests <<endl;
Expand Down Expand Up @@ -112,36 +108,39 @@ int main() {

// Here, pass streams to gates for parallel gates.
for (int i = 0; i < kNumTests; i ++)
Nand(ct[i], ct[i], ct[i + kNumTests], st[i % kNumSMs]);
for (int j = 0; j < kNumLevels; j ++)
Nand(ct[i], ct[i], ct[i + kNumTests], st[i % kNumSMs]);
Synchronize();

cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&et, start, stop);
cout<< et / kNumTests << " ms / gate" <<endl;
cout<< et / kNumTests / kNumLevels << " ms / gate" <<endl;
cudaEventDestroy(start);
cudaEventDestroy(stop);

int cnt_failures = 0;
for (int i = 0; i < kNumTests; i ++) {
NandCheck(pt[i + kNumTests], pt[i], pt[i + kNumTests]);
Decrypt(pt[i], ct[i], pri_key);
for (int j = 0; j < kNumLevels; j ++)
NandCheck(pt[i], pt[i], pt[i + kNumTests]);
Decrypt(pt[i + kNumTests], ct[i], pri_key);
if (pt[i + kNumTests].message_ != pt[i].message_) {
correct = false;
cnt_failures += 1;
//std::cout<< "Fail at iteration: " << i <<std::endl;
}
}
if (correct)
cout<< "PASS" <<endl;
else
cout<< "FAIL:\t" << cnt_failures << "/" << kNumTests <<endl;
for (int i = 0; i < kNumSMs; i ++)
st[i].Destroy();
delete [] st;

cout<< "------ Cleaning Data on GPU(s) ------" <<endl;
CleanUp(); // essential to clean and deallocate data
delete [] ct;
delete [] pt;
for (int i = 0; i < kNumSMs; i ++)
st[i].Destroy();
delete [] st;
return 0;
}

0 comments on commit 10a5390

Please sign in to comment.