182 lines
5.5 KiB
Python
182 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import csv
|
|
import os
|
|
from pydoc import plain
|
|
import sys
|
|
import multiprocessing
|
|
from venv import logger
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import time
|
|
import random
|
|
import logging
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
|
|
|
from queryParse import parseSQL
|
|
from saveCipherTable import *
|
|
from loadCipherTable import *
|
|
from AST import *
|
|
from obliviousSort import *
|
|
# from secureGroupBy import *
|
|
from cache import *
|
|
from parameters import *
|
|
from executeSecurePlan import *
|
|
|
|
'''
|
|
In this py file, we test the SQL demo on TPC-H lineitem(part):
|
|
|
|
SELECT
|
|
sum(l_extendedprice * l_discount) as revenue
|
|
FROM
|
|
lineitem
|
|
WHERE
|
|
(l_shipdate >= 1995-01-01 AND l_shipdate < 1998-12-01)
|
|
AND (l_discount >= 0.05 AND l_discount <= 0.07)
|
|
AND (l_quantity < 24)
|
|
|
|
Note: this test just for this SQL, deployer can parse your own SQL to use our MPC-Cache idea
|
|
'''
|
|
|
|
def test_demo1(id):
|
|
sys.argv.extend(["--node_id", "P{}".format(id)])
|
|
import latticex.rosetta as rtt
|
|
rtt.backend_log_to_stdout(False)
|
|
rtt.activate("SecureNN")
|
|
|
|
sess = tf.Session()
|
|
|
|
SQL1_START = time.time()
|
|
|
|
header = ['l_orderkey', 'l_partkey', 'l_suppkey', 'l_linenumber', 'l_quantity', 'l_extendedprice', 'l_discount', 'l_tax', 'l_shipdate', 'l_commitdate', 'l_receiptdate']
|
|
SQL1_START = time.time()
|
|
|
|
'''
|
|
Step1. select the record that (l_shipdate >= 1995-01-01 AND l_shipdate < 1998-12-01)
|
|
AND (l_discount >= 0.05 AND l_discount <= 0.07)
|
|
AND (l_quantity < 24)
|
|
|
|
'''
|
|
|
|
SQL_SUB1 = [['l_shipdate', 'op3', '19950101'], ['AND'], ['l_shipdate', 'op1', '19981201']]
|
|
|
|
SQL_SUB2 = [['l_discount', 'op3', '0.05'], ['AND'], ['l_discount', 'op1', '0.07']]
|
|
|
|
SQL_SUB3 = [['l_quantity', 'op1', '24']]
|
|
|
|
SQL_PLAN = [['l_shipdate', 'op3', '19950101'], ['AND'], ['l_shipdate', 'op1', '19981201'], ['AND'],
|
|
['l_discount', 'op3', '0.05'], ['AND'], ['l_discount', 'op1', '0.07'], ['AND'],
|
|
['l_quantity', 'op1', '24']]
|
|
|
|
opAST1 = infix2postfix(SQL_SUB1)
|
|
|
|
opAST2 = infix2postfix(SQL_SUB2)
|
|
|
|
opAST3 = infix2postfix(SQL_SUB3)
|
|
|
|
opAST = infix2postfix(SQL_PLAN)
|
|
|
|
|
|
if id == 0:
|
|
'''
|
|
Generate the 1's SS
|
|
'''
|
|
plaintext = [[0], [1]]
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
elif id == 1:
|
|
test_table = 'users/user1/lineitem_4096.csv'
|
|
plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
else:
|
|
plaintext = [[0], [1]]
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
|
|
|
print('Successfully load the data')
|
|
|
|
MPC_START = time.time()
|
|
|
|
where_result = computePostfix(id, opAST, rty, header)
|
|
|
|
# print('where_result')
|
|
|
|
# where_result_1 = computePostfix(id, opAST1, rty, header)
|
|
# where_result_1 = tf.transpose(where_result_1)
|
|
|
|
# where_result_2 = [computePostfix(id, opAST2, rty, header)]
|
|
# where_result_2 = tf.transpose(where_result_2)
|
|
|
|
# where_result_3 = computePostfix(id, opAST3, rty, header)
|
|
# where_result_3 = tf.transpose(where_result_3)
|
|
|
|
# where_result = rtt.SecureLogicalAnd((rtt.SecureLogicalAnd(where_result_1, sess.run(where_result_2))), sess.run(where_result_3))
|
|
|
|
|
|
# condition_rtdata = rtt.SecureMul(where_result_1, rty)
|
|
|
|
# print(sess.run(rtt.SecureReveal(where_result_1)))
|
|
|
|
# print(sess.run(rtt.SecureReveal(where_result_2)))
|
|
|
|
# print(sess.run(rtt.SecureReveal(where_result_3)))
|
|
|
|
# print(where_result.shape)
|
|
|
|
node2 = time.time()
|
|
|
|
# result = rtt.SecureMul(where_result, rty)
|
|
|
|
'''
|
|
Step2. compute the sum(l_extendedprice * l_discount)
|
|
'''
|
|
# result = sess.run(result)
|
|
|
|
|
|
# revenue = rtt.SecureMul(result[:,5], result[:,6])
|
|
# ssum = rtt.SecureSum(revenue)
|
|
# print(sess.run(rtt.SecureReveal(ssum)))
|
|
# ssum = rtx[0]
|
|
# for i in range(len(result)):
|
|
# # tf.reset_default_graph()
|
|
# # print(i)
|
|
# ssum = rtt.SecureAdd(ssum, rtt.SecureMul(result[i][5], result[i][6]))
|
|
# ssum = sess.run(ssum)
|
|
# print(sess.run(rtt.SecureReveal(ssum)))
|
|
|
|
SQL1_END = time.time()
|
|
# print('Node time:', SQL1_END - node2, 's')
|
|
|
|
print('MPC time:', SQL1_END - MPC_START, 's')
|
|
|
|
print('Total time:', SQL1_END - SQL1_START, 's')
|
|
|
|
|
|
# print(sess.run(rtt.SecureReveal(ssum)))
|
|
|
|
|
|
# print(sess.run(rtt.SecureReveal(sorted_rtxdata)))
|
|
# print(sess.run(rtt.SecureReveal(compare_result)))
|
|
# print(sess.run(rtt.SecureReveal(D_0)))
|
|
# print(sess.run(rtt.SecureReveal(C_)))
|
|
# print(sess.run(rtt.SecureReveal(D_1)))
|
|
|
|
p0 = multiprocessing.Process(target = test_demo1, args = (0,))
|
|
p1 = multiprocessing.Process(target = test_demo1, args = (1,))
|
|
p2 = multiprocessing.Process(target = test_demo1, args = (2,))
|
|
|
|
p0.daemon = True
|
|
p0.start()
|
|
p1.daemon = True
|
|
p1.start()
|
|
p2.daemon = True
|
|
p2.start()
|
|
|
|
p0.join()
|
|
p1.join()
|
|
p2.join() |