// // Created by app on 22-12-29. // #include #include "numpy/arrayobject.h" #include "modules/protocol/utility/include/_test_common.h" #include "modules/protocol/utility/include/version_compat_utils.h" #include "modules/iowrapper/include/io_manager.h" #include "modules/protocol/mpc/snn/include/snn_protocol.h" #include "modules/protocol/mpc/snn/include/snn_ops.h" #include "python_export/dataset.h" static rosetta::snn::SnnProtocol* context = nullptr; static int myid = 0; string node_id_0, node_id_1, node_id_2; static PyObject * snn_init(PyObject *self, PyObject *args) { int party_id; char *config_jsonc; if(context != nullptr)return nullptr; PyArg_ParseTuple(args, "is", &party_id, &config_jsonc); std::string config_json(config_jsonc); myid = party_id; Logger::Get().log_to_stdout(false); // Logger::Get().set_filename("log/yy.log"); // Logger::Get().set_level(3); // rosetta_old_conf_parse(node_id, config_json, party_id, "CONFIG.json"); std::string node_id = "P" + std::to_string(party_id); IOManager::Instance()->CreateChannel("", node_id, config_json); context = new rosetta::snn::SnnProtocol(); context->Init(""); shared_ptr net_io = context->GetNetHandler(); node_id_0 = net_io->GetNodeId(0); \ node_id_1 = net_io->GetNodeId(1); \ node_id_2 = net_io->GetNodeId(2); // shared_ptr net_io = snn0.GetNetHandler(); // string node_id_0 = net_io->GetNodeId(0); // string node_id_1 = net_io->GetNodeId(1); // string node_id_2 = net_io->GetNodeId(2); // vector reveal_receivers = {"P0", "P1", "P2"}; // rosetta::attr_type reveal_attr; // reveal_attr["receive_parties"] = receiver_parties_pack(reveal_receivers); return PyLong_FromLong((long)context); } static PyObject * privateinput(PyObject *self, PyObject *args, PyObject *kwargs) { PyObject *datas = nullptr; PyObject *ownerlist = nullptr; int ggtype = 12; vector data_owner; char *kwlist[] = {"data_owner","data", "ggtype",nullptr}; PyArg_ParseTupleAndKeywords(args, kwargs, "|OOi", kwlist, &ownerlist, &datas, &ggtype); data_owner.resize(PyList_Size(ownerlist)); Py_ssize_t strlen; char * strbuf; for(int i = 0; i < PyList_Size(ownerlist); i++){ strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen); data_owner[i].resize(strlen); memcpy((void *) data_owner[i].data(), strbuf, strlen); } auto dd = new DataSet(data_owner, 1, "datashare", context); PyObject *res = dd->private_input_x(datas, ggtype); return res; } static PyObject * reveal(PyObject *self, PyObject *args, PyObject *kwargs) { msg_id_t msgid("reavealer"); PyObject *datas = nullptr; PyObject *ownerlist = nullptr; vector data_owner; char *kwlist[] = {"data_revealer","data", nullptr}; PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas); data_owner.resize(PyList_Size(ownerlist)); Py_ssize_t strlen; char * strbuf; for(int i = 0; i < PyList_Size(ownerlist); i++){ strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen); data_owner[i].resize(strlen); memcpy((void *) data_owner[i].data(), strbuf, strlen); } ssize_t size = PyArray_Size(datas); vector in; vector out; in.resize(size); out.resize(size); memcpy(in.data(), PyArray_DATA(datas), size*8); PyObject *res = PyArray_SimpleNew(PyArray_NDIM(datas), PyArray_DIMS(datas), NPY_DOUBLE); auto *buf = (double *)PyArray_DATA(res); Py_BEGIN_ALLOW_THREADS context->GetInternal(msgid)->Reconstruct2PC(in, out, data_owner); context->GetInternal(msgid)->synchronize(msgid); Py_END_ALLOW_THREADS const int float_precision = context->GetMpcContext()->FLOAT_PRECISION; for (size_t i = 0; i < size; ++i) { buf[i] = MpcTypeToFloat(out[i], float_precision); } return res; } static PyObject * revealfake(PyObject *self, PyObject *args, PyObject *kwargs) { msg_id_t msgid("reavealer"); PyObject *datas = nullptr; PyObject *ownerlist = nullptr; vector data_owner; char *kwlist[] = {"data_revealer","data", nullptr}; PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas); data_owner.resize(PyList_Size(ownerlist)); Py_ssize_t strlen; char * strbuf; for(int i = 0; i < PyList_Size(ownerlist); i++){ strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen); data_owner[i].resize(strlen); memcpy((void *) data_owner[i].data(), strbuf, strlen); } ssize_t size = PyArray_Size(datas); vector in; vector out; in.resize(size); out.resize(size*2); memcpy(in.data(), PyArray_DATA(datas), size*8); long dims[3] = {PyArray_DIM(datas,0), PyArray_DIM(datas,1), 2}; PyObject *res; if (PyArray_NDIM(datas) == 1){ dims[0] = 2; dims[1] = PyArray_DIM(datas,0); res = PyArray_SimpleNew(2, dims, NPY_UINT64); } else if (PyArray_NDIM(datas) == 2){ dims[0] = 2; dims[1] = PyArray_DIM(datas,0); dims[2] = PyArray_DIM(datas,1); res = PyArray_SimpleNew(3, dims, NPY_UINT64); } auto *buf = (unsigned long*)PyArray_DATA(res); Py_BEGIN_ALLOW_THREADS context->GetInternal(msgid)->Reconstruct2PC_pp(in, out, data_owner); context->GetInternal(msgid)->synchronize(msgid); Py_END_ALLOW_THREADS memcpy(buf, out.data(), 16*size); // for (size_t i = 0; i < size; ++i) { // buf[i] = out[i]; // buf[i*2+1] = out[i+size]; // } return res; } static PyObject * opvector(PyObject *self, PyObject *args) { msg_id_t msgid("opvec"); PyObject *x, *y = nullptr; int optype; PyArg_ParseTuple(args, "OOi", &x, &y, &optype); npy_intp xsize = PyArray_Size(x); npy_intp ysize = PyArray_Size(y); if (xsize == ysize) { vector in1, in2, out; in1.resize(xsize); in2.resize(ysize); out.resize(xsize); memcpy(in1.data(), PyArray_DATA(x), xsize * 8); memcpy(in2.data(), PyArray_DATA(y), xsize * 8); PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG); auto *buf = (double *) PyArray_DATA(res); int flag = 0; Py_BEGIN_ALLOW_THREADS switch (optype) { case 1: context->GetInternal(msgid)->Mul(in1, in2, out); break; case 2: context->GetInternal(msgid)->Add(in1, in2, out); break; case 3: context->GetInternal(msgid)->Less(in1, in2, out); break; case 4: context->GetInternal(msgid)->Greater(in1, in2, out); break; case 5: context->GetInternal(msgid)->Division(in1, in2, out); break; case 8: context->GetInternal(msgid)->Equal(in1, in2, out); break; case 12: context->GetInternal(msgid)->OR(in1, in2, out); break; case 14: context->GetInternal(msgid)->AND(in1, in2, out); break; case 15: context->GetInternal(msgid)->XOR(in1, in2, out); break; // OP_OR = 12 // OP_NOT = 13 // OP_AND = 14 // OP_XOR = 15 default: flag = 1; } if (flag == 0) memcpy(buf, out.data(), xsize * 8); Py_END_ALLOW_THREADS if (flag == 0) return res; else { PyErr_SetString(PyExc_RuntimeError, "op unkown"); return NULL; } } else if(ysize == 1){ vector in1, in2, out; in1.resize(xsize); in2.resize(xsize); out.resize(xsize); memcpy(in1.data(), PyArray_DATA(x), xsize * 8); for(int i =0; i < xsize; i ++){ in2[i] = ((mpc_t*)PyArray_DATA(y))[0]; } PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG); auto *buf = (double *) PyArray_DATA(res); int flag = 0; if (xsize == 0) return res; Py_BEGIN_ALLOW_THREADS switch (optype) { case 1: context->GetInternal(msgid)->Mul(in1, in2, out); break; case 2: context->GetInternal(msgid)->Add(in1, in2, out); break; case 3: context->GetInternal(msgid)->Less(in1, in2, out); break; case 4: context->GetInternal(msgid)->Greater(in1, in2, out); break; case 5: context->GetInternal(msgid)->Division(in1, in2, out); break; case 8: context->GetInternal(msgid)->Equal(in1, in2, out); break; case 9: context->GetInternal(msgid)->NotEqual(in1, in2, out); break; case 10: context->GetInternal(msgid)->GreaterEqual(in1, in2, out); break; case 11: context->GetInternal(msgid)->LessEqual(in1, in2, out); break; default: flag = 1; } if (flag == 0) memcpy(buf, out.data(), xsize * 8); Py_END_ALLOW_THREADS if (flag == 0) return res; else { PyErr_SetString(PyExc_RuntimeError, "op unkown"); return NULL; } } else { PyErr_SetString(PyExc_RuntimeError, "wrong size"); return NULL; } } static PyMethodDef SpamMethods[] = { {"snninit", snn_init, METH_VARARGS,"init agent"}, {"privateinput", (PyCFunctionWithKeywords)privateinput, METH_VARARGS|METH_KEYWORDS,"privateinput"}, {"opvector", opvector, METH_VARARGS,"opvector"}, {"reveal", (PyCFunctionWithKeywords)reveal, METH_VARARGS|METH_KEYWORDS,"reveal"}, {"revealfake", (PyCFunctionWithKeywords)revealfake, METH_VARARGS|METH_KEYWORDS,"revealfake"}, {NULL, NULL, 0, NULL} /* Sentinel */ }; static struct PyModuleDef dbsecmodule = { PyModuleDef_HEAD_INIT, "dbsecmodule", /* name of module */ "", /* module documentation, may be NULL */ -1, /* size of per-interpreter state of the module, or -1 if the module keeps state in global variables. */ SpamMethods }; PyMODINIT_FUNC PyInit_dbsec(void) { import_array(); return PyModule_Create(&dbsecmodule); } //PyMODINIT_FUNC //initdbsec(void) //{ //// import_array(); // return PyModule_Create(&dbsecmodule); //}