NKDBsec/python_export/ttdy.cpp

325 lines
11 KiB
C++

//
// Created by app on 22-12-29.
//
#include <Python.h>
#include "numpy/arrayobject.h"
#include "modules/protocol/utility/include/_test_common.h"
#include "modules/protocol/utility/include/version_compat_utils.h"
#include "modules/iowrapper/include/io_manager.h"
#include "modules/protocol/mpc/snn/include/snn_protocol.h"
#include "modules/protocol/mpc/snn/include/snn_ops.h"
#include "python_export/dataset.h"
static rosetta::snn::SnnProtocol* context = nullptr;
static int myid = 0;
string node_id_0, node_id_1, node_id_2;
static PyObject *
snn_init(PyObject *self, PyObject *args)
{
int party_id;
char *config_jsonc;
if(context != nullptr)return nullptr;
PyArg_ParseTuple(args, "is", &party_id, &config_jsonc);
std::string config_json(config_jsonc);
myid = party_id;
Logger::Get().log_to_stdout(false);
// Logger::Get().set_filename("log/yy.log");
// Logger::Get().set_level(3);
// rosetta_old_conf_parse(node_id, config_json, party_id, "CONFIG.json");
std::string node_id = "P" + std::to_string(party_id);
IOManager::Instance()->CreateChannel("", node_id, config_json);
context = new rosetta::snn::SnnProtocol();
context->Init("");
shared_ptr<NET_IO> net_io = context->GetNetHandler();
node_id_0 = net_io->GetNodeId(0); \
node_id_1 = net_io->GetNodeId(1); \
node_id_2 = net_io->GetNodeId(2);
// shared_ptr<NET_IO> net_io = snn0.GetNetHandler();
// string node_id_0 = net_io->GetNodeId(0);
// string node_id_1 = net_io->GetNodeId(1);
// string node_id_2 = net_io->GetNodeId(2);
// vector<string> reveal_receivers = {"P0", "P1", "P2"};
// rosetta::attr_type reveal_attr;
// reveal_attr["receive_parties"] = receiver_parties_pack(reveal_receivers);
return PyLong_FromLong((long)context);
}
static PyObject *
privateinput(PyObject *self, PyObject *args, PyObject *kwargs) {
PyObject *datas = nullptr;
PyObject *ownerlist = nullptr;
int ggtype = 12;
vector<string> data_owner;
char *kwlist[] = {"data_owner","data", "ggtype",nullptr};
PyArg_ParseTupleAndKeywords(args, kwargs, "|OOi", kwlist, &ownerlist, &datas, &ggtype);
data_owner.resize(PyList_Size(ownerlist));
Py_ssize_t strlen;
char * strbuf;
for(int i = 0; i < PyList_Size(ownerlist); i++){
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
data_owner[i].resize(strlen);
memcpy((void *) data_owner[i].data(), strbuf, strlen);
}
auto dd = new DataSet(data_owner, 1, "datashare", context);
PyObject *res = dd->private_input_x(datas, ggtype);
return res;
}
static PyObject *
reveal(PyObject *self, PyObject *args, PyObject *kwargs) {
msg_id_t msgid("reavealer");
PyObject *datas = nullptr;
PyObject *ownerlist = nullptr;
vector<string> data_owner;
char *kwlist[] = {"data_revealer","data", nullptr};
PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas);
data_owner.resize(PyList_Size(ownerlist));
Py_ssize_t strlen;
char * strbuf;
for(int i = 0; i < PyList_Size(ownerlist); i++){
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
data_owner[i].resize(strlen);
memcpy((void *) data_owner[i].data(), strbuf, strlen);
}
ssize_t size = PyArray_Size(datas);
vector<mpc_t> in;
vector<mpc_t> out;
in.resize(size);
out.resize(size);
memcpy(in.data(), PyArray_DATA(datas), size*8);
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(datas), PyArray_DIMS(datas), NPY_DOUBLE);
auto *buf = (double *)PyArray_DATA(res);
Py_BEGIN_ALLOW_THREADS
context->GetInternal(msgid)->Reconstruct2PC(in, out, data_owner);
context->GetInternal(msgid)->synchronize(msgid);
Py_END_ALLOW_THREADS
const int float_precision = context->GetMpcContext()->FLOAT_PRECISION;
for (size_t i = 0; i < size; ++i) {
buf[i] = MpcTypeToFloat(out[i], float_precision);
}
return res;
}
static PyObject *
revealfake(PyObject *self, PyObject *args, PyObject *kwargs) {
msg_id_t msgid("reavealer");
PyObject *datas = nullptr;
PyObject *ownerlist = nullptr;
vector<string> data_owner;
char *kwlist[] = {"data_revealer","data", nullptr};
PyArg_ParseTupleAndKeywords(args, kwargs, "|OO", kwlist, &ownerlist, &datas);
data_owner.resize(PyList_Size(ownerlist));
Py_ssize_t strlen;
char * strbuf;
for(int i = 0; i < PyList_Size(ownerlist); i++){
strbuf = PyUnicode_AsUTF8AndSize(PyList_GetItem(ownerlist, i), &strlen);
data_owner[i].resize(strlen);
memcpy((void *) data_owner[i].data(), strbuf, strlen);
}
ssize_t size = PyArray_Size(datas);
vector<mpc_t> in;
vector<mpc_t> out;
in.resize(size);
out.resize(size*2);
memcpy(in.data(), PyArray_DATA(datas), size*8);
long dims[3] = {PyArray_DIM(datas,0), PyArray_DIM(datas,1), 2};
PyObject *res;
if (PyArray_NDIM(datas) == 1){
dims[0] = 2;
dims[1] = PyArray_DIM(datas,0);
res = PyArray_SimpleNew(2, dims, NPY_UINT64);
}
else if (PyArray_NDIM(datas) == 2){
dims[0] = 2;
dims[1] = PyArray_DIM(datas,0);
dims[2] = PyArray_DIM(datas,1);
res = PyArray_SimpleNew(3, dims, NPY_UINT64);
}
auto *buf = (unsigned long*)PyArray_DATA(res);
Py_BEGIN_ALLOW_THREADS
context->GetInternal(msgid)->Reconstruct2PC_pp(in, out, data_owner);
context->GetInternal(msgid)->synchronize(msgid);
Py_END_ALLOW_THREADS
memcpy(buf, out.data(), 16*size);
// for (size_t i = 0; i < size; ++i) {
// buf[i] = out[i];
// buf[i*2+1] = out[i+size];
// }
return res;
}
static PyObject *
opvector(PyObject *self, PyObject *args) {
msg_id_t msgid("opvec");
PyObject *x, *y = nullptr;
int optype;
PyArg_ParseTuple(args, "OOi", &x, &y, &optype);
npy_intp xsize = PyArray_Size(x);
npy_intp ysize = PyArray_Size(y);
if (xsize == ysize) {
vector<mpc_t> in1, in2, out;
in1.resize(xsize);
in2.resize(ysize);
out.resize(xsize);
memcpy(in1.data(), PyArray_DATA(x), xsize * 8);
memcpy(in2.data(), PyArray_DATA(y), xsize * 8);
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG);
auto *buf = (double *) PyArray_DATA(res);
int flag = 0;
Py_BEGIN_ALLOW_THREADS
switch (optype) {
case 1:
context->GetInternal(msgid)->Mul(in1, in2, out);
break;
case 2:
context->GetInternal(msgid)->Add(in1, in2, out);
break;
case 3:
context->GetInternal(msgid)->Less(in1, in2, out);
break;
case 4:
context->GetInternal(msgid)->Greater(in1, in2, out);
break;
case 5:
context->GetInternal(msgid)->Division(in1, in2, out);
break;
case 8:
context->GetInternal(msgid)->Equal(in1, in2, out);
break;
case 12:
context->GetInternal(msgid)->OR(in1, in2, out);
break;
case 14:
context->GetInternal(msgid)->AND(in1, in2, out);
break;
case 15:
context->GetInternal(msgid)->XOR(in1, in2, out);
break;
// OP_OR = 12
// OP_NOT = 13
// OP_AND = 14
// OP_XOR = 15
default:
flag = 1;
}
if (flag == 0)
memcpy(buf, out.data(), xsize * 8);
Py_END_ALLOW_THREADS
if (flag == 0)
return res;
else {
PyErr_SetString(PyExc_RuntimeError, "op unkown");
return NULL;
}
}
else if(ysize == 1){
vector<mpc_t> in1, in2, out;
in1.resize(xsize);
in2.resize(xsize);
out.resize(xsize);
memcpy(in1.data(), PyArray_DATA(x), xsize * 8);
for(int i =0; i < xsize; i ++){
in2[i] = ((mpc_t*)PyArray_DATA(y))[0];
}
PyObject *res = PyArray_SimpleNew(PyArray_NDIM(x), PyArray_DIMS(x), NPY_ULONGLONG);
auto *buf = (double *) PyArray_DATA(res);
int flag = 0;
if (xsize == 0)
return res;
Py_BEGIN_ALLOW_THREADS
switch (optype) {
case 1:
context->GetInternal(msgid)->Mul(in1, in2, out);
break;
case 2:
context->GetInternal(msgid)->Add(in1, in2, out);
break;
case 3:
context->GetInternal(msgid)->Less(in1, in2, out);
break;
case 4:
context->GetInternal(msgid)->Greater(in1, in2, out);
break;
case 5:
context->GetInternal(msgid)->Division(in1, in2, out);
break;
case 8:
context->GetInternal(msgid)->Equal(in1, in2, out);
break;
case 9:
context->GetInternal(msgid)->NotEqual(in1, in2, out);
break;
case 10:
context->GetInternal(msgid)->GreaterEqual(in1, in2, out);
break;
case 11:
context->GetInternal(msgid)->LessEqual(in1, in2, out);
break;
default:
flag = 1;
}
if (flag == 0)
memcpy(buf, out.data(), xsize * 8);
Py_END_ALLOW_THREADS
if (flag == 0)
return res;
else {
PyErr_SetString(PyExc_RuntimeError, "op unkown");
return NULL;
}
}
else
{
PyErr_SetString(PyExc_RuntimeError, "wrong size");
return NULL;
}
}
static PyMethodDef SpamMethods[] = {
{"snninit", snn_init, METH_VARARGS,"init agent"},
{"privateinput", (PyCFunctionWithKeywords)privateinput, METH_VARARGS|METH_KEYWORDS,"privateinput"},
{"opvector", opvector, METH_VARARGS,"opvector"},
{"reveal", (PyCFunctionWithKeywords)reveal, METH_VARARGS|METH_KEYWORDS,"reveal"},
{"revealfake", (PyCFunctionWithKeywords)revealfake, METH_VARARGS|METH_KEYWORDS,"revealfake"},
{NULL, NULL, 0, NULL} /* Sentinel */
};
static struct PyModuleDef dbsecmodule = {
PyModuleDef_HEAD_INIT,
"dbsecmodule", /* name of module */
"", /* module documentation, may be NULL */
-1, /* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
SpamMethods
};
PyMODINIT_FUNC
PyInit_dbsec(void)
{
import_array();
return PyModule_Create(&dbsecmodule);
}
//PyMODINIT_FUNC
//initdbsec(void)
//{
//// import_array();
// return PyModule_Create(&dbsecmodule);
//}