forked from p83651209/CPM-9G-8B
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
|
import os
|
|||
|
import random
|
|||
|
|
|||
|
import bmtrain as bmt
|
|||
|
import numpy as np
|
|||
|
|
|||
|
|
|||
|
class BitSet:
|
|||
|
def __init__(self, size=1024**2):
|
|||
|
self.size = size
|
|||
|
self.bitset = np.zeros(self.size, dtype=bool)
|
|||
|
|
|||
|
def _ensure_capacity(self, num):
|
|||
|
"""确保bitset有足够的容量来存储指定的数字"""
|
|||
|
if num >= self.size:
|
|||
|
# 扩展bitset大小
|
|||
|
new_size = max(num + 1, self.size * 2)
|
|||
|
new_bitset = np.zeros(new_size, dtype=bool)
|
|||
|
new_bitset[: self.size] = self.bitset
|
|||
|
self.bitset = new_bitset
|
|||
|
self.size = new_size
|
|||
|
bmt.print_rank("enlarge size to {}".format(self.size))
|
|||
|
|
|||
|
def add(self, num):
|
|||
|
"""向bitset中添加一个数字"""
|
|||
|
self._ensure_capacity(num)
|
|||
|
self.bitset[num] = True
|
|||
|
|
|||
|
def remove(self, num):
|
|||
|
"""从bitset中移除一个数字"""
|
|||
|
if num < self.size:
|
|||
|
self.bitset[num] = False
|
|||
|
|
|||
|
def contains(self, num):
|
|||
|
"""检查bitset是否包含某个数字"""
|
|||
|
return num < self.size and self.bitset[num]
|
|||
|
|
|||
|
def __contains__(self, num):
|
|||
|
return self.contains(num)
|
|||
|
|
|||
|
def update(self, iterable_or_bitset):
|
|||
|
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
|
|||
|
if isinstance(iterable_or_bitset, BitSet):
|
|||
|
# 如果参数是BitSet,则使用numpy的向量化操作更新
|
|||
|
self._ensure_capacity(iterable_or_bitset.size)
|
|||
|
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
|
|||
|
else:
|
|||
|
# 如果参数是可迭代对象,则遍历并添加每个元素
|
|||
|
for num in iterable_or_bitset:
|
|||
|
self.add(num)
|
|||
|
|
|||
|
def __sub__(self, other):
|
|||
|
"""实现减法运算符,使用numpy向量化操作来高效地创建一个新的bitset"""
|
|||
|
# 创建一个新的bitset实例
|
|||
|
result = BitSet(max(self.size, other.size))
|
|||
|
# 使用numpy的向量化逻辑运算
|
|||
|
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
|
|||
|
return result
|
|||
|
|
|||
|
def __isub__(self, other):
|
|||
|
"""实现就地减法运算符,利用numpy向量化操作进行高效的元素移除"""
|
|||
|
# 首先确保other的大小不超过当前bitset的大小
|
|||
|
min_size = min(self.size, other.size)
|
|||
|
# 使用numpy的向量化逻辑运算进行元素移除
|
|||
|
self.bitset[:min_size] &= ~other.bitset[:min_size]
|
|||
|
return self
|
|||
|
|
|||
|
def __str__(self):
|
|||
|
"""返回bitset的字符串表示,列出所有为真的位的索引"""
|
|||
|
# 找出所有为真的位的索引
|
|||
|
true_indices = np.where(self.bitset)[0]
|
|||
|
# 将这些索引转换为字符串并用逗号分隔
|
|||
|
indices_str = ", ".join(map(str, true_indices))
|
|||
|
return f"BitSet({indices_str})"
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
"""返回bitset中为真的元素个数"""
|
|||
|
return self.bitset.sum()
|
|||
|
|
|||
|
def capacity(self):
|
|||
|
return self.size
|
|||
|
|
|||
|
def density(self):
|
|||
|
return len(self) / self.size
|
|||
|
|
|||
|
def memory_usage(self):
|
|||
|
"""返回bitset所占用的内存大小,以KB、MB或GB为单位"""
|
|||
|
bytes_usage = self.bitset.nbytes
|
|||
|
if bytes_usage < 1024:
|
|||
|
return f"{bytes_usage} B"
|
|||
|
elif bytes_usage < 1024**2:
|
|||
|
return f"{bytes_usage / 1024:.2f} KB"
|
|||
|
elif bytes_usage < 1024**3:
|
|||
|
return f"{bytes_usage / 1024**2:.2f} MB"
|
|||
|
else:
|
|||
|
return f"{bytes_usage / 1024**3:.2f} GB"
|
|||
|
|
|||
|
def to_list(self):
|
|||
|
"""返回一个包含所有为真位索引的列表"""
|
|||
|
return list(np.where(self.bitset)[0])
|
|||
|
|
|||
|
def save(self, filename):
|
|||
|
"""将bitset保存到文件"""
|
|||
|
|
|||
|
def random_hash():
|
|||
|
"""返回一个随机哈希值"""
|
|||
|
return random.randint(0, 2**64 - 1)
|
|||
|
|
|||
|
filename_with_suffix = filename + ".{}.npy".format(random_hash())
|
|||
|
dirname = os.path.dirname(filename_with_suffix)
|
|||
|
os.makedirs(dirname, exist_ok=True)
|
|||
|
np.save(filename_with_suffix, self.bitset)
|
|||
|
return os.path.basename(filename_with_suffix) # 返回最后的名字,不带前缀,支持tranfer项目
|
|||
|
|
|||
|
@classmethod
|
|||
|
def load(cls, filename_with_suffix):
|
|||
|
"""从文件加载bitset并创建一个新的BitSet实例"""
|
|||
|
bitset_array = np.load(filename_with_suffix)
|
|||
|
bitset = cls(bitset_array.size)
|
|||
|
bitset.bitset = bitset_array
|
|||
|
return bitset
|
|||
|
|
|||
|
|
|||
|
def bitset_diff(normal_set, bitset):
|
|||
|
"""返回存在于normal_set中但不在bitset中的元素集合"""
|
|||
|
ret = {elem for elem in normal_set if not bitset.contains(elem)}
|
|||
|
return ret
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 示例使用
|
|||
|
bitset1 = BitSet(1024)
|
|||
|
bitset1.update([100, 200, 300, 1023])
|
|||
|
|
|||
|
bitset2 = BitSet(1024)
|
|||
|
bitset2.update([100, 400, 1023])
|
|||
|
|
|||
|
result_bitset = bitset1 - bitset2
|
|||
|
print(100 in result_bitset) # 应该输出False
|
|||
|
print(200 in result_bitset) # 应该输出True
|
|||
|
print(300 in result_bitset) # 应该输出True
|
|||
|
print(1023 in result_bitset) # 应该输出False
|
|||
|
|
|||
|
bitset1 -= bitset2
|
|||
|
print(result_bitset) # BitSet(200, 300)
|
|||
|
print(bitset1) # BitSet(200, 300)
|
|||
|
print(bitset2) # BitSet(100, 400, 1023)
|
|||
|
|
|||
|
bitsetlarge = BitSet(1024**3)
|
|||
|
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
|
|||
|
print("BitSet memory usage:", bitsetlarge.memory_usage())
|
|||
|
|
|||
|
print(bitset_diff({100, 200}, bitset2))
|
|||
|
|
|||
|
bitset1.update(bitset2)
|
|||
|
bitsetlarge.add(52260134)
|
|||
|
bitset2.update(bitsetlarge)
|
|||
|
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
|
|||
|
print(bitset2) # BitSet(100, 400, 1023, 52260134)
|