545 lines
22 KiB
Python
545 lines
22 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import csv
|
|
import os
|
|
from pydoc import plain
|
|
import sys
|
|
import multiprocessing
|
|
from venv import logger
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import time
|
|
import random
|
|
import logging
|
|
|
|
from queryParse import parseSQL
|
|
from saveCipherTable import *
|
|
from loadCipherTable import *
|
|
from AST import *
|
|
from obliviousSort import *
|
|
from cache import *
|
|
from parameters import *
|
|
|
|
def judgeTypes(tuple):
|
|
'''
|
|
When compute the result from the operand stack, we should judge the type of the operand:
|
|
|
|
There are three types:
|
|
(1) Digit: e.g. '4', '50000' return 1
|
|
(2) Column: e.g. 'ID', 'deposit3' return 2
|
|
(3) Cipher result: e.g. [[b'ul\xcd\x11\x89\x9d\xba\xe1#'][b'\x9fIa\xcbj%8n#']...] return 3
|
|
|
|
'''
|
|
if type(tuple) != str:
|
|
return 3
|
|
else:
|
|
if tuple.isdigit() == True:
|
|
return 1
|
|
else:
|
|
return 2
|
|
|
|
def computePostfix(id, postexp, cipherTable, header):
|
|
'''
|
|
This function will compute the result according to postfix expressions
|
|
'''
|
|
import latticex.rosetta as rtt
|
|
# rtt.activate("SecureNN")
|
|
operand = []
|
|
|
|
# i = 0
|
|
flag = 0
|
|
for token in postexp:
|
|
if token not in ['ALL', 'ANY', 'BETWEEN', 'LIKE', 'IN', 'OR', 'SOME', 'AND', 'NOT', 'op0', 'op1', 'op2', 'op3', 'op4', 'op5', 'op6', 'op7', 'op8', 'op9', '(', ')']:
|
|
'''
|
|
If token is operand, push to the temp stack
|
|
'''
|
|
operand.append(token)
|
|
else:
|
|
'''
|
|
If token is operator, compute the intermediate results and push to the stack
|
|
'''
|
|
roperand = operand.pop()
|
|
loperand = operand.pop()
|
|
|
|
'''
|
|
Judge the operand type and generate the operand(SS):
|
|
If operand is digit, generate the share of digit directly
|
|
If operand is string, select the cipher column from the ciphertable
|
|
If operand is ndarray, use it directly
|
|
'''
|
|
param = parameters()
|
|
cacheFlag = cacheRule(id, loperand, roperand, token) and param.cacheTurn
|
|
if cacheFlag == True:
|
|
'''
|
|
If the table has cached?
|
|
|
|
True -> load the MPC cache directly;
|
|
False -> execute MPC
|
|
'''
|
|
cachePath = cacheName(loperand, roperand, token)
|
|
if ifCached(id, cachePath) == True and param.cacheTurn == True:
|
|
operand.append(cacheLoad(id, cachePath, [1, param.dataScale]))
|
|
else:
|
|
for cidx in range(len(header)):
|
|
if header[cidx] == loperand:
|
|
loperand = cipherTable[:,cidx]
|
|
break
|
|
else:
|
|
continue
|
|
|
|
for cidx in range(len(header)):
|
|
if header[cidx] == roperand:
|
|
roperand = cipherTable[:,cidx]
|
|
break
|
|
else:
|
|
continue
|
|
|
|
rttres = secureOperator(id, loperand, roperand, token)
|
|
'''
|
|
Judge if the rttres need to be cached
|
|
'''
|
|
cacheSave(id, rttres, cachePath)
|
|
|
|
operand.append(rttres)
|
|
|
|
else:
|
|
if judgeTypes(loperand) == 1:
|
|
loperand = [[int(loperand)]]
|
|
if id == 0:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(loperand)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
if id == 1:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(loperand)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
if id == 2:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(loperand)
|
|
loperand = rtdata0
|
|
# loperand = tf.broadcast_to(rtdata0, [20,1])
|
|
|
|
elif judgeTypes(loperand) == 2:
|
|
for cidx in range(len(header)):
|
|
if header[cidx] == loperand:
|
|
loperand = cipherTable[:,cidx]
|
|
break
|
|
else:
|
|
continue
|
|
|
|
if judgeTypes(roperand) == 1:
|
|
roperand = [[int(roperand)]]
|
|
if id == 0:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(roperand)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
if id == 1:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(roperand)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
if id == 2:
|
|
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(roperand)
|
|
roperand = rtdata0
|
|
# roperand = tf.broadcast_to(rtdata0, [20,1])
|
|
|
|
elif judgeTypes(roperand) == 2:
|
|
for cidx in range(len(header)):
|
|
if header[cidx] == roperand:
|
|
roperand = cipherTable[:,cidx]
|
|
break
|
|
else:
|
|
continue
|
|
|
|
rttres = secureOperator(id, loperand, roperand, token)
|
|
'''
|
|
Judge if the rttres need to be cached
|
|
'''
|
|
operand.append(rttres)
|
|
|
|
return operand.pop()
|
|
|
|
def secureOperator(id, leftOperand, rightOperand, op):
|
|
import latticex.rosetta as rtt
|
|
# rtt.activate("SecureNN")
|
|
'''
|
|
After generate the rtt operand, this module will compute the result
|
|
Operator ReplaceCode
|
|
> op0
|
|
< op1
|
|
== op2
|
|
>= op3
|
|
<= op4
|
|
<> op5
|
|
+ op6
|
|
- op7
|
|
* op8
|
|
/ op9
|
|
'''
|
|
|
|
if op == 'op0':
|
|
rtres = rtt.SecureGreater(leftOperand, rightOperand)
|
|
elif op == 'op1':
|
|
rtres = rtt.SecureLess(leftOperand, rightOperand)
|
|
elif op == 'op2':
|
|
rtres = rtt.SecureEqual(leftOperand, rightOperand)
|
|
elif op == 'op3':
|
|
rtres = rtt.SecureGreaterEqual(leftOperand, rightOperand)
|
|
elif op == 'op4':
|
|
rtres = rtt.SecureLessEqual(leftOperand, rightOperand)
|
|
elif op == 'op5':
|
|
rtres = rtt.SecureLogicalNot(rtt.SecureEqual(leftOperand, rightOperand))
|
|
elif op == 'op6':
|
|
rtres = rtt.SecureAdd(leftOperand, rightOperand)
|
|
elif op == 'op7':
|
|
rtres = rtt.SecureSub(leftOperand, rightOperand)
|
|
elif op == 'op8':
|
|
rtres = rtt.SecureMul(leftOperand, rightOperand)
|
|
elif op == 'op9':
|
|
rtres = rtt.SecureSecureFloorDiv(leftOperand, rightOperand)
|
|
elif op == 'AND':
|
|
rtres = rtt.SecureLogicalAnd(leftOperand, rightOperand)
|
|
elif op == 'OR':
|
|
rtres = rtt.SecureLogicalOr(leftOperand, rightOperand)
|
|
return rtres
|
|
|
|
def executeSecurePlan(id, onlinelist, SQLs):
|
|
SQLidx = 0
|
|
TIME_START = time.time()
|
|
param = parameters()
|
|
for SQL in SQLs:
|
|
plan = parseSQL(SQL)
|
|
'''
|
|
This module will execute the secure plan generated by queryParse.py
|
|
|
|
Parameters:
|
|
(1) onlinelist: Indicate the which user(s) is(are) online/offline. 0 represents offline and 1 represents online
|
|
(2) plan: The key-value datastructure generated by queryParse.py
|
|
'''
|
|
|
|
'''
|
|
Step1. This module should identify which column to be selected
|
|
'''
|
|
columns = []
|
|
for item in plan['SELECT']:
|
|
if item[0] not in [',', '(', ')']:
|
|
columns.append(item[0])
|
|
|
|
'''
|
|
Step2. Identify which table(plaintext) shoule be uploaded and selected
|
|
|
|
table: attributes_source_name
|
|
|
|
attribute: S/M, S-single, M-merge
|
|
source: user0/user1/user2/user3/...
|
|
name: table name
|
|
|
|
e.g. S_user2_table0
|
|
'''
|
|
table = plan['FROM'][0][0]
|
|
table_attribute = table.split('_')[0]
|
|
table_source = table.split('_')[1]
|
|
table_name = table.split('_')[2]
|
|
|
|
tablePath = 'users/{}/{}.csv'.format(table_source, table)
|
|
if os.path.exists(tablePath) == False:
|
|
print('Please check the tablename in SQL!')
|
|
exit(0)
|
|
|
|
'''
|
|
Step3. According to header, cipher_result and column, server will return the expected columns of ciphertext table
|
|
|
|
Datastructures:
|
|
(1)table header is stored in header[], the data is recorded in plaintext[]
|
|
(2)select_column_idx: the columns index appear in 'SELECT' key
|
|
(3)where_column_idx: the columns index appear in 'WHERE' key
|
|
(4)cipher_column_idx: the union of select_column_idx and where_column_idx
|
|
'''
|
|
sys.argv.extend(["--node_id", "P{}".format(id)])
|
|
import latticex.rosetta as rtt
|
|
rtt.backend_log_to_stdout(False)
|
|
|
|
rtt.activate("SecureNN")
|
|
|
|
header = []
|
|
with open(tablePath, 'r') as f:
|
|
reader = csv.reader(f)
|
|
header = list(reader)[0]
|
|
header = np.array(header)
|
|
|
|
select_column_idx = []
|
|
for column in columns:
|
|
for idx in range(len(header)):
|
|
if column == header[idx]:
|
|
select_column_idx.append(idx)
|
|
else:
|
|
continue
|
|
|
|
select_delete_columns = []
|
|
if columns[0] == '*':
|
|
select_delete_columns = []
|
|
else:
|
|
for i in range(len(header)):
|
|
if i in select_column_idx:
|
|
continue
|
|
else:
|
|
select_delete_columns.append(i)
|
|
|
|
|
|
'''
|
|
Parse and compute the 'ORDER BY' subquery
|
|
|
|
Note. This module just support single attribute sort
|
|
|
|
e.g. 'SELECT ID,AGE FROM TABLE ORDER BY DEPOSIT ASC|DESC'
|
|
'''
|
|
if 'ORDER BY' in plan.keys():
|
|
order_columns = []
|
|
order_column_idx = []
|
|
# if 'ORDER BY' in plan.keys():
|
|
kv = plan['ORDER BY'][0][0].split(' ')
|
|
order_columns.append(kv[0])
|
|
order_mode = kv[1]
|
|
|
|
for column in order_columns:
|
|
for idx in range(len(header)):
|
|
if column == header[idx]:
|
|
order_column_idx.append(idx)
|
|
else:
|
|
continue
|
|
|
|
cipher_column_idx = list(set(select_column_idx).union(set(order_column_idx)))
|
|
# print('order_columns_idx:', order_column_idx, 'select_column_idx:', select_column_idx, 'cipher_column_idx:', cipher_column_idx)
|
|
cipher_delete_column_idx = []
|
|
for i in range(len(header)):
|
|
if i in cipher_column_idx:
|
|
continue
|
|
else:
|
|
cipher_delete_column_idx.append(i)
|
|
|
|
'''
|
|
Generate the cache path and judge whether to save/load to cache
|
|
|
|
e.g. 'SELECT ID FROM TABLE ORDER BY DEPOSIT ASC' will be cached to server/P*/TABLE_ID_DEPOSIT_ASC.sdata
|
|
'''
|
|
cacheName = table + '_'
|
|
for i in range(len(select_column_idx)):
|
|
cacheName += (header[i] + '_')
|
|
cacheName += (kv[0] + '_' + kv[1])
|
|
|
|
if ifCached(id, cacheName) == True and param.cacheTurn == True:
|
|
rtres = cacheLoad(id, cacheName, [param.dataScale, len(select_column_idx)])
|
|
|
|
else:
|
|
'''
|
|
StepO-1. We define P0 to generate the data slices for plaintext
|
|
|
|
After generating the secret share, P0,P1,P2 will hold the data slices
|
|
|
|
Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share
|
|
'''
|
|
|
|
if id == 0:
|
|
plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1)
|
|
plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1)
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
elif id == 1:
|
|
rd = np.random.RandomState(1789)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
else:
|
|
rd = np.random.RandomState(1999)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
|
|
|
header = np.delete(header, cipher_delete_column_idx)
|
|
|
|
'''
|
|
StepO-2. Execute oblivious odd even merge sort on cipher table
|
|
|
|
order_mode: Check whether the sort mode is in ASC|DESC
|
|
|
|
attr_idx: Sorting will be done according to attr
|
|
'''
|
|
if order_mode == 'ASC':
|
|
sortMode = 0
|
|
elif order_mode == 'DESC':
|
|
sortMode = 1
|
|
else:
|
|
print('Please check your order mode, just support ASC|DESC!')
|
|
exit(-1)
|
|
|
|
attr_idx = 0
|
|
for attr_idx in range(len(header)):
|
|
'''
|
|
Note. This module just support single attribute for sorting
|
|
'''
|
|
if header[attr_idx] == order_columns[0]:
|
|
break
|
|
else:
|
|
attr_idx += 1
|
|
|
|
rtres = oblivious_odd_even_merge_sort(attr_idx, rtx, sortMode)
|
|
session = tf.Session()
|
|
sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres))
|
|
print(sorted_rtxdata_plaintext)
|
|
|
|
if ifCached(id, cacheName) == False and param.cacheTurn == True:
|
|
cacheSave(id, rtres, cacheName)
|
|
|
|
# session = tf.Session()
|
|
# sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres))
|
|
# print(sorted_rtxdata_plaintext)
|
|
|
|
|
|
'''
|
|
Parse and compute the 'WHERE' subquery
|
|
'''
|
|
if 'WHERE' in plan.keys():
|
|
where_columns = []
|
|
where_column_idx = []
|
|
# if 'WHERE' in plan.keys():
|
|
for item in plan['WHERE']:
|
|
for tuple in item:
|
|
where_columns.append(tuple)
|
|
# print(where_columns)
|
|
|
|
for column in where_columns:
|
|
for idx in range(len(header)):
|
|
if column == header[idx]:
|
|
where_column_idx.append(idx)
|
|
else:
|
|
continue
|
|
|
|
cipher_column_idx = list(set(select_column_idx).union(set(where_column_idx)))
|
|
cipher_delete_column_idx = []
|
|
for i in range(len(header)):
|
|
if i in cipher_column_idx:
|
|
continue
|
|
else:
|
|
cipher_delete_column_idx.append(i)
|
|
|
|
'''
|
|
StepW-1. We define P0 to generate the data slices for plaintext
|
|
|
|
After generating the secret share, P0,P1,P2 will hold the data slices
|
|
|
|
Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share
|
|
'''
|
|
|
|
if id == 0:
|
|
plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1)
|
|
plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1)
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
elif id == 1:
|
|
rd = np.random.RandomState(1789)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
else:
|
|
rd = np.random.RandomState(1999)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
|
|
|
header = np.delete(header, cipher_delete_column_idx)
|
|
|
|
'''
|
|
StepW-2. After generating the whole ciphertext table, we should parse the WHERE subquery
|
|
|
|
This module aimed to transfer the WHERE subquery to logic AST
|
|
'''
|
|
|
|
opAST = infix2postfix(plan['WHERE'])
|
|
# print(opAST)
|
|
|
|
select_ans_colidx = []
|
|
select_ans = []
|
|
|
|
'''
|
|
Choose the index of 'select' subquery
|
|
'''
|
|
temp = 0
|
|
for idx in range(len(columns)):
|
|
if columns[idx] == header[temp]:
|
|
select_ans_colidx.append(temp)
|
|
temp += 1
|
|
else:
|
|
continue
|
|
|
|
sess = tf.compat.v1.Session()
|
|
sess.run(tf.compat.v1.global_variables_initializer())
|
|
|
|
where_result = computePostfix(id, opAST, rtx, header)
|
|
where_result_new = tf.reshape(where_result[0], [param.dataScale, -1])
|
|
|
|
print('WHERE result:', sess.run(rtt.SecureReveal(where_result_new)))
|
|
print('rtx result:', sess.run(rtt.SecureReveal(rtx)))
|
|
|
|
print('where_result\'s type:', type(where_result), 'rtx\'s type:', type(rtx))
|
|
SQLidx += 1
|
|
print('From ID:{} the {}th SQL\'s parse completed!'.format(id, SQLidx))
|
|
|
|
|
|
TIME_END = time.time()
|
|
print('Successfully execute all the SQLs, total time:', TIME_END - TIME_START, 's')
|
|
rtt.deactivate()
|
|
|
|
|
|
ONLINE_LIST = [0,1,0,1]
|
|
SQLs = ["SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)"]
|
|
# SQLs = ["SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC",
|
|
# "SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)",
|
|
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < 5000000 AND deposit3 < loan3",
|
|
# "SELECT ID FROM S_user2_table0 WHERE credit2 >= 5",
|
|
# "SELECT ID FROM S_user1_table0 WHERE loan1 < 80000 OR loan1 >= 150000",
|
|
# "SELECT ID FROM S_user0_table0 WHERE AGE < 40",
|
|
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3",
|
|
# "SELECT ID FROM S_user3_table0 WHERE credit3 < 6 AND deposit3 < loan3",
|
|
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR (ID > 100010)",
|
|
# "SELECT ID FROM S_user3_table0 WHERE (loan3 >= 450000) AND ((deposit3 < loan3) OR (deposit3 <= 30000))",
|
|
# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (loan3 >= 100000)",
|
|
# "SELECT ID FROM S_user1_table0 WHERE credit1 >= 5",
|
|
# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (credit3 < 7)",
|
|
# "SELECT ID FROM S_user0_table0 WHERE AGE > 60",
|
|
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2",
|
|
# "SELECT ID FROM S_user2_table0 WHERE (deposit2 >= loan2) AND (credit > 4)",
|
|
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR ID < 100500",
|
|
# "SELECT ID FROM S_user2_table0 WHERE (credit < 3 OR credit > 8) AND (deposit2 >= loan2)",
|
|
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR deposit2 < 300000",
|
|
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR loan3 >= 2500000",
|
|
# "SELECT ID FROM S_user3_table0 WHERE ID < 100015 AND deposit3 < loan3",
|
|
# "SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC"]
|
|
|
|
p0 = multiprocessing.Process(target = executeSecurePlan, args = (0, ONLINE_LIST, SQLs))
|
|
p1 = multiprocessing.Process(target = executeSecurePlan, args = (1, ONLINE_LIST, SQLs))
|
|
p2 = multiprocessing.Process(target = executeSecurePlan, args = (2, ONLINE_LIST, SQLs))
|
|
|
|
p0.daemon = True
|
|
p0.start()
|
|
p1.daemon = True
|
|
p1.start()
|
|
p2.daemon = True
|
|
p2.start()
|
|
|
|
p0.join()
|
|
p1.join()
|
|
p2.join()
|
|
|
|
# if __name__ == '__main__':
|
|
# SQL = "SELECT deposit3,credit3 FROM S_user3_table0 WHERE deposit3>200000 AND credit<=6"
|
|
# plan = parseSQL(SQL)
|
|
# executeSecurePlan([0,1,0,1], plan)
|