113 lines
3.7 KiB
Python
113 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import multiprocessing
|
|
import sys
|
|
import time
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import csv
|
|
import os
|
|
|
|
from queryParse import parseSQL
|
|
from saveCipherTable import *
|
|
from loadCipherTable import *
|
|
from AST import *
|
|
from obliviousSort import *
|
|
from cache import *
|
|
from parameters import *
|
|
|
|
def secureGroupBy(id, ciphertable, attr = 0, having = None):
|
|
'''
|
|
This module will do the secure group by
|
|
'''
|
|
|
|
'''
|
|
Step1. sort the ciphertable according to attr
|
|
'''
|
|
import latticex.rosetta as rtt
|
|
|
|
'''
|
|
Step2. compute the SS of compare result
|
|
'''
|
|
compare_result = []
|
|
sess = tf.Session()
|
|
|
|
param = parameters()
|
|
cacheName = 'SQL2' + '_' + str(attr) + '_' + 'ASC' + '_' + str(param.dataScale)
|
|
if param.cacheTurn == True and ifCached(id, cacheName):
|
|
LOAD_START = time.time()
|
|
sorted_table = cacheLoad(id, cacheName, [len(ciphertable), len(ciphertable[0])])
|
|
LOAD_END = time.time()
|
|
print('Load cache total time:', LOAD_END - LOAD_START, 's')
|
|
|
|
else:
|
|
sorted_table = oblivious_odd_even_merge_sort(attr, ciphertable, 0)
|
|
# print(sess.run(rtt.SecureReveal(compare_result)))
|
|
CACHE_START = time.time()
|
|
if ifCached(id, cacheName) == False and param.cacheTurn == True:
|
|
cacheSave(id, sorted_table, cacheName)
|
|
CACHE_END = time.time()
|
|
print('Cache total time:', CACHE_END - CACHE_START, 's')
|
|
|
|
t1 = time.time()
|
|
for i in range(len(sorted_table) - 1):
|
|
compare_result.append(sess.run(rtt.SecureEqual(tf.convert_to_tensor([sorted_table[i][attr]]), tf.convert_to_tensor([sorted_table[i+1][attr]]))))
|
|
# tf.reset_default_graph()
|
|
compare_result = np.array(compare_result)
|
|
t2 = time.time()
|
|
print('Compare time:', t2 - t1, 's')
|
|
|
|
return sorted_table, compare_result
|
|
|
|
def test(id):
|
|
sys.argv.extend(["--node_id", "P{}".format(id)])
|
|
import latticex.rosetta as rtt
|
|
rtt.backend_log_to_stdout(False)
|
|
rtt.activate("SecureNN")
|
|
|
|
test_table = 'users/user0/S_user0_test.csv'
|
|
|
|
TIME_START = time.time()
|
|
|
|
if id == 0:
|
|
plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
elif id == 1:
|
|
rd = np.random.RandomState(1789)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
|
else:
|
|
rd = np.random.RandomState(1999)
|
|
plaintext = rd.randint(0, 1, (1, 1))
|
|
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
|
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
|
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
|
|
|
session = tf.Session()
|
|
group_rtxdata = secureGroupBy(rtx, 0, None)
|
|
TIME_END = time.time()
|
|
print('Successfully group the array, total time:', TIME_END - TIME_START)
|
|
print('ciphertext:\n', group_rtxdata)
|
|
group_rtxdata_plaintext = session.run(rtt.SecureReveal(group_rtxdata))
|
|
tf.get_default_graph().finalize()
|
|
print(group_rtxdata_plaintext)
|
|
|
|
|
|
# p0 = multiprocessing.Process(target = test, args = (0,))
|
|
# p1 = multiprocessing.Process(target = test, args = (1,))
|
|
# p2 = multiprocessing.Process(target = test, args = (2,))
|
|
|
|
# p0.daemon = True
|
|
# p0.start()
|
|
# p1.daemon = True
|
|
# p1.start()
|
|
# p2.daemon = True
|
|
# p2.start()
|
|
|
|
# p0.join()
|
|
# p1.join()
|
|
# p2.join() |