SMPCache/SQL3.py

169 lines
5.2 KiB
Python

#!/usr/bin/env python3
import csv
import os
from pydoc import plain
import sys
import multiprocessing
from venv import logger
import numpy as np
import tensorflow as tf
import time
import random
import logging
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from queryParse import parseSQL
from saveCipherTable import *
from loadCipherTable import *
from AST import *
from obliviousSort import *
from secureGroupBy import *
from cache import *
from parameters import *
from executeSecurePlan import *
'''
In this py file, we test the SQL demo:
SELECT COUNT(*)
FROM store_sales INNER JOIN store_returns ON store_sales.PID = store_returns.PID
WHERE store_returns.ReturnDate - store_sales.SaleDate <= 10
Note: this test just for this SQL, deployer can parse your own SQL to use our MPC-Cache idea
'''
def test_demo3(id):
sys.argv.extend(["--node_id", "P{}".format(id)])
import latticex.rosetta as rtt
rtt.backend_log_to_stdout(False)
rtt.activate("SecureNN")
SQL3_START = time.time()
'''
Step1. load the ciphertable to P1 and P2
'''
if id == 0:
'''
Generate the 0, 1's SS
'''
plaintext = [[0], [1]]
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
test_table = 'users/user2/store_sales_1048576.csv'
plaintext = np.loadtxt(open(test_table), delimiter = ",")
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
test_table = 'users/user2/store_returns_FULL.csv'
plaintext = np.loadtxt(open(test_table), delimiter = ",")
'''
Note: Concat a one-SS column to joinTable in order to compute the COUNT(*) later
'''
plaintext = np.hstack((plaintext, np.ones((len(plaintext), 1))))
# print(len(plaintext[0]))
# print(plaintext)
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
header = ['SaleDate', 'SaleTime', 'item', 'customer', 'cdemo', 'hdemo', 'addr', 'store', 'promo', 'ticket', 'quantity',
'wholeSaleCost', 'listPrice', 'salesPrice', 'extDiscount', 'extwholeSaleCost', 'extlistPrice', 'extsalesPrice',
'extTax', 'coupon', 'netPaid', 'netPaidIncTax', 'netProfit', 'zero', 'ReturnDate', 'one']
print('Successfully load the data')
# y = rtt.RttPlaceholder(tf.float32, shape=rty.shape)
# z = rtt.RttPlaceholder(tf.float32, shape=rtz.shape)
MPC_START = time.time()
sess = tf.Session()
'''
Step2. join two tables according to PID
Note: we just use the 2th table's return_date column and one-SS column
'''
param = parameters()
cacheName = 'sale' + '_JOIN_' + 'return' + '_' + '1048576'
if param.cacheTurn == True and ifCached(id, cacheName) == True:
LOAD_START = time.time()
joinTable = cacheLoad(id, cacheName, [max(len(rty), len(rtz)), len(header)])
LOAD_END = time.time()
print('JOIN CACHE LOAD TIME:', LOAD_END - LOAD_START, 's')
else:
JOIN_START = time.time()
joinTable = rtt.SecurePsi(rty, rtz, 2, 2, jointlist = [0, 21])
JOIN_END = time.time()
print('JOIN MPC TIME:', JOIN_END - JOIN_START, 's')
if param.cacheTurn == True and ifCached(id, cacheName) == False:
cacheSave(id, joinTable, cacheName)
psi_result = tf.reshape(joinTable[:,25], (-1, 1))
OTHER_START = time.time()
InnerJoinTable = rtt.SecureMul(joinTable, psi_result)
OTHER_END = time.time()
'''
Step3. compute the where_result
filter: 'store_returns.ReturnDate - store_sales.SaleDate <= 10'
'''
plan = parseSQL("SELECT COUNT(*) FROM table WHERE ReturnDate - SaleDate <= 10")
opAST = infix2postfix(plan['WHERE'])
where_result = computePostfix(id, opAST, InnerJoinTable, header)
CORE_MPC = time.time()
print('CORE MPC TIME:', CORE_MPC - MPC_START - (OTHER_END - OTHER_START), 's')
where_result = tf.transpose(where_result)
resultTable = rtt.SecureMul(InnerJoinTable, where_result)
count_result = rtt.SecureSum(resultTable[:,25])
SQL3_END = time.time()
print('MPC time:', SQL3_END - MPC_START, 's')
print('Total time:', SQL3_END - SQL3_START, 's')
# print('res:', sess.run(rtt.SecureReveal(count_result)))
# print(sess.run(rtt.SecureReveal(InnerJoinTable)))
# print(rty.shape)
# res = rtt.SecureReveal(joinTable)
# resp = sess.run(res)
# print(resp)
p0 = multiprocessing.Process(target = test_demo3, args = (0,))
p1 = multiprocessing.Process(target = test_demo3, args = (1,))
p2 = multiprocessing.Process(target = test_demo3, args = (2,))
p0.daemon = True
p0.start()
p1.daemon = True
p1.start()
p2.daemon = True
p2.start()
p0.join()
p1.join()
p2.join()