ADD file via upload

This commit is contained in:
p87059431 2023-11-16 12:59:28 +08:00
parent 2aec389635
commit 6f94f94af5
1 changed files with 200 additions and 0 deletions

200
secureAggregate.py Normal file
View File

@ -0,0 +1,200 @@
#!/usr/bin/env python3
import multiprocessing
import sys
import time
import numpy as np
import tensorflow as tf
import csv
import os
def obviousAssign(cond, x, y):
'''
This module will assign the secret share to a variable
Return cond * x + y - cond * y
(1) If cond is 0-share, will return y-share
(2) If cond is 1-share, will return x-share
'''
import latticex.rosetta as rtt
return rtt.SecureAdd(rtt.SecureMul(cond, rtt.SecureSub(x, y)), y)
def secureMax(ciphertable, attr):
'''
This module will return the maximum value of ciphertable[:,attr] in secret share form
'''
import latticex.rosetta as rtt
tableLen = len(ciphertable)
if tableLen == 1:
return ciphertable[0][attr]
ciphercolumn = ciphertable[:,attr]
# print(ciphercolumn)
maxSS = np.array([ciphercolumn[0]])
sess = tf.Session()
for i in range(1, tableLen):
'''
Traversing the array and extracts the maximum value obviously
'''
b = rtt.SecureGreater(np.array([ciphercolumn[i]]), maxSS)
maxSS = obviousAssign(b, np.array([ciphercolumn[i]]), maxSS)
tf.reset_default_graph()
sess.run(maxSS)
return maxSS
def secureMin(ciphertable, attr):
'''
This module will return the minimum value of ciphertable[:,attr] in secret share form
'''
import latticex.rosetta as rtt
tableLen = len(ciphertable)
if tableLen == 1:
return ciphertable[0][attr]
ciphercolumn = ciphertable[:,attr]
minSS = np.array([ciphercolumn[0]])
sess = tf.Session()
for i in range(1, tableLen):
'''
Traversing the array and extracts the minimum value obviously
'''
b = rtt.SecureLess(np.array([ciphercolumn[i]]), minSS)
minSS = obviousAssign(b, np.array([ciphercolumn[i]]), minSS)
tf.reset_default_graph()
sess.run(minSS)
return minSS
def secureSum(ciphertable, attr):
'''
This module will return the sum of the target records
e.g. 'SELECT SUM(AGE) FROM TABLE' will return the sum of AGE
'''
import latticex.rosetta as rtt
tableLen = len(ciphertable)
if tableLen == 1:
return ciphertable[0][attr]
ciphercolumn = ciphertable[:,attr]
sumSS = np.array(ciphercolumn[0])
sess = tf.Session()
for i in range(1, tableLen):
sumSS = rtt.SecureAdd(sumSS, np.array(ciphercolumn[i]))
sess.run(sumSS)
tf.reset_default_graph()
return sumSS
def secureAVG(id, ciphertable, attr):
'''
This module will return the sum of the target records
e.g. 'SELECT AVG(loan) FROM TABLE' will return the average value of loan
'''
import latticex.rosetta as rtt
tableLen = len(ciphertable)
sumSS = secureSum(ciphertable, attr)
if id == 0:
rtx = rtt.controller.PrivateDataset(["P0"]).load_X([[tableLen]])
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
elif id == 1:
rd = np.random.RandomState(1789)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
else:
rd = np.random.RandomState(1999)
plaintext = rd.randint(0, 1, (1, 1))
rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
avgSS = rtt.SecureDivide(sumSS, rtx)
tf.reset_default_graph()
return avgSS
# def test(id):
# sys.argv.extend(["--node_id", "P{}".format(id)])
# import latticex.rosetta as rtt
# rtt.activate("SecureNN")
# rtt.backend_log_to_stdout(False)
# test_table = 'users/user1/S_user1_table0.csv'
# TIME_START = time.time()
# if id == 0:
# plaintext = np.loadtxt(open(test_table), delimiter = ",", skiprows = 1)
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(plaintext)
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
# elif id == 1:
# rd = np.random.RandomState(1789)
# plaintext = rd.randint(0, 1, (1, 1))
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(plaintext)
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(None)
# else:
# rd = np.random.RandomState(1999)
# plaintext = rd.randint(0, 1, (1, 1))
# rtx = rtt.controller.PrivateDataset(["P0"]).load_X(None)
# rty = rtt.controller.PrivateDataset(["P1"]).load_X(None)
# rtz = rtt.controller.PrivateDataset(["P2"]).load_X(plaintext)
# session = tf.Session()
# # smax = secureMax(rtx, 1)
# # print('Successfully compute the max')
# # smin = secureMin(rtx, 1)
# # print('Successfully compute the min')
# # ssum = secureSum(rtx, 1)
# # print('Successfully compute the sum')
# savg = secureAVG(id, rtx, 3)
# print('Successfully compute the avg')
# TIME_END = time.time()
# print('Successfully test all the secure aggregate, total time:', TIME_END - TIME_START)
# plaintext = session.run(rtt.SecureReveal(savg))
# print('min:', plaintext)
# p0 = multiprocessing.Process(target = test, args = (0,))
# p1 = multiprocessing.Process(target = test, args = (1,))
# p2 = multiprocessing.Process(target = test, args = (2,))
# p0.daemon = True
# p0.start()
# p1.daemon = True
# p1.start()
# p2.daemon = True
# p2.start()
# p0.join()
# p1.join()
# p2.join()