CPM-9G-8B/FM_9G/fm9g/utils/bitset.py

160 lines
5.5 KiB
Python
Raw Normal View History

2024-07-15 14:27:10 +08:00
import os
import random
import bmtrain as bmt
import numpy as np
class BitSet:
def __init__(self, size=1024**2):
self.size = size
self.bitset = np.zeros(self.size, dtype=bool)
def _ensure_capacity(self, num):
"""确保bitset有足够的容量来存储指定的数字"""
if num >= self.size:
# 扩展bitset大小
new_size = max(num + 1, self.size * 2)
new_bitset = np.zeros(new_size, dtype=bool)
new_bitset[: self.size] = self.bitset
self.bitset = new_bitset
self.size = new_size
bmt.print_rank("enlarge size to {}".format(self.size))
def add(self, num):
"""向bitset中添加一个数字"""
self._ensure_capacity(num)
self.bitset[num] = True
def remove(self, num):
"""从bitset中移除一个数字"""
if num < self.size:
self.bitset[num] = False
def contains(self, num):
"""检查bitset是否包含某个数字"""
return num < self.size and self.bitset[num]
def __contains__(self, num):
return self.contains(num)
def update(self, iterable_or_bitset):
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
if isinstance(iterable_or_bitset, BitSet):
# 如果参数是BitSet则使用numpy的向量化操作更新
self._ensure_capacity(iterable_or_bitset.size)
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
else:
# 如果参数是可迭代对象,则遍历并添加每个元素
for num in iterable_or_bitset:
self.add(num)
def __sub__(self, other):
"""实现减法运算符使用numpy向量化操作来高效地创建一个新的bitset"""
# 创建一个新的bitset实例
result = BitSet(max(self.size, other.size))
# 使用numpy的向量化逻辑运算
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
return result
def __isub__(self, other):
"""实现就地减法运算符利用numpy向量化操作进行高效的元素移除"""
# 首先确保other的大小不超过当前bitset的大小
min_size = min(self.size, other.size)
# 使用numpy的向量化逻辑运算进行元素移除
self.bitset[:min_size] &= ~other.bitset[:min_size]
return self
def __str__(self):
"""返回bitset的字符串表示列出所有为真的位的索引"""
# 找出所有为真的位的索引
true_indices = np.where(self.bitset)[0]
# 将这些索引转换为字符串并用逗号分隔
indices_str = ", ".join(map(str, true_indices))
return f"BitSet({indices_str})"
def __len__(self):
"""返回bitset中为真的元素个数"""
return self.bitset.sum()
def capacity(self):
return self.size
def density(self):
return len(self) / self.size
def memory_usage(self):
"""返回bitset所占用的内存大小以KB、MB或GB为单位"""
bytes_usage = self.bitset.nbytes
if bytes_usage < 1024:
return f"{bytes_usage} B"
elif bytes_usage < 1024**2:
return f"{bytes_usage / 1024:.2f} KB"
elif bytes_usage < 1024**3:
return f"{bytes_usage / 1024**2:.2f} MB"
else:
return f"{bytes_usage / 1024**3:.2f} GB"
def to_list(self):
"""返回一个包含所有为真位索引的列表"""
return list(np.where(self.bitset)[0])
def save(self, filename):
"""将bitset保存到文件"""
def random_hash():
"""返回一个随机哈希值"""
return random.randint(0, 2**64 - 1)
filename_with_suffix = filename + ".{}.npy".format(random_hash())
dirname = os.path.dirname(filename_with_suffix)
os.makedirs(dirname, exist_ok=True)
np.save(filename_with_suffix, self.bitset)
return os.path.basename(filename_with_suffix) # 返回最后的名字不带前缀支持tranfer项目
@classmethod
def load(cls, filename_with_suffix):
"""从文件加载bitset并创建一个新的BitSet实例"""
bitset_array = np.load(filename_with_suffix)
bitset = cls(bitset_array.size)
bitset.bitset = bitset_array
return bitset
def bitset_diff(normal_set, bitset):
"""返回存在于normal_set中但不在bitset中的元素集合"""
ret = {elem for elem in normal_set if not bitset.contains(elem)}
return ret
if __name__ == "__main__":
# 示例使用
bitset1 = BitSet(1024)
bitset1.update([100, 200, 300, 1023])
bitset2 = BitSet(1024)
bitset2.update([100, 400, 1023])
result_bitset = bitset1 - bitset2
print(100 in result_bitset) # 应该输出False
print(200 in result_bitset) # 应该输出True
print(300 in result_bitset) # 应该输出True
print(1023 in result_bitset) # 应该输出False
bitset1 -= bitset2
print(result_bitset) # BitSet(200, 300)
print(bitset1) # BitSet(200, 300)
print(bitset2) # BitSet(100, 400, 1023)
bitsetlarge = BitSet(1024**3)
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
print("BitSet memory usage:", bitsetlarge.memory_usage())
print(bitset_diff({100, 200}, bitset2))
bitset1.update(bitset2)
bitsetlarge.add(52260134)
bitset2.update(bitsetlarge)
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
print(bitset2) # BitSet(100, 400, 1023, 52260134)