forked from nankaicyber/NKDBsec
325 lines
11 KiB
C++
325 lines
11 KiB
C++
//
|
|
// Created by app on 22-12-29.
|
|
//
|
|
|
|
|
|
#include <Python.h>
|
|
#include "numpy/arrayobject.h"
|
|
|
|
#include "modules/protocol/utility/include/_test_common.h"
|
|
#include "modules/protocol/utility/include/version_compat_utils.h"
|
|
#include "modules/iowrapper/include/io_manager.h"
|
|
#include "modules/protocol/mpc/snn/include/snn_protocol.h"
|
|
#include "modules/protocol/mpc/snn/include/snn_ops.h"
|
|
#include "python_export/dataset.h"
|
|
|
|
|
|
static rosetta::snn::SnnProtocol* context = nullptr;
|
|
static int myid = 0;
|
|
string node_id_0, node_id_1, node_id_2;
|
|
static PyObject *
|
|
snn_init(PyObject *self, PyObject *args)
|
|
{
|
|
int party_id;
|
|
char *config_jsonc;
|
|
if(context != nullptr)return nullptr;
|
|
PyArg_ParseTuple(args, "is", &party_id, &config_jsonc);
|
|
std::string config_json(config_jsonc);
|
|
myid = party_id;
|
|
Logger::Get().log_to_stdout(false);
|
|
// Logger::Get().set_filename("log/yy.log");
|
|
// Logger::Get().set_level(3);
|
|
// rosetta_old_conf_parse(node_id, config_json, party_id, "CONFIG.json");
|
|
std::string node_id = "P" + std::to_string(party_id);
|
|
IOManager::Instance()->CreateChannel("", node_id, config_json);
|
|
context = new rosetta::snn::SnnProtocol();
|
|
context->Init("");
|
|
shared_ptr<NET_IO> net_io = context->GetNetHandler();
|
|
node_id_0 = net_io->GetNodeId(0); \
|
|
node_id_1 = net_io->GetNodeId(1); \
|
|
node_id_2 = net_io->GetNodeId(2);
|
|
|
|
|
|
// shared_ptr<NET_IO> net_io = snn0.GetNetHandler();
|
|
// string node_id_0 = net_io->GetNodeId(0);
|
|
// string node_id_1 = net_io->GetNodeId(1);
|
|
// string node_id_2 = net_io->GetNodeId(2);
|
|
// vector<string> reveal_receivers = {"P0", "P1", "P2"};
|
|
// rosetta::attr_type reveal_attr;
|
|
// reveal_attr["receive_parties"] = receiver_parties_pack(reveal_receivers);
|
|
return PyLong_FromLong((long)context);
|
|
}
|
|
|
|
|
|
static PyObject *
|
|
privateinput(PyObject *self, PyObject *args, PyObject *kwargs) {
|
|
PyObject *datas = nullptr;
|
|
PyObject *ownerlist = nullptr;
|
|
int ggtype = 12;
|
|
vector<string> data_owner;
|
|
char *kwlist[] = {"data_owner","data", "ggtype",nullptr};
|
|
PyArg_ParseTupleAndKeywords(args, kwargs, "|OOi", kwlist, &ownerlist, &datas, &ggtype);
|
|
data_owner.resize(PyList_Size(ownerlist));
|
|
Py_ssize_t strlen;
|
|
char * strbuf;
|
|
for(int i = 0; i < PyList_Size(ownerlist); i++){
|
|
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
|
|
data_owner[i].resize(strlen);
|
|
memcpy((void *) data_owner[i].data(), strbuf, strlen);
|
|
}
|
|
auto dd = new DataSet(data_owner, 1, "datashare", context);
|
|
PyObject *res = dd->private_input_x(datas, ggtype);
|
|
|
|
return res;
|
|
}
|
|
|
|
static PyObject *
|
|
reveal(PyObject *self, PyObject *args, PyObject *kwargs) {
|
|
msg_id_t msgid("reavealer");
|
|
PyObject *datas = nullptr;
|
|
PyObject *ownerlist = nullptr;
|
|
vector<string> data_owner;
|
|
char *kwlist[] = {"data_revealer","data", nullptr};
|
|
PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas);
|
|
data_owner.resize(PyList_Size(ownerlist));
|
|
Py_ssize_t strlen;
|
|
char * strbuf;
|
|
for(int i = 0; i < PyList_Size(ownerlist); i++){
|
|
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
|
|
data_owner[i].resize(strlen);
|
|
memcpy((void *) data_owner[i].data(), strbuf, strlen);
|
|
}
|
|
ssize_t size = PyArray_Size(datas);
|
|
vector<mpc_t> in;
|
|
vector<mpc_t> out;
|
|
in.resize(size);
|
|
out.resize(size);
|
|
memcpy(in.data(), PyArray_DATA(datas), size*8);
|
|
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(datas), PyArray_DIMS(datas), NPY_DOUBLE);
|
|
auto *buf = (double *)PyArray_DATA(res);
|
|
Py_BEGIN_ALLOW_THREADS
|
|
context->GetInternal(msgid)->Reconstruct2PC(in, out, data_owner);
|
|
context->GetInternal(msgid)->synchronize(msgid);
|
|
Py_END_ALLOW_THREADS
|
|
const int float_precision = context->GetMpcContext()->FLOAT_PRECISION;
|
|
for (size_t i = 0; i < size; ++i) {
|
|
buf[i] = MpcTypeToFloat(out[i], float_precision);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
static PyObject *
|
|
revealfake(PyObject *self, PyObject *args, PyObject *kwargs) {
|
|
msg_id_t msgid("reavealer");
|
|
PyObject *datas = nullptr;
|
|
PyObject *ownerlist = nullptr;
|
|
vector<string> data_owner;
|
|
char *kwlist[] = {"data_revealer","data", nullptr};
|
|
PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas);
|
|
data_owner.resize(PyList_Size(ownerlist));
|
|
Py_ssize_t strlen;
|
|
char * strbuf;
|
|
for(int i = 0; i < PyList_Size(ownerlist); i++){
|
|
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
|
|
data_owner[i].resize(strlen);
|
|
memcpy((void *) data_owner[i].data(), strbuf, strlen);
|
|
}
|
|
ssize_t size = PyArray_Size(datas);
|
|
vector<mpc_t> in;
|
|
vector<mpc_t> out;
|
|
in.resize(size);
|
|
out.resize(size*2);
|
|
memcpy(in.data(), PyArray_DATA(datas), size*8);
|
|
long dims[3] = {PyArray_DIM(datas,0), PyArray_DIM(datas,1), 2};
|
|
PyObject *res;
|
|
if (PyArray_NDIM(datas) == 1){
|
|
dims[0] = 2;
|
|
dims[1] = PyArray_DIM(datas,0);
|
|
res = PyArray_SimpleNew(2, dims, NPY_UINT64);
|
|
|
|
}
|
|
else if (PyArray_NDIM(datas) == 2){
|
|
dims[0] = 2;
|
|
dims[1] = PyArray_DIM(datas,0);
|
|
dims[2] = PyArray_DIM(datas,1);
|
|
res = PyArray_SimpleNew(3, dims, NPY_UINT64);
|
|
}
|
|
auto *buf = (unsigned long*)PyArray_DATA(res);
|
|
Py_BEGIN_ALLOW_THREADS
|
|
context->GetInternal(msgid)->Reconstruct2PC_pp(in, out, data_owner);
|
|
context->GetInternal(msgid)->synchronize(msgid);
|
|
Py_END_ALLOW_THREADS
|
|
memcpy(buf, out.data(), 16*size);
|
|
// for (size_t i = 0; i < size; ++i) {
|
|
// buf[i] = out[i];
|
|
// buf[i*2+1] = out[i+size];
|
|
// }
|
|
return res;
|
|
}
|
|
|
|
static PyObject *
|
|
opvector(PyObject *self, PyObject *args) {
|
|
msg_id_t msgid("opvec");
|
|
PyObject *x, *y = nullptr;
|
|
int optype;
|
|
PyArg_ParseTuple(args, "OOi", &x, &y, &optype);
|
|
npy_intp xsize = PyArray_Size(x);
|
|
npy_intp ysize = PyArray_Size(y);
|
|
if (xsize == ysize) {
|
|
vector<mpc_t> in1, in2, out;
|
|
in1.resize(xsize);
|
|
in2.resize(ysize);
|
|
out.resize(xsize);
|
|
memcpy(in1.data(), PyArray_DATA(x), xsize * 8);
|
|
memcpy(in2.data(), PyArray_DATA(y), xsize * 8);
|
|
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG);
|
|
auto *buf = (double *) PyArray_DATA(res);
|
|
int flag = 0;
|
|
Py_BEGIN_ALLOW_THREADS
|
|
switch (optype) {
|
|
case 1:
|
|
context->GetInternal(msgid)->Mul(in1, in2, out);
|
|
break;
|
|
case 2:
|
|
context->GetInternal(msgid)->Add(in1, in2, out);
|
|
break;
|
|
case 3:
|
|
context->GetInternal(msgid)->Less(in1, in2, out);
|
|
break;
|
|
case 4:
|
|
context->GetInternal(msgid)->Greater(in1, in2, out);
|
|
break;
|
|
case 5:
|
|
context->GetInternal(msgid)->Division(in1, in2, out);
|
|
break;
|
|
case 8:
|
|
context->GetInternal(msgid)->Equal(in1, in2, out);
|
|
break;
|
|
case 12:
|
|
context->GetInternal(msgid)->OR(in1, in2, out);
|
|
break;
|
|
case 14:
|
|
context->GetInternal(msgid)->AND(in1, in2, out);
|
|
break;
|
|
case 15:
|
|
context->GetInternal(msgid)->XOR(in1, in2, out);
|
|
break;
|
|
|
|
|
|
// OP_OR = 12
|
|
// OP_NOT = 13
|
|
// OP_AND = 14
|
|
// OP_XOR = 15
|
|
default:
|
|
flag = 1;
|
|
|
|
|
|
}
|
|
if (flag == 0)
|
|
memcpy(buf, out.data(), xsize * 8);
|
|
Py_END_ALLOW_THREADS
|
|
if (flag == 0)
|
|
return res;
|
|
else {
|
|
PyErr_SetString(PyExc_RuntimeError, "op unkown");
|
|
return NULL;
|
|
}
|
|
}
|
|
else if(ysize == 1){
|
|
vector<mpc_t> in1, in2, out;
|
|
in1.resize(xsize);
|
|
in2.resize(xsize);
|
|
out.resize(xsize);
|
|
memcpy(in1.data(), PyArray_DATA(x), xsize * 8);
|
|
for(int i =0; i < xsize; i ++){
|
|
in2[i] = ((mpc_t*)PyArray_DATA(y))[0];
|
|
}
|
|
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG);
|
|
auto *buf = (double *) PyArray_DATA(res);
|
|
int flag = 0;
|
|
if (xsize == 0)
|
|
return res;
|
|
Py_BEGIN_ALLOW_THREADS
|
|
switch (optype) {
|
|
case 1:
|
|
context->GetInternal(msgid)->Mul(in1, in2, out);
|
|
break;
|
|
case 2:
|
|
context->GetInternal(msgid)->Add(in1, in2, out);
|
|
break;
|
|
case 3:
|
|
context->GetInternal(msgid)->Less(in1, in2, out);
|
|
break;
|
|
case 4:
|
|
context->GetInternal(msgid)->Greater(in1, in2, out);
|
|
break;
|
|
case 5:
|
|
context->GetInternal(msgid)->Division(in1, in2, out);
|
|
break;
|
|
case 8:
|
|
context->GetInternal(msgid)->Equal(in1, in2, out);
|
|
break;
|
|
case 9:
|
|
context->GetInternal(msgid)->NotEqual(in1, in2, out);
|
|
break;
|
|
case 10:
|
|
context->GetInternal(msgid)->GreaterEqual(in1, in2, out);
|
|
break;
|
|
case 11:
|
|
context->GetInternal(msgid)->LessEqual(in1, in2, out);
|
|
break;
|
|
default:
|
|
flag = 1;
|
|
|
|
|
|
}
|
|
if (flag == 0)
|
|
memcpy(buf, out.data(), xsize * 8);
|
|
Py_END_ALLOW_THREADS
|
|
if (flag == 0)
|
|
return res;
|
|
else {
|
|
PyErr_SetString(PyExc_RuntimeError, "op unkown");
|
|
return NULL;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
PyErr_SetString(PyExc_RuntimeError, "wrong size");
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static PyMethodDef SpamMethods[] = {
|
|
{"snninit", snn_init, METH_VARARGS,"init agent"},
|
|
{"privateinput", (PyCFunctionWithKeywords)privateinput, METH_VARARGS|METH_KEYWORDS,"privateinput"},
|
|
{"opvector", opvector, METH_VARARGS,"opvector"},
|
|
{"reveal", (PyCFunctionWithKeywords)reveal, METH_VARARGS|METH_KEYWORDS,"reveal"},
|
|
{"revealfake", (PyCFunctionWithKeywords)revealfake, METH_VARARGS|METH_KEYWORDS,"revealfake"},
|
|
{NULL, NULL, 0, NULL} /* Sentinel */
|
|
};
|
|
|
|
static struct PyModuleDef dbsecmodule = {
|
|
PyModuleDef_HEAD_INIT,
|
|
"dbsecmodule", /* name of module */
|
|
"", /* module documentation, may be NULL */
|
|
-1, /* size of per-interpreter state of the module,
|
|
or -1 if the module keeps state in global variables. */
|
|
SpamMethods
|
|
};
|
|
|
|
PyMODINIT_FUNC
|
|
PyInit_dbsec(void)
|
|
{
|
|
import_array();
|
|
return PyModule_Create(&dbsecmodule);
|
|
}
|
|
|
|
//PyMODINIT_FUNC
|
|
//initdbsec(void)
|
|
//{
|
|
//// import_array();
|
|
// return PyModule_Create(&dbsecmodule);
|
|
//}
|