SMPCache/bitonic.py

266 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 6 09:57:58 2022
@author: hanye
"""
from Compiler.types import *
from Compiler.library import *
from Compiler import ml
from Compiler.util import is_zero, tree_reduce
from Compiler import util
#1.随机产生整数向量打包向量和其索引成为一个节点压入list返回一个类似于vector的结构
#2.对节点的茫然交换做定义
#3.双调排序
#4.求取topK
sfix.set_precision(16,57)
print_float_precision(20)
class secure_topK_element():
def __init__(self):
pass
def Gen_Random_NodeList(self,nvals,value_type):#nvals:要生成的list的长度
Random_NodeList= Matrix(nvals,2,value_type)#n行2列第一列存index第二列存value->一行是一个node
@for_range(nvals)
def _(i):
currentVal=value_type.get_random(lower=-10,upper=10)
currentIndex=sint(i)
Random_NodeList[i][0]=currentIndex
Random_NodeList[i][1]=currentVal
# print_ln('currentNode:(%s,%s)', currentIndex.reveal(), currentVal.reveal())
# @for_range(nvals)
#def _(i):
#print_ln('checkNode:(%s,%s)', Random_NodeList[i][0].reveal(), Random_NodeList[i][1].reveal())
return Random_NodeList
def cond_swap_with_bit(self, b, x, y):
bx = b * x
by = b * y
return bx + y - by, x - bx + by
# def cond_swap(self, x, y,type_sort):
# if(type_sort==0):
# b = x.__gt__(y) #私有比较
# elif(type_sort==1):
# b = x.__lt__(y)
# x_new, y_new = self.cond_swap_with_bit(b, x, y)
# return b, x_new, y_new
def cond_swap_forNode(self,InputVector,Dataflag,index1,index2,type_sort):#从交换的结果来说type=0时大->小type=1时小->大
x_index=InputVector[index1][0]
y_index=InputVector[index2][0]
x_value=InputVector[index1][1]
y_value=InputVector[index2][1]
x_flag=Dataflag[index1]
y_flag=Dataflag[index2]
#交换
@if_e((x_flag == regint(0)) & (y_flag != regint(0)))
def _():
InputVector[index1][1] = y_value
InputVector[index1][0] = y_index
InputVector[index2][1] = x_value
InputVector[index2][0] = x_index
Dataflag[index1] = y_flag
Dataflag[index2] = x_flag
@else_
def _():
#不交换
@if_e(Dataflag[index2] == 0)
def _():
pass
#私有比较
@else_
def _():
if(type_sort==0):
b = x_value.__gt__(y_value) #私有比较
elif(type_sort==1):
b = x_value.__lt__(y_value)
x_value_new, y_value_new = self.cond_swap_with_bit(b, x_value,y_value)
x_index_new, y_index_new = self.cond_swap_with_bit(b, x_index,y_index)
InputVector[index1][0]=x_index_new
InputVector[index2][0]=y_index_new
InputVector[index1][1]=x_value_new
InputVector[index2][1]=y_value_new
def get_secure_topK(self,DataVector,K,theta,type_sort,value_type):
#start_timer(2)
Data_row=DataVector.sizes[0]
DataHead = Matrix(K,2,value_type)
assert(Data_row >= K)
if(Data_row==K or K*(1+theta)>=Data_row): #代表和一次整体排序等价
DataSort = self.odd_even_merge_sort(DataVector, type_sort)
@for_range_opt(K)
def _(j):
DataHead[j]=DataSort[j]
else:
#DataHead=Matrix(K,2,value_type)
#DataRest=Matrix(Data_row-K,2,value_type)
#DataMask=Matrix(theta*K,2,value_type)
every_sort_length = K*(1+theta)
Data_tobeSorted=Matrix(every_sort_length,2,value_type)
@for_range_opt(every_sort_length)
def _(j):
Data_tobeSorted[j] = DataVector[j]
#余下元素的首节点的 索引
global DataRest_head
DataRest_head = regint(every_sort_length)
#global lastlen
#lastlen = regint((Data_row-K)%(theta * K))
#DataRest_tail = Data_row-1
DataSort = self.odd_even_merge_sort(Data_tobeSorted, type_sort)
@for_range_opt(K)
def _(j):
DataHead[j]=DataSort[j]
#print_ln("data_head:%s",DataHead.reveal_nested())
flag_continue=MemValue(1)
@do_while
def _():
#while(flag_continue==1):
last_element = (Data_row-DataRest_head)
#print_ln('last_element:%s',regint(last_element))
@if_e(last_element <= theta * K)
def _():
#assert(lastlen==(Data_row-DataRest_head))
#print_ln('here last batch:')
lastlen = int((Data_row-K)%(theta * K) + K)
#print_ln("lastlen:%s",lastlen)
Data_tobeSorted=Matrix(lastlen,2,value_type)
@for_range_opt(len(Data_tobeSorted))
def _(j):
@if_e(j < K)
def _():
Data_tobeSorted[j] = DataHead[j]
@else_
def _():
global DataRest_head
#print_ln('DataRest_head:%s',regint(DataRest_head))
Data_tobeSorted[j] = DataVector[DataRest_head]
DataRest_head = DataRest_head+1
flag_continue.write(0)
@else_
def _():
#Data_tobeSorted.assign(DataHead)
@for_range_opt(len(Data_tobeSorted))
def _(j):
@if_e(j < K)
def _():
Data_tobeSorted[j] = DataHead[j]
@else_
def _():
global DataRest_head
#print_ln('DataRest_head:%s',regint(DataRest_head))
Data_tobeSorted[j] = DataVector[DataRest_head]
DataRest_head = DataRest_head+1
#assert(DataRest_head <= Data_row)
DataSort = self.odd_even_merge_sort(Data_tobeSorted, type_sort)
@for_range_opt(K)
def _(j):
DataHead[j]=DataSort[j]
#print_ln("data_head:%s",DataHead.reveal_nested())
#print('flag_continue:',flag_continue)
return_flag = regint(flag_continue.read())
return return_flag
#stop_timer(2)
#返回索引值sint格式
topk_index_arr = Array(K,sint)
topk_index_arr.assign_all(-1)
#返回数值sfix格式
topk_value_arr = Array(K,sfix)
topk_value_arr.assign_all(-1)
@for_range_opt(K)
def _(i):
topk_index_arr[i] = DataHead[i][0]
topk_value_arr[i] = DataHead[i][1]
return topk_index_arr,topk_value_arr
# modified from loopy_odd_even_merge_sort in library.py of MP-SPDZ
def odd_even_merge_sort(self, InputVector, type_sort, sorted_length=1, n_parallel=1):
""" Pads to power of 2, sorts, removes padding """
length = len(InputVector)
length_should = len(InputVector)
while length_should & (length_should-1) != 0:
length_should = length_should + 1
#第1列index第二列value
DataSort = Matrix(length_should,2,sfix)
Dataflag = Array(length_should,regint)
@for_range_opt(length_should)
def _(i):
@if_e(i < length)
def _():
DataSort[i][0] = InputVector[i][0]
DataSort[i][1] = InputVector[i][1]
Dataflag[i] = 1
@else_
def _():
DataSort[i][0] = sint(i)
DataSort[i][1] = sint(0)
Dataflag[i] = 0
l = sorted_length
num_keys = len(DataSort)
while l < num_keys:
l *= 2
k = 1
while k < l:
k *= 2
n_outer = num_keys // l
n_inner = l // k
n_innermost = 1 if k == 2 else k // 2 - 1
@for_range_parallel(n_parallel // n_innermost // n_inner, n_outer)
def loop(i):
@for_range_parallel(n_parallel // n_innermost, n_inner)
def inner(j):
base = i * l + j
step = l // k
if k == 2:
index1=base
index2=base + step
self.cond_swap_forNode(DataSort,Dataflag,index1,index2,type_sort)
else:
@for_range_parallel(n_parallel, n_innermost)
def f(i_inner):
m1 = step + i_inner * 2 * step
m2 = m1 + base
index1=m2
index2=m2 + step
self.cond_swap_forNode(DataSort,Dataflag,index1,index2,type_sort)
return DataSort