ADD file via upload
This commit is contained in:
parent
2aec389635
commit
6f94f94af5
|
@ -0,0 +1,200 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import csv
|
||||
import os
|
||||
|
||||
def obviousAssign(cond, x, y):
|
||||
'''
|
||||
This module will assign the secret share to a variable
|
||||
|
||||
Return cond * x + y - cond * y
|
||||
|
||||
(1) If cond is 0-share, will return y-share
|
||||
(2) If cond is 1-share, will return x-share
|
||||
'''
|
||||
import latticex.rosetta as rtt
|
||||
|
||||
return rtt.SecureAdd(rtt.SecureMul(cond, rtt.SecureSub(x, y)), y)
|
||||
|
||||
def secureMax(ciphertable, attr):
|
||||
'''
|
||||
This module will return the maximum value of ciphertable[:,attr] in secret share form
|
||||
'''
|
||||
import latticex.rosetta as rtt
|
||||
|
||||
tableLen = len(ciphertable)
|
||||
|
||||
if tableLen == 1:
|
||||
return ciphertable[0][attr]
|
||||
|
||||
ciphercolumn = ciphertable[:,attr]
|
||||
# print(ciphercolumn)
|
||||
maxSS = np.array([ciphercolumn[0]])
|
||||
|
||||
sess = tf.Session()
|
||||
for i in range(1, tableLen):
|
||||
'''
|
||||
Traversing the array and extracts the maximum value obviously
|
||||
'''
|
||||
b = rtt.SecureGreater(np.array([ciphercolumn[i]]), maxSS)
|
||||
maxSS = obviousAssign(b, np.array([ciphercolumn[i]]), maxSS)
|
||||
tf.reset_default_graph()
|
||||
sess.run(maxSS)
|
||||
|
||||
return maxSS
|
||||
|
||||
|
||||
def secureMin(ciphertable, attr):
|
||||
'''
|
||||
This module will return the minimum value of ciphertable[:,attr] in secret share form
|
||||
'''
|
||||
import latticex.rosetta as rtt
|
||||
|
||||
tableLen = len(ciphertable)
|
||||
|
||||
if tableLen == 1:
|
||||
return ciphertable[0][attr]
|
||||
|
||||
ciphercolumn = ciphertable[:,attr]
|
||||
minSS = np.array([ciphercolumn[0]])
|
||||
|
||||
sess = tf.Session()
|
||||
for i in range(1, tableLen):
|
||||
'''
|
||||
Traversing the array and extracts the minimum value obviously
|
||||
'''
|
||||
b = rtt.SecureLess(np.array([ciphercolumn[i]]), minSS)
|
||||
minSS = obviousAssign(b, np.array([ciphercolumn[i]]), minSS)
|
||||
tf.reset_default_graph()
|
||||
sess.run(minSS)
|
||||
|
||||
return minSS
|
||||
|
||||
def secureSum(ciphertable, attr):
|
||||
'''
|
||||
This module will return the sum of the target records
|
||||
|
||||
e.g. 'SELECT SUM(AGE) FROM TABLE' will return the sum of AGE
|
||||
'''
|
||||
import latticex.rosetta as rtt
|
||||
|
||||
tableLen = len(ciphertable)
|
||||
|
||||
if tableLen == 1:
|
||||
return ciphertable[0][attr]
|
||||
|
||||
ciphercolumn = ciphertable[:,attr]
|
||||
sumSS = np.array(ciphercolumn[0])
|
||||
|
||||
sess = tf.Session()
|
||||
for i in range(1, tableLen):
|
||||
sumSS = rtt.SecureAdd(sumSS, np.array(ciphercolumn[i]))
|
||||
sess.run(sumSS)
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
return sumSS
|
||||
|
||||
def secureAVG(id, ciphertable, attr):
|
||||
'''
|
||||
This module will return the sum of the target records
|
||||
|
||||
e.g. 'SELECT AVG(loan) FROM TABLE' will return the average value of loan
|
||||
'''
|
||||
import latticex.rosetta as rtt
|
||||
|
||||
tableLen = len(ciphertable)
|
||||
|
||||
sumSS = secureSum(ciphertable, attr)
|
||||
|
||||
if id == 0:
|
||||
rtx = rtt.controller.PrivateDataset(["P0"]).load_X([[tableLen]])
|
||||
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
||||
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
||||
elif id == 1:
|
||||
rd = np.random.RandomState(1789)
|
||||
plaintext = rd.randint(0, 1, (1, 1))
|
||||
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
||||
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
||||
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
||||
else:
|
||||
rd = np.random.RandomState(1999)
|
||||
plaintext = rd.randint(0, 1, (1, 1))
|
||||
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
||||
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
||||
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
||||
|
||||
avgSS = rtt.SecureDivide(sumSS, rtx)
|
||||
tf.reset_default_graph()
|
||||
|
||||
return avgSS
|
||||
|
||||
# def test(id):
|
||||
# sys.argv.extend(["--node_id", "P{}".format(id)])
|
||||
# import latticex.rosetta as rtt
|
||||
# rtt.activate("SecureNN")
|
||||
|
||||
# rtt.backend_log_to_stdout(False)
|
||||
|
||||
# test_table = 'users/user1/S_user1_table0.csv'
|
||||
|
||||
# TIME_START = time.time()
|
||||
|
||||
# if id == 0:
|
||||
# plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
|
||||
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
|
||||
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
||||
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
||||
# elif id == 1:
|
||||
# rd = np.random.RandomState(1789)
|
||||
# plaintext = rd.randint(0, 1, (1, 1))
|
||||
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
||||
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
|
||||
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
|
||||
# else:
|
||||
# rd = np.random.RandomState(1999)
|
||||
# plaintext = rd.randint(0, 1, (1, 1))
|
||||
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
|
||||
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
|
||||
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
|
||||
|
||||
# session = tf.Session()
|
||||
# # smax = secureMax(rtx, 1)
|
||||
# # print('Successfully compute the max')
|
||||
|
||||
# # smin = secureMin(rtx, 1)
|
||||
# # print('Successfully compute the min')
|
||||
|
||||
# # ssum = secureSum(rtx, 1)
|
||||
# # print('Successfully compute the sum')
|
||||
|
||||
# savg = secureAVG(id, rtx, 3)
|
||||
# print('Successfully compute the avg')
|
||||
|
||||
# TIME_END = time.time()
|
||||
# print('Successfully test all the secure aggregate, total time:', TIME_END - TIME_START)
|
||||
|
||||
# plaintext = session.run(rtt.SecureReveal(savg))
|
||||
# print('min:', plaintext)
|
||||
|
||||
|
||||
|
||||
# p0 = multiprocessing.Process(target = test, args = (0,))
|
||||
# p1 = multiprocessing.Process(target = test, args = (1,))
|
||||
# p2 = multiprocessing.Process(target = test, args = (2,))
|
||||
|
||||
# p0.daemon = True
|
||||
# p0.start()
|
||||
# p1.daemon = True
|
||||
# p1.start()
|
||||
# p2.daemon = True
|
||||
# p2.start()
|
||||
|
||||
# p0.join()
|
||||
# p1.join()
|
||||
# p2.join()
|
Loading…
Reference in New Issue