SMPCache/executeSecurePlan.py

545 lines
22 KiB
Python
Raw Normal View History

2023-11-16 12:58:26 +08:00
#!/usr/bin/env python3
import csv
import os
from pydoc import plain
import sys
import multiprocessing
from venv import logger
import numpy as np
import tensorflow as tf
import time
import random
import logging
from queryParse import parseSQL
from saveCipherTable import *
from loadCipherTable import *
from AST import *
from obliviousSort import *
from cache import *
from parameters import *
def judgeTypes(tuple):
'''
When compute the result from the operand stack, we should judge the type of the operand:
There are three types:
(1) Digit: e.g. '4', '50000' return 1
(2) Column: e.g. 'ID', 'deposit3' return 2
(3) Cipher result: e.g. [[b'ul\xcd\x11\x89\x9d\xba\xe1#'][b'\x9fIa\xcbj%8n#']...] return 3
'''
if type(tuple) != str:
return 3
else:
if tuple.isdigit() == True:
return 1
else:
return 2
def computePostfix(id, postexp, cipherTable, header):
'''
This function will compute the result according to postfix expressions
'''
import latticex.rosetta as rtt
# rtt.activate("SecureNN")
operand = []
# i = 0
flag = 0
for token in postexp:
if token not in ['ALL', 'ANY', 'BETWEEN', 'LIKE', 'IN', 'OR', 'SOME', 'AND', 'NOT', 'op0', 'op1', 'op2', 'op3', 'op4', 'op5', 'op6', 'op7', 'op8', 'op9', '(', ')']:
'''
If token is operand, push to the temp stack
'''
operand.append(token)
else:
'''
If token is operator, compute the intermediate results and push to the stack
'''
roperand = operand.pop()
loperand = operand.pop()
'''
Judge the operand type and generate the operand(SS):
If operand is digit, generate the share of digit directly
If operand is string, select the cipher column from the ciphertable
If operand is ndarray, use it directly
'''
param = parameters()
cacheFlag = cacheRule(id, loperand, roperand, token) and param.cacheTurn
if cacheFlag == True:
'''
If the table has cached?
True -> load the MPC cache directly;
False -> execute MPC
'''
cachePath = cacheName(loperand, roperand, token)
if ifCached(id, cachePath) == True and param.cacheTurn == True:
operand.append(cacheLoad(id, cachePath, [1, param.dataScale]))
else:
for cidx in range(len(header)):
if header[cidx] == loperand:
loperand = cipherTable[:,cidx]
break
else:
continue
for cidx in range(len(header)):
if header[cidx] == roperand:
roperand = cipherTable[:,cidx]
break
else:
continue
rttres = secureOperator(id, loperand, roperand, token)
'''
Judge if the rttres need to be cached
'''
cacheSave(id, rttres, cachePath)
operand.append(rttres)
else:
if judgeTypes(loperand) == 1:
loperand = [[int(loperand)]]
if id == 0:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(loperand)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
if id == 1:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(loperand)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
if id == 2:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(loperand)
loperand = rtdata0
# loperand = tf.broadcast_to(rtdata0, [20,1])
elif judgeTypes(loperand) == 2:
for cidx in range(len(header)):
if header[cidx] == loperand:
loperand = cipherTable[:,cidx]
break
else:
continue
if judgeTypes(roperand) == 1:
roperand = [[int(roperand)]]
if id == 0:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(roperand)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
if id == 1:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(roperand)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(None)
if id == 2:
rtdata0 = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rtdata1 = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtdata2 = rtt.controller.PrivateDataset(["P2"]).load_X(roperand)
roperand = rtdata0
# roperand = tf.broadcast_to(rtdata0, [20,1])
elif judgeTypes(roperand) == 2:
for cidx in range(len(header)):
if header[cidx] == roperand:
roperand = cipherTable[:,cidx]
break
else:
continue
rttres = secureOperator(id, loperand, roperand, token)
'''
Judge if the rttres need to be cached
'''
operand.append(rttres)
return operand.pop()
def secureOperator(id, leftOperand, rightOperand, op):
import latticex.rosetta as rtt
# rtt.activate("SecureNN")
'''
After generate the rtt operand, this module will compute the result
Operator ReplaceCode
> op0
< op1
== op2
>= op3
<= op4
<> op5
+ op6
- op7
* op8
/ op9
'''
if op == 'op0':
rtres = rtt.SecureGreater(leftOperand, rightOperand)
elif op == 'op1':
rtres = rtt.SecureLess(leftOperand, rightOperand)
elif op == 'op2':
rtres = rtt.SecureEqual(leftOperand, rightOperand)
elif op == 'op3':
rtres = rtt.SecureGreaterEqual(leftOperand, rightOperand)
elif op == 'op4':
rtres = rtt.SecureLessEqual(leftOperand, rightOperand)
elif op == 'op5':
rtres = rtt.SecureLogicalNot(rtt.SecureEqual(leftOperand, rightOperand))
elif op == 'op6':
rtres = rtt.SecureAdd(leftOperand, rightOperand)
elif op == 'op7':
rtres = rtt.SecureSub(leftOperand, rightOperand)
elif op == 'op8':
rtres = rtt.SecureMul(leftOperand, rightOperand)
elif op == 'op9':
rtres = rtt.SecureSecureFloorDiv(leftOperand, rightOperand)
elif op == 'AND':
rtres = rtt.SecureLogicalAnd(leftOperand, rightOperand)
elif op == 'OR':
rtres = rtt.SecureLogicalOr(leftOperand, rightOperand)
return rtres
def executeSecurePlan(id, onlinelist, SQLs):
SQLidx = 0
TIME_START = time.time()
param = parameters()
for SQL in SQLs:
plan = parseSQL(SQL)
'''
This module will execute the secure plan generated by queryParse.py
Parameters:
(1) onlinelist: Indicate the which user(s) is(are) online/offline. 0 represents offline and 1 represents online
(2) plan: The key-value datastructure generated by queryParse.py
'''
'''
Step1. This module should identify which column to be selected
'''
columns = []
for item in plan['SELECT']:
if item[0] not in [',', '(', ')']:
columns.append(item[0])
'''
Step2. Identify which table(plaintext) shoule be uploaded and selected
table: attributes_source_name
attribute: S/M, S-single, M-merge
source: user0/user1/user2/user3/...
name: table name
e.g. S_user2_table0
'''
table = plan['FROM'][0][0]
table_attribute = table.split('_')[0]
table_source = table.split('_')[1]
table_name = table.split('_')[2]
tablePath = 'users/{}/{}.csv'.format(table_source, table)
if os.path.exists(tablePath) == False:
print('Please check the tablename in SQL!')
exit(0)
'''
Step3. According to header, cipher_result and column, server will return the expected columns of ciphertext table
Datastructures:
(1)table header is stored in header[], the data is recorded in plaintext[]
(2)select_column_idx: the columns index appear in 'SELECT' key
(3)where_column_idx: the columns index appear in 'WHERE' key
(4)cipher_column_idx: the union of select_column_idx and where_column_idx
'''
sys.argv.extend(["--node_id", "P{}".format(id)])
import latticex.rosetta as rtt
rtt.backend_log_to_stdout(False)
rtt.activate("SecureNN")
header = []
with open(tablePath, 'r') as f:
reader = csv.reader(f)
header = list(reader)[0]
header = np.array(header)
select_column_idx = []
for column in columns:
for idx in range(len(header)):
if column == header[idx]:
select_column_idx.append(idx)
else:
continue
select_delete_columns = []
if columns[0] == '*':
select_delete_columns = []
else:
for i in range(len(header)):
if i in select_column_idx:
continue
else:
select_delete_columns.append(i)
'''
Parse and compute the 'ORDER BY' subquery
Note. This module just support single attribute sort
e.g. 'SELECT ID,AGE FROM TABLE ORDER BY DEPOSIT ASC|DESC'
'''
if 'ORDER BY' in plan.keys():
order_columns = []
order_column_idx = []
# if 'ORDER BY' in plan.keys():
kv = plan['ORDER BY'][0][0].split(' ')
order_columns.append(kv[0])
order_mode = kv[1]
for column in order_columns:
for idx in range(len(header)):
if column == header[idx]:
order_column_idx.append(idx)
else:
continue
cipher_column_idx = list(set(select_column_idx).union(set(order_column_idx)))
# print('order_columns_idx:', order_column_idx, 'select_column_idx:', select_column_idx, 'cipher_column_idx:', cipher_column_idx)
cipher_delete_column_idx = []
for i in range(len(header)):
if i in cipher_column_idx:
continue
else:
cipher_delete_column_idx.append(i)
'''
Generate the cache path and judge whether to save/load to cache
e.g. 'SELECT ID FROM TABLE ORDER BY DEPOSIT ASC' will be cached to server/P*/TABLE_ID_DEPOSIT_ASC.sdata
'''
cacheName = table + '_'
for i in range(len(select_column_idx)):
cacheName += (header[i] + '_')
cacheName += (kv[0] + '_' + kv[1])
if ifCached(id, cacheName) == True and param.cacheTurn == True:
rtres = cacheLoad(id, cacheName, [param.dataScale, len(select_column_idx)])
else:
'''
StepO-1. We define P0 to generate the data slices for plaintext
After generating the secret share, P0,P1,P2 will hold the data slices
Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share
'''
if id == 0:
plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1)
plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1)
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
rd = np.random.RandomState(1789)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
rd = np.random.RandomState(1999)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
header = np.delete(header, cipher_delete_column_idx)
'''
StepO-2. Execute oblivious odd even merge sort on cipher table
order_mode: Check whether the sort mode is in ASC|DESC
attr_idx: Sorting will be done according to attr
'''
if order_mode == 'ASC':
sortMode = 0
elif order_mode == 'DESC':
sortMode = 1
else:
print('Please check your order mode, just support ASC|DESC!')
exit(-1)
attr_idx = 0
for attr_idx in range(len(header)):
'''
Note. This module just support single attribute for sorting
'''
if header[attr_idx] == order_columns[0]:
break
else:
attr_idx += 1
rtres = oblivious_odd_even_merge_sort(attr_idx, rtx, sortMode)
session = tf.Session()
sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres))
print(sorted_rtxdata_plaintext)
if ifCached(id, cacheName) == False and param.cacheTurn == True:
cacheSave(id, rtres, cacheName)
# session = tf.Session()
# sorted_rtxdata_plaintext = session.run(rtt.SecureReveal(rtres))
# print(sorted_rtxdata_plaintext)
'''
Parse and compute the 'WHERE' subquery
'''
if 'WHERE' in plan.keys():
where_columns = []
where_column_idx = []
# if 'WHERE' in plan.keys():
for item in plan['WHERE']:
for tuple in item:
where_columns.append(tuple)
# print(where_columns)
for column in where_columns:
for idx in range(len(header)):
if column == header[idx]:
where_column_idx.append(idx)
else:
continue
cipher_column_idx = list(set(select_column_idx).union(set(where_column_idx)))
cipher_delete_column_idx = []
for i in range(len(header)):
if i in cipher_column_idx:
continue
else:
cipher_delete_column_idx.append(i)
'''
StepW-1. We define P0 to generate the data slices for plaintext
After generating the secret share, P0,P1,P2 will hold the data slices
Note: P1 and P2 generate random 1*1 matrix to ensure the socket run, just P0 generate the valid secret share
'''
if id == 0:
plaintext = np.loadtxt(open(tablePath), delimiter = ",", skiprows = 1)
plaintext = np.delete(plaintext, cipher_delete_column_idx, axis = 1)
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
rd = np.random.RandomState(1789)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
rd = np.random.RandomState(1999)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
header = np.delete(header, cipher_delete_column_idx)
'''
StepW-2. After generating the whole ciphertext table, we should parse the WHERE subquery
This module aimed to transfer the WHERE subquery to logic AST
'''
opAST = infix2postfix(plan['WHERE'])
# print(opAST)
select_ans_colidx = []
select_ans = []
'''
Choose the index of 'select' subquery
'''
temp = 0
for idx in range(len(columns)):
if columns[idx] == header[temp]:
select_ans_colidx.append(temp)
temp += 1
else:
continue
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
where_result = computePostfix(id, opAST, rtx, header)
where_result_new = tf.reshape(where_result[0], [param.dataScale, -1])
print('WHERE result:', sess.run(rtt.SecureReveal(where_result_new)))
print('rtx result:', sess.run(rtt.SecureReveal(rtx)))
print('where_result\'s type:', type(where_result), 'rtx\'s type:', type(rtx))
SQLidx += 1
print('From ID:{} the {}th SQL\'s parse completed!'.format(id, SQLidx))
TIME_END = time.time()
print('Successfully execute all the SQLs, total time:', TIME_END - TIME_START, 's')
rtt.deactivate()
ONLINE_LIST = [0,1,0,1]
SQLs = ["SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)"]
# SQLs = ["SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC",
# "SELECT ID FROM S_user3_table0 WHERE (loan3 > 100000 AND (deposit3 < loan3)) AND (credit3 <= 3 OR credit3 >= 7)",
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < 5000000 AND deposit3 < loan3",
# "SELECT ID FROM S_user2_table0 WHERE credit2 >= 5",
# "SELECT ID FROM S_user1_table0 WHERE loan1 < 80000 OR loan1 >= 150000",
# "SELECT ID FROM S_user0_table0 WHERE AGE < 40",
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3",
# "SELECT ID FROM S_user3_table0 WHERE credit3 < 6 AND deposit3 < loan3",
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR (ID > 100010)",
# "SELECT ID FROM S_user3_table0 WHERE (loan3 >= 450000) AND ((deposit3 < loan3) OR (deposit3 <= 30000))",
# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (loan3 >= 100000)",
# "SELECT ID FROM S_user1_table0 WHERE credit1 >= 5",
# "SELECT ID FROM S_user3_table0 WHERE (deposit3 < loan3) AND (credit3 < 7)",
# "SELECT ID FROM S_user0_table0 WHERE AGE > 60",
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2",
# "SELECT ID FROM S_user2_table0 WHERE (deposit2 >= loan2) AND (credit > 4)",
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR ID < 100500",
# "SELECT ID FROM S_user2_table0 WHERE (credit < 3 OR credit > 8) AND (deposit2 >= loan2)",
# "SELECT ID FROM S_user2_table0 WHERE deposit2 >= loan2 OR deposit2 < 300000",
# "SELECT ID FROM S_user3_table0 WHERE deposit3 < loan3 OR loan3 >= 2500000",
# "SELECT ID FROM S_user3_table0 WHERE ID < 100015 AND deposit3 < loan3",
# "SELECT ID FROM S_user3_table0 ORDER BY loan3 DESC"]
p0 = multiprocessing.Process(target = executeSecurePlan, args = (0, ONLINE_LIST, SQLs))
p1 = multiprocessing.Process(target = executeSecurePlan, args = (1, ONLINE_LIST, SQLs))
p2 = multiprocessing.Process(target = executeSecurePlan, args = (2, ONLINE_LIST, SQLs))
p0.daemon = True
p0.start()
p1.daemon = True
p1.start()
p2.daemon = True
p2.start()
p0.join()
p1.join()
p2.join()
# if __name__ == '__main__':
# SQL = "SELECT deposit3,credit3 FROM S_user3_table0 WHERE deposit3>200000 AND credit<=6"
# plan = parseSQL(SQL)
# executeSecurePlan([0,1,0,1], plan)