NKDBsec/python_export/dataset.h

538 lines
17 KiB
C++

// ==============================================================================
// Copyright 2020 The LatticeX Foundation
// This file is part of the Rosetta library.
//
// The Rosetta library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The Rosetta library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the Rosetta library. If not, see <http://www.gnu.org/licenses/>.
// ==============================================================================
#pragma once
#include <cstdio>
#include <string>
#include <vector>
#include <iostream>
#include <algorithm>
#include <mutex>
#include <fstream>
#include <Python.h>
#include <stdint.h>
#include "modules/common/include/utils/rtt_logger.h"
#include "modules/iowrapper/include/io_manager.h"
#include "modules/iowrapper/include/io_wrapper.h"
#include "modules/protocol/utility/include/util.h"
using namespace std;
using np_str_t = std::array<char, 33>; // at most 33 bytes
class DataSet {
vector<string> data_owner_;
int owner_index_ = -1;
string label_owner_ = "";
int dataset_type_ = -1;
string node_id_ = "";
string task_id_ = "";
vector<string> data_nodes_;
// -1, have not checked; 0, have checked, but not ok; 1, have checked, but ok
int args_checked_ok_ = -1;
std::string args_check_errmsg;
rosetta::snn::SnnProtocol* context;
enum DatasetType {
SampleAligned = 1,
FeatureAligned = 2,
};
public:
DataSet(const vector<string>& data_owner, int dataset_type, const string& task_id, rosetta::snn::SnnProtocol* snnpr)
: data_owner_(data_owner), dataset_type_(dataset_type), task_id_(task_id), context(snnpr) {
std::sort(data_owner_.begin(), data_owner_.end());
stringstream ss;
ss << "data owner(";
for (auto& owner : data_owner_) {
ss << owner << ",";
}
ss << "), label owner(" << label_owner_ << "), dataset_type(" << dataset_type_ << ")";
if (!task_id_.empty()) {
ss << ", task id(" << task_id_ << ")";
}
log_info << ss.str() ;
}
public:
PyObject * private_input_x(PyObject *input, int ggtype) {
__check();
return private_dataset_input_2d_X(input, ggtype);
}
private:
void __check() {
node_id_ = context->GetNetHandler()->GetCurrentNodeId();
for (int i = 0; i < data_owner_.size(); i++) {
if (data_owner_[i] == node_id_) {
owner_index_ = i;
}
}
__check_args(args_check_errmsg);
if (args_checked_ok_ == 0) {
throw invalid_argument("Invalid_argument - " + args_check_errmsg);
}
}
void __check_args(std::string& errmsg) {
if (args_checked_ok_ != -1)
return;
// locally check
int local_check_ok = 1;
data_nodes_ = context->GetNetHandler()->GetDataNodes();
/// owner \in {0,1,2}
errmsg = "locally check:";
if (data_owner_.size() == 0) {
errmsg = errmsg + " data owner size() == 0.";
log_error << errmsg ;
local_check_ok = 0;
} else {
for (auto& owner : data_owner_) {
if (node_id_ == owner && std::find(data_nodes_.begin(), data_nodes_.end(), owner) == data_nodes_.end()) {
errmsg = errmsg + " invalid data owner(" + owner + ").";
log_error << errmsg;
local_check_ok = 0;
}
}
}
/// label_owner \in {0,1,2}
if (node_id_ == label_owner_ && std::find(data_nodes_.begin(), data_nodes_.end(), label_owner_) == data_nodes_.end()) {
errmsg = " invalid label owner(" + label_owner_ + ").";
log_error << errmsg ;
local_check_ok = 0;
}
msg_id_t msg__check_args("__check_args");
auto netio = context->GetNetHandler();
{
// 1. check "locally check" is ok
vector<int> checked_ok(data_owner_.size(), 0);
log_debug << "owner index:" << owner_index_ << " local check ok:" << local_check_ok ;
if (owner_index_ >= 0)
checked_ok[owner_index_] = local_check_ok;
sync_check(netio, data_owner_, checked_ok);
for (int i = 0; i < checked_ok.size(); i++) {
if (checked_ok[i] == 0) {
log_error << "data owner " << data_owner_[i] << " check error" ;
args_checked_ok_ = 0;
return;
} else {
log_debug << "data owner " << data_owner_[i] << " check ok" ;
}
}
}
errmsg = "data owner size check:";
{
// 2. check owner size
vector<int> data_owner_size(data_owner_.size(), 0);
if (owner_index_ >= 0)
data_owner_size[owner_index_] = data_owner_.size();
sync_check(netio, data_owner_, data_owner_size);
int owner_size = data_owner_size[0];
for (int i = 1; i < data_owner_size.size(); i++) { // at least one of Pi is not equal to other(s) (if not ok)
if(data_owner_size.size() == 1)
break;
if (owner_size != data_owner_size[i]) {
errmsg = errmsg + " invalid data owner size: " + data_owner_[0] + "(" + std::to_string(owner_size) + ") " + data_owner_[i] + "(" +
std::to_string(data_owner_size[i]) + ").";
args_checked_ok_ = 0;
return;
}
}
}
errmsg = "all check:";
{
// 3. check all arguments
vector<vector<string>> all_args(data_owner_.size(), vector<string>(data_owner_.size() + 2));
all_args[0] = data_owner_;
all_args[0][data_owner_.size()] = std::to_string(dataset_type_);
all_args[0][data_owner_.size() + 1] = label_owner_;
for (int i = 1; i < all_args.size(); i++) {
all_args[i] = all_args[0];
}
sync_check(netio, data_owner_, all_args);
for (int i = 1; i < all_args.size(); i++) {
if(all_args.size() == 1)
break;
for (int j = 0; j < all_args[0].size(); j++) {
if (all_args[0][j] != all_args[i][j]) {
errmsg = errmsg + " invalid in check all args[j=" + std::to_string(j) + ", 0(" + all_args[0][j] +"), " + std::to_string(i) + "(" + all_args[i][j] + ")]";
args_checked_ok_ = 0;
return;
}
}
}
}
errmsg = "ok.";
args_checked_ok_ = 1;
}
// csv inputs(X)
PyObject * private_dataset_input_2d_X(PyObject *input, int ggtype) {
log_debug << "DataSet, private_dataset_input_2d_X." ;
msg_id_t msgid("datasetinput");
auto netio = context->GetNetHandler();
////////////////////////////////
ssize_t ndim = PyArray_NDIM(input);
ssize_t size = PyArray_Size(input);
ssize_t dimsize[2];
{
// check ndim
// check size
vector<int> ndims(data_owner_.size(), 0);
if (owner_index_ >= 0) {
ndims[owner_index_] = ndim;
}
sync_check(netio, data_owner_, ndims);
for (int i = 0; i < ndims.size(); i++) {
if (ndims[i] != 2) {
throw runtime_error(data_owner_[i] + " has data, but ndim[" + to_string(ndims[i]) + "] != 2");
}
}
}
//! @todo optimize[yl]
if (dataset_type_ == DatasetType::SampleAligned) {
// initialize
vector<double> valuesX;
int n, d;
n = d = 0;
vector<int> dAll;
vector<int> nAll;
dAll.resize(data_owner_.size(), 0);
nAll.resize(data_owner_.size(), 0);
if (std::find(data_owner_.begin(), data_owner_.end(), node_id_) != data_owner_.end()) {
n = PyArray_DIM(input, 0);
d = PyArray_DIM(input,1);
//
valuesX.resize(size);
auto *buf = (double *)PyArray_DATA(input);
double *hstride = nullptr;
for (int i = 0; i < n; i++) {
hstride = buf + i*d;
for (int j = 0; j < d; j++) {
valuesX[i * d + j] = *(hstride + j);
}
}
}
// sync shape (n, d) and valid check
{
for (int i = 0; i < data_owner_.size(); i++) {
if (data_owner_[i] == node_id_) {
dAll[i] = d;
nAll[i] = n;
break;
}
}
sync_d(netio, data_owner_, dAll);
sync_d(netio, data_owner_, nAll);
d = 0;
for (int i = 0; i < dAll.size(); i++) {
d += dAll[i];
}
n = nAll[0];
for (int i = 1; i < nAll.size(); i++) {
if (n != nAll[i]) {
log_error << "n is not the same" ;
}
}
log_info << "shape:(n,d) --> "
<< "(" << n << ", " << d << ")" ;
}
// get sharings
std::vector<std::vector<mpc_t>> SS(data_owner_.size());
for (int i = 0; i < SS.size(); i++) {
SS[i].resize(n * dAll[i]);
}
Py_BEGIN_ALLOW_THREADS;
for (int i = 0; i < data_owner_.size(); i++) {
vector<double> valuesXX;
if (data_owner_[i] == node_id_) {
valuesXX = valuesX;
} else {
valuesXX.resize(n * dAll[i], 0);
}
log_debug << "valuesXX size:" << valuesXX.size() << "SS size:" << SS[i].size();
if(ggtype == 8)
context->GetInternal(msgid)->PrivateInput(data_owner_[i], *((vector<mpc_t> *) &valuesXX), SS[i]);
else if (ggtype == 12)
context->GetInternal(msgid)->PrivateInput(data_owner_[i], valuesXX, SS[i]);
else{
printf("warning wrongtype: %d\n", ggtype);
context->GetInternal(msgid)->PrivateInput(data_owner_[i], valuesXX, SS[i]);
}
}
Py_END_ALLOW_THREADS;
npy_intp dims[2]={n,d};
PyObject *res = PyArray_SimpleNew(2, dims, NPY_ULONGLONG);
auto buf = (npy_ulonglong *)PyArray_DATA(res);
// set result. combine partA,partB,partC for each party
{
int offset = 0;
for (int i = 0; i < data_owner_.size(); i++) {
int tmp_d = dAll[i];
for (int j = 0; j < n; j++) {
for (int k = 0; k < tmp_d; k++) {
buf[j*d+k+offset] = SS[i][j * tmp_d + k];
}
}
offset += tmp_d;
}
}
msg_id_t msg_sync_with("private_dataset_input_2d_X");
Py_BEGIN_ALLOW_THREADS;
netio->sync_with(msg_sync_with);
Py_END_ALLOW_THREADS;
return res;
}
//! @todo optimize[yl]
if (dataset_type_ == DatasetType::FeatureAligned) {
// initialize
vector<double> valuesX;
int n, d;
n = d = 0;
vector<int> dAll;
vector<int> nAll;
dAll.resize(data_owner_.size(), 0);
nAll.resize(data_owner_.size(), 0);
if (std::find(data_owner_.begin(), data_owner_.end(), node_id_) != data_owner_.end()) {
n = PyArray_DIM(input, 0);
d = PyArray_DIM(input,1);
//
valuesX.resize(size);
auto *buf = (double *)PyArray_DATA(input);
double *hstride = nullptr;
for (int i = 0; i < n; i++) {
hstride = buf + i*d;
for (int j = 0; j < d; j++) {
valuesX[i * d + j] = *(hstride + j);
}
}
}
// sync shape (n, d) and valid check
{
for (int i = 0; i < data_owner_.size(); i++) {
if (data_owner_[i] == node_id_) {
dAll[i] = d;
nAll[i] = n;
log_info << "my shape: ("<< n << "," << d <<")" ;
break;
}
}
sync_d(netio, data_owner_, dAll);
sync_d(netio, data_owner_, nAll);
for (int i = 0; i < data_owner_.size(); i++) {
log_info << data_owner_[i] << " shape:(" << nAll[i] << ", " << dAll[i] << ")" ;
}
n = 0;
for (int i = 0; i < nAll.size(); i++) {
n += nAll[i];
}
d = dAll[0];
for (int i = 1; i < dAll.size(); i++) {
if (d != dAll[i]) {
log_error << "d is not the same" ;
}
}
log_info << "shape:(n,d) --> "
<< "(" << n << ", " << d << ")" ;
}
// get sharings!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
vector<vector<mpc_t>> SS(data_owner_.size());
for (int i = 0; i < SS.size(); i++) {
SS[i].resize(nAll[i] * d);
}
Py_BEGIN_ALLOW_THREADS;
for (int i = 0; i < data_owner_.size(); i++) {
vector<double> valuesXX;
if (data_owner_[i] == node_id_) {
valuesXX = valuesX;
} else {
valuesXX.resize(nAll[i] * d, 0);
}
log_debug << "valuesXX size:" << valuesXX.size() << "SS size:" << SS[i].size() ;
context->GetInternal(msgid)->PrivateInput(data_owner_[i], valuesXX, SS[i]);
}
Py_END_ALLOW_THREADS;
// set result. combine partA,partB,partC for each party
npy_intp dims[2]={n,d};
PyObject *res = PyArray_SimpleNew(2, dims, NPY_ULONGLONG);
auto buf = (npy_ulonglong *)PyArray_DATA(res);
{
int offset = 0;
for (int i = 0; i < data_owner_.size(); i++) {
int tmp_n = nAll[i];
for (int j = 0; j < tmp_n; j++) {
for (int k = 0; k < d; k++) {
buf[(j + offset)*d+ k] = SS[i][j * d + k];
}
}
offset += tmp_n;
}
}
msg_id_t msg_sync_with("private_dataset_input_2d_X");
Py_BEGIN_ALLOW_THREADS;
netio->sync_with(msg_sync_with);
Py_END_ALLOW_THREADS;
return res;
}
return nullptr;
}
private:
void sync_d(shared_ptr<NET_IO>& netio, const vector<string> &owners, vector<int>& d) {
Py_BEGIN_ALLOW_THREADS;
msg_id_t msgid("sync_d vector"); // temp
vector<string> non_computation_nodes = netio->GetNonComputationNodes();
string node_id = netio->GetCurrentNodeId();
int partyNum = netio->GetPartyId(node_id);
for (int i = 0; i < owners.size(); i++) {
//log_info << "owner:" << owners[i] << "node id: " << node_id_ ;
string node_a = netio->GetNodeId(PARTY_A);
string node_b = netio->GetNodeId(PARTY_B);
string node_c = netio->GetNodeId(PARTY_C);
if (owners[i] == node_id_) {
if (node_id_ != node_a) {
netio->send(node_a, (const char*)&d[i], sizeof(d[i]), msgid);
//log_info << "send to " << node_a << " data:" << d[i] ;
}
if (node_id_ != node_b) {
netio->send(node_b, (const char*)&d[i], sizeof(d[i]), msgid);
//log_info << "send to " << node_b << " data:" << d[i] ;
}
if (node_id_ != node_c) {
netio->send(node_c, (const char*)&d[i], sizeof(d[i]), msgid);
//log_info << "send to " << node_c << " data:" << d[i] ;
}
} else if (PRIMARY || HELPER) {
netio->recv(owners[i], (char*)&d[i], sizeof(d[i]), msgid);
// log_info << "recv from " << owners[i] << " data:" << d[i] ;
} else if (std::find(non_computation_nodes.begin(), non_computation_nodes.end(), node_id_) != non_computation_nodes.end()) {
netio->recv(node_c, (char*)&d[i], sizeof(d[i]), msgid);
//log_info << "recv from " << node_c << " data:" << d[i] ;
}
if (HELPER) {
for (auto iter = non_computation_nodes.begin(); iter != non_computation_nodes.end(); iter++) {
if (*iter != owners[i] && *iter != node_a && *iter != node_b && *iter != node_c) {
netio->send(*iter, (const char*)&d[i], sizeof(d[i]), msgid);
//log_info << "send to " << *iter << " data:" << d[i] ;
}
}
}
}
Py_END_ALLOW_THREADS;
}
void sync_l(shared_ptr<NET_IO>& netio, vector<int>&d) {
sync_d(netio, data_owner_, d);
}
void sync_l(shared_ptr<NET_IO>& netio, int& label) {
vector<string> label_owner{label_owner_};
vector<int> labels{label};
sync_d(netio, label_owner, labels);
label = labels[0];
}
template<class T>
void sync_check(shared_ptr<NET_IO>& netio, const vector<string>& owners, vector<vector<T>> &d) {
Py_BEGIN_ALLOW_THREADS;
msg_id_t msgid("sync check for dataset");
string node_c = netio->GetNodeId(PARTY_C);
for (int i = 0; i < owners.size(); i++) {
log_debug << "checknum:" << i ;
if (node_id_ == owners[i]) {
if (node_id_ != node_c) {
//log_debug << "send to " << node_c << " size:" << d[i].size() ;
netio->send(node_c, d[i], d[i].size(), msgid);
}
}
if (node_id_ == node_c) {
if (node_id_ != owners[i]) {
//log_debug << "recv from " << owners[i] << " size:" << d[i].size() ;
netio->recv(owners[i], d[i], d[i].size(), msgid);
//log_debug << "after recv from " << owners[i] << " size:" << d[i].size() ;
}
vector<string> nodes = netio->GetConnectedNodes();
for (int j = 0; j < nodes.size(); j++) {
if (nodes[j] != owners[i]) {
//log_debug << "send to " << nodes[j] << " size:" << d[i].size() ;
netio->send(nodes[j], d[i], d[i].size(), msgid);
}
}
} else if (node_id_ != owners[i]) {
//log_debug << "recv from " << node_c << " size:" << d[i].size() ;
netio->recv(node_c, d[i], d[i].size(), msgid);
}
}
log_debug << "checkfinish" ;
Py_END_ALLOW_THREADS;
}
void sync_check(shared_ptr<NET_IO>& netio, const vector<string>& owners, vector<int> &d) {
vector<vector<int>> ds(d.size());
for (int i = 0; i < ds.size(); i++) {
ds[i].push_back(d[i]);
}
log_debug << "owners size:" << owners.size() << " ds size:" << ds.size() ;
sync_check(netio, owners, ds);
for (int i = 0; i < ds.size(); i++) {
d[i] = ds[i][0];
}
}
};