SMPCache/secureGroupBy.py

113 lines
3.7 KiB
Python

#!/usr/bin/env python3
import multiprocessing
import sys
import time
import numpy as np
import tensorflow as tf
import csv
import os
from queryParse import parseSQL
from saveCipherTable import *
from loadCipherTable import *
from AST import *
from obliviousSort import *
from cache import *
from parameters import *
def secureGroupBy(id, ciphertable, attr = 0, having = None):
'''
This module will do the secure group by
'''
'''
Step1. sort the ciphertable according to attr
'''
import latticex.rosetta as rtt
'''
Step2. compute the SS of compare result
'''
compare_result = []
sess = tf.Session()
param = parameters()
cacheName = 'SQL2' + '_' + str(attr) + '_' + 'ASC' + '_' + str(param.dataScale)
if param.cacheTurn == True and ifCached(id, cacheName):
LOAD_START = time.time()
sorted_table = cacheLoad(id, cacheName, [len(ciphertable), len(ciphertable[0])])
LOAD_END = time.time()
print('Load cache total time:', LOAD_END - LOAD_START, 's')
else:
sorted_table = oblivious_odd_even_merge_sort(attr, ciphertable, 0)
# print(sess.run(rtt.SecureReveal(compare_result)))
CACHE_START = time.time()
if ifCached(id, cacheName) == False and param.cacheTurn == True:
cacheSave(id, sorted_table, cacheName)
CACHE_END = time.time()
print('Cache total time:', CACHE_END - CACHE_START, 's')
t1 = time.time()
for i in range(len(sorted_table) - 1):
compare_result.append(sess.run(rtt.SecureEqual(tf.convert_to_tensor([sorted_table[i][attr]]), tf.convert_to_tensor([sorted_table[i+1][attr]]))))
# tf.reset_default_graph()
compare_result = np.array(compare_result)
t2 = time.time()
print('Compare time:', t2 - t1, 's')
return sorted_table, compare_result
def test(id):
sys.argv.extend(["--node_id", "P{}".format(id)])
import latticex.rosetta as rtt
rtt.backend_log_to_stdout(False)
rtt.activate("SecureNN")
test_table = 'users/user0/S_user0_test.csv'
TIME_START = time.time()
if id == 0:
plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
rd = np.random.RandomState(1789)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
rd = np.random.RandomState(1999)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
session = tf.Session()
group_rtxdata = secureGroupBy(rtx, 0, None)
TIME_END = time.time()
print('Successfully group the array, total time:', TIME_END - TIME_START)
print('ciphertext:\n', group_rtxdata)
group_rtxdata_plaintext = session.run(rtt.SecureReveal(group_rtxdata))
tf.get_default_graph().finalize()
print(group_rtxdata_plaintext)
# p0 = multiprocessing.Process(target = test, args = (0,))
# p1 = multiprocessing.Process(target = test, args = (1,))
# p2 = multiprocessing.Process(target = test, args = (2,))
# p0.daemon = True
# p0.start()
# p1.daemon = True
# p1.start()
# p2.daemon = True
# p2.start()
# p0.join()
# p1.join()
# p2.join()