diff --git a/executeSecurePlan.py b/executeSecurePlan.py new file mode 100644 index 0000000..6184d81 --- /dev/null +++ b/executeSecurePlan.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python3 + +import csv +import os +from pydoc import plain +import sys +import multiprocessing +from venv import logger +import numpy as np +import tensorflow as tf +import time +import random +import logging + +from queryParse import parseSQL +from saveCipherTable import * +from loadCipherTable import * +from AST import * +from obliviousSort import * +from cache import * +from parameters import * + +def judgeTypes(tuple): + ''' + When compute the result from the operand stack, we should judge the type of the operand: + + There are three types: + (1) Digit: e.g. '4', '50000' return 1 + (2) Column: e.g. 'ID', 'deposit3' return 2 + (3) Cipher result: e.g. [[b'ul\xcd\x11\x89\x9d\xba\xe1#'][b'\x9fIa\xcbj%8n#']...] return 3 + + ''' + if type(tuple) != str: + return 3 + else: + if tuple.isdigit() == True: + return 1 + else: + return 2 + +def computePostfix(id, postexp, cipherTable, header): + ''' + This function will compute the result according to postfix expressions + ''' + import latticex.rosetta as rtt + # rtt.activate("SecureNN") + operand = [] + + # i = 0 + flag = 0 + for token in postexp: + if token not in ['ALL', 'ANY', 'BETWEEN', 'LIKE', 'IN', 'OR', 'SOME', 'AND', 'NOT', 'op0', 'op1', 'op2', 'op3', 'op4', 'op5', 'op6', 'op7', 'op8', 'op9', '(', ')']: + ''' + If token is operand, push to the temp stack + ''' + operand.append(token) + else: + ''' + If token is operator, compute the intermediate results and push to the stack + ''' + roperand = operand.pop() + loperand = operand.pop() + + ''' + Judge the operand type and generate the operand(SS): + If operand is digit, generate the share of digit directly + If operand is string, select the cipher column from the ciphertable + If operand is ndarray, use it directly + ''' + param = parameters() + cacheFlag = cacheRule(id, loperand, roperand, token) and param.cacheTurn + if cacheFlag == True: + ''' + If the table has cached? + + True -> load the MPC cache directly; + False -> execute MPC + ''' + cachePath = cacheName(loperand, roperand, token) + if ifCached(id, cachePath) == True and param.cacheTurn == True: + operand.append(cacheLoad(id, cachePath, [1, param.dataScale])) + else: + for cidx in range(len(header)): + if header[cidx] == loperand: + loperand = cipherTable[:,cidx] + break + else: + continue + + for cidx in range(len(header)): + if header[cidx] == roperand: + roperand = cipherTable[:,cidx] + break + else: + continue + + rttres = secureOperator(id, loperand, roperand, token) + ''' + Judge if the rttres need to be cached + ''' + cacheSave(id, rttres, cachePath) + + operand.append(rttres) + + else: + if judgeTypes(loperand) == 1: + loperand = [[int(loperand)]] + if id == 0: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(loperand) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None) + if id == 1: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(loperand) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None) + if id == 2: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(loperand) + loperand = rtdata0 + # loperand = tf.broadcast_to(rtdata0, [20,1]) + + elif judgeTypes(loperand) == 2: + for cidx in range(len(header)): + if header[cidx] == loperand: + loperand = cipherTable[:,cidx] + break + else: + continue + + if judgeTypes(roperand) == 1: + roperand = [[int(roperand)]] + if id == 0: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(roperand) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None) + if id == 1: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(roperand) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None) + if id == 2: + rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(roperand) + roperand = rtdata0 + # roperand = tf.broadcast_to(rtdata0, [20,1]) + + elif judgeTypes(roperand) == 2: + for cidx in range(len(header)): + if header[cidx] == roperand: + roperand = cipherTable[:,cidx] + break + else: + continue + + rttres = secureOperator(id, loperand, roperand, token) + ''' + Judge if the rttres need to be cached + ''' + operand.append(rttres) + + return operand.pop() + +def secureOperator(id, leftOperand, rightOperand, op): + import latticex.rosetta as rtt + # rtt.activate("SecureNN") + ''' + After generate the rtt operand, this module will compute the result + Operator ReplaceCode + > op0 + < op1 + == op2 + >= op3 + <= op4 + <> op5 + + op6 + - op7 + * op8 + / op9 + ''' + + if op == 'op0': + rtres = rtt.SecureGreater(leftOperand, rightOperand) + elif op == 'op1': + rtres = rtt.SecureLess(leftOperand, rightOperand) + elif op == 'op2': + rtres = rtt.SecureEqual(leftOperand, rightOperand) + elif op == 'op3': + rtres = rtt.SecureGreaterEqual(leftOperand, rightOperand) + elif op == 'op4': + rtres = rtt.SecureLessEqual(leftOperand, rightOperand) + elif op == 'op5': + rtres = rtt.SecureLogicalNot(rtt.SecureEqual(leftOperand, rightOperand)) + elif op == 'op6': + rtres = rtt.SecureAdd(leftOperand, rightOperand) + elif op == 'op7': + rtres = rtt.SecureSub(leftOperand, rightOperand) + elif op == 'op8': + rtres = rtt.SecureMul(leftOperand, rightOperand) + elif op == 'op9': + rtres = rtt.SecureSecureFloorDiv(leftOperand, rightOperand) + elif op == 'AND': + rtres = rtt.SecureLogicalAnd(leftOperand, rightOperand) + elif op == 'OR': + rtres = rtt.SecureLogicalOr(leftOperand, rightOperand) + return rtres + +def executeSecurePlan(id, onlinelist, SQLs): + SQLidx = 0 + TIME_START = time.time() + param = parameters() + for SQL in SQLs: + plan = parseSQL(SQL) + ''' + This module will execute the secure plan generated by queryParse.py + + Parameters: + (1) onlinelist: Indicate the which user(s) is(are) online/offline. 0 represents offline and 1 represents online + (2) plan: The key-value datastructure generated by queryParse.py + ''' + + ''' + Step1. This module should identify which column to be selected + ''' + columns = [] + for item in plan['SELECT']: + if item[0] not in [',', '(', ')']: + columns.append(item[0]) + + ''' + Step2. Identify which table(plaintext) shoule be uploaded and selected + + table: attributes_source_name + + attribute: S/M, S-single, M-merge + source: user0/user1/user2/user3/... + name: table name + + e.g. S_user2_table0 + ''' + table = plan['FROM'][0][0] + table_attribute = table.split('_')[0] + table_source = table.split('_')[1] + table_name = table.split('_')[2] + + tablePath = 'users/{}/{}.csv'.format(table_source, table) + if os.path.exists(tablePath) == False: + print('Please check the tablename in SQL!') + exit(0) + + ''' + Step3. According to header, cipher_result and column, server will return the expected columns of ciphertext table + + Datastructures: + (1)table header is stored in header[], the data is recorded in plaintext[] + (2)select_column_idx: the columns index appear in 'SELECT' key + (3)where_column_idx: the columns index appear in 'WHERE' key + (4)cipher_column_idx: the union of select_column_idx and where_column_idx + ''' + sys.argv.extend(["--node_id", "P{}".format(id)]) + import latticex.rosetta as rtt + rtt.backend_log_to_stdout(False) + + rtt.activate("SecureNN") + + header = [] + with open(tablePath, 'r') as f: + reader = csv.reader(f) + header = list(reader)[0] + header = np.array(header) + + select_column_idx = [] + for column in columns: + for idx in range(len(header)): + if column == header[idx]: + select_column_idx.append(idx) + else: + continue + + select_delete_columns = [] + if columns[0] == '*': + select_delete_columns = [] + else: + for i in range(len(header)): + if i in select_column_idx: + continue + else: + select_delete_columns.append(i) + + + ''' + Parse and compute the 'ORDER BY' subquery + + Note. This module just support single attribute sort + + e.g. 'SELECT ID,AGE FROM TABLE ORDER BY DEPOSIT ASC|DESC' + ''' + if 'ORDER BY' in plan.keys(): + order_columns = [] + order_column_idx = [] + # if 'ORDER BY' in plan.keys(): + kv = plan['ORDER BY'][0][0].split(' ') + order_columns.append(kv[0]) + order_mode = kv[1] + + for column in order_columns: + for idx in range(len(header)): + if column == header[idx]: + order_column_idx.append(idx) + else: + continue + + cipher_column_idx = list(set(select_column_idx).union(set(order_column_idx))) + # print('order_columns_idx:', order_column_idx, 'select_column_idx:', select_column_idx, 'cipher_column_idx:', cipher_column_idx) + cipher_delete_column_idx = [] + for i in range(len(header)): + if i in cipher_column_idx: + continue + else: + cipher_delete_column_idx.append(i) + + ''' + Generate the cache path and judge whether to save/load to cache + + e.g. 'SELECT ID FROM TABLE ORDER BY DEPOSIT ASC' will be cached to server/P*/TABLE_ID_DEPOSIT_ASC.sdata + ''' + cacheName = table + '_' + for i in range(len(select_column_idx)): + cacheName += (header[i] + '_') + cacheName += (kv[0] + '_' + kv[1]) + + if ifCached(id, cacheName) == True and param.cacheTurn == True: + rtres = cacheLoad(id, cacheName, [param.dataScale, len(select_column_idx)]) + + else: + ''' + StepO-1. We define P0 to generate the data slices for plaintext + + After generating the secret share, P0,P1,P2 will hold the data slices + + Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share + ''' + + if id == 0: + plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1) + plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None) + elif id == 1: + rd = np.random.RandomState(1789) + plaintext = rd.randint(0, 1, (1, 1)) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None) + else: + rd = np.random.RandomState(1999) + plaintext = rd.randint(0, 1, (1, 1)) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext) + + header = np.delete(header, cipher_delete_column_idx) + + ''' + StepO-2. Execute oblivious odd even merge sort on cipher table + + order_mode: Check whether the sort mode is in ASC|DESC + + attr_idx: Sorting will be done according to attr + ''' + if order_mode == 'ASC': + sortMode = 0 + elif order_mode == 'DESC': + sortMode = 1 + else: + print('Please check your order mode, just support ASC|DESC!') + exit(-1) + + attr_idx = 0 + for attr_idx in range(len(header)): + ''' + Note. This module just support single attribute for sorting + ''' + if header[attr_idx] == order_columns[0]: + break + else: + attr_idx += 1 + + rtres = oblivious_odd_even_merge_sort(attr_idx, rtx, sortMode) + session = tf.Session() + sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres)) + print(sorted_rtxdata_plaintext) + + if ifCached(id, cacheName) == False and param.cacheTurn == True: + cacheSave(id, rtres, cacheName) + + # session = tf.Session() + # sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres)) + # print(sorted_rtxdata_plaintext) + + + ''' + Parse and compute the 'WHERE' subquery + ''' + if 'WHERE' in plan.keys(): + where_columns = [] + where_column_idx = [] + # if 'WHERE' in plan.keys(): + for item in plan['WHERE']: + for tuple in item: + where_columns.append(tuple) + # print(where_columns) + + for column in where_columns: + for idx in range(len(header)): + if column == header[idx]: + where_column_idx.append(idx) + else: + continue + + cipher_column_idx = list(set(select_column_idx).union(set(where_column_idx))) + cipher_delete_column_idx = [] + for i in range(len(header)): + if i in cipher_column_idx: + continue + else: + cipher_delete_column_idx.append(i) + + ''' + StepW-1. We define P0 to generate the data slices for plaintext + + After generating the secret share, P0,P1,P2 will hold the data slices + + Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share + ''' + + if id == 0: + plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1) + plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None) + elif id == 1: + rd = np.random.RandomState(1789) + plaintext = rd.randint(0, 1, (1, 1)) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None) + else: + rd = np.random.RandomState(1999) + plaintext = rd.randint(0, 1, (1, 1)) + rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None) + rty = rtt.controller.PrivateDataset(["P1"]).load_X(None) + rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext) + + header = np.delete(header, cipher_delete_column_idx) + + ''' + StepW-2. After generating the whole ciphertext table, we should parse the WHERE subquery + + This module aimed to transfer the WHERE subquery to logic AST + ''' + + opAST = infix2postfix(plan['WHERE']) + # print(opAST) + + select_ans_colidx = [] + select_ans = [] + + ''' + Choose the index of 'select' subquery + ''' + temp = 0 + for idx in range(len(columns)): + if columns[idx] == header[temp]: + select_ans_colidx.append(temp) + temp += 1 + else: + continue + + sess = tf.compat.v1.Session() + sess.run(tf.compat.v1.global_variables_initializer()) + + where_result = computePostfix(id, opAST, rtx, header) + where_result_new = tf.reshape(where_result[0], [param.dataScale, -1]) + + print('WHERE result:', sess.run(rtt.SecureReveal(where_result_new))) + print('rtx result:', sess.run(rtt.SecureReveal(rtx))) + + print('where_result\'s type:', type(where_result), 'rtx\'s type:', type(rtx)) + SQLidx += 1 + print('From ID:{} the {}th SQL\'s parse completed!'.format(id, SQLidx)) + + + TIME_END = time.time() + print('Successfully execute all the SQLs, total time:', TIME_END - TIME_START, 's') + rtt.deactivate() + + +ONLINE_LIST = [0,1,0,1] +SQLs = ["SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)"] +# SQLs = ["SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC", +# "SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)", +# "SELECT ID FROM S_user3_table0 WHERE deposit3 < 5000000 AND deposit3 < loan3", +# "SELECT ID FROM S_user2_table0 WHERE credit2 >= 5", +# "SELECT ID FROM S_user1_table0 WHERE loan1 < 80000 OR loan1 >= 150000", +# "SELECT ID FROM S_user0_table0 WHERE AGE < 40", +# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3", +# "SELECT ID FROM S_user3_table0 WHERE credit3 < 6 AND deposit3 < loan3", +# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR (ID > 100010)", +# "SELECT ID FROM S_user3_table0 WHERE (loan3 >= 450000) AND ((deposit3 < loan3) OR (deposit3 <= 30000))", +# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (loan3 >= 100000)", +# "SELECT ID FROM S_user1_table0 WHERE credit1 >= 5", +# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (credit3 < 7)", +# "SELECT ID FROM S_user0_table0 WHERE AGE > 60", +# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2", +# "SELECT ID FROM S_user2_table0 WHERE (deposit2 >= loan2) AND (credit > 4)", +# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR ID < 100500", +# "SELECT ID FROM S_user2_table0 WHERE (credit < 3 OR credit > 8) AND (deposit2 >= loan2)", +# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR deposit2 < 300000", +# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR loan3 >= 2500000", +# "SELECT ID FROM S_user3_table0 WHERE ID < 100015 AND deposit3 < loan3", +# "SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC"] + +p0 = multiprocessing.Process(target = executeSecurePlan, args = (0, ONLINE_LIST, SQLs)) +p1 = multiprocessing.Process(target = executeSecurePlan, args = (1, ONLINE_LIST, SQLs)) +p2 = multiprocessing.Process(target = executeSecurePlan, args = (2, ONLINE_LIST, SQLs)) + +p0.daemon = True +p0.start() +p1.daemon = True +p1.start() +p2.daemon = True +p2.start() + +p0.join() +p1.join() +p2.join() + +# if __name__ == '__main__': +# SQL = "SELECT deposit3,credit3 FROM S_user3_table0 WHERE deposit3>200000 AND credit<=6" +# plan = parseSQL(SQL) +# executeSecurePlan([0,1,0,1], plan)