ADD file via upload

This commit is contained in:
p87059431 2023-11-16 13:00:44 +08:00
parent 075ec5c526
commit d328e524b7
1 changed files with 113 additions and 0 deletions

113
secureGroupBy.py Normal file
View File

@ -0,0 +1,113 @@
#!/usr/bin/env python3
import multiprocessing
import sys
import time
import numpy as np
import tensorflow as tf
import csv
import os
from queryParse import parseSQL
from saveCipherTable import *
from loadCipherTable import *
from AST import *
from obliviousSort import *
from cache import *
from parameters import *
def secureGroupBy(id, ciphertable, attr = 0, having = None):
'''
This module will do the secure group by
'''
'''
Step1. sort the ciphertable according to attr
'''
import latticex.rosetta as rtt
'''
Step2. compute the SS of compare result
'''
compare_result = []
sess = tf.Session()
param = parameters()
cacheName = 'SQL2' + '_' + str(attr) + '_' + 'ASC' + '_' + str(param.dataScale)
if param.cacheTurn == True and ifCached(id, cacheName):
LOAD_START = time.time()
sorted_table = cacheLoad(id, cacheName, [len(ciphertable), len(ciphertable[0])])
LOAD_END = time.time()
print('Load cache total time:', LOAD_END - LOAD_START, 's')
else:
sorted_table = oblivious_odd_even_merge_sort(attr, ciphertable, 0)
# print(sess.run(rtt.SecureReveal(compare_result)))
CACHE_START = time.time()
if ifCached(id, cacheName) == False and param.cacheTurn == True:
cacheSave(id, sorted_table, cacheName)
CACHE_END = time.time()
print('Cache total time:', CACHE_END - CACHE_START, 's')
t1 = time.time()
for i in range(len(sorted_table) - 1):
compare_result.append(sess.run(rtt.SecureEqual(tf.convert_to_tensor([sorted_table[i][attr]]), tf.convert_to_tensor([sorted_table[i+1][attr]]))))
# tf.reset_default_graph()
compare_result = np.array(compare_result)
t2 = time.time()
print('Compare time:', t2 - t1, 's')
return sorted_table, compare_result
def test(id):
sys.argv.extend(["--node_id", "P{}".format(id)])
import latticex.rosetta as rtt
rtt.backend_log_to_stdout(False)
rtt.activate("SecureNN")
test_table = 'users/user0/S_user0_test.csv'
TIME_START = time.time()
if id == 0:
plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
rd = np.random.RandomState(1789)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
rd = np.random.RandomState(1999)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
session = tf.Session()
group_rtxdata = secureGroupBy(rtx, 0, None)
TIME_END = time.time()
print('Successfully group the array, total time:', TIME_END - TIME_START)
print('ciphertext:\n', group_rtxdata)
group_rtxdata_plaintext = session.run(rtt.SecureReveal(group_rtxdata))
tf.get_default_graph().finalize()
print(group_rtxdata_plaintext)
# p0 = multiprocessing.Process(target = test, args = (0,))
# p1 = multiprocessing.Process(target = test, args = (1,))
# p2 = multiprocessing.Process(target = test, args = (2,))
# p0.daemon = True
# p0.start()
# p1.daemon = True
# p1.start()
# p2.daemon = True
# p2.start()
# p0.join()
# p1.join()
# p2.join()