CPM-9G-8B/FM_9G/fm9g/utils/bitset.py

160 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import bmtrain as bmt
import numpy as np
class BitSet:
def __init__(self, size=1024**2):
self.size = size
self.bitset = np.zeros(self.size, dtype=bool)
def _ensure_capacity(self, num):
"""确保bitset有足够的容量来存储指定的数字"""
if num >= self.size:
# 扩展bitset大小
new_size = max(num + 1, self.size * 2)
new_bitset = np.zeros(new_size, dtype=bool)
new_bitset[: self.size] = self.bitset
self.bitset = new_bitset
self.size = new_size
bmt.print_rank("enlarge size to {}".format(self.size))
def add(self, num):
"""向bitset中添加一个数字"""
self._ensure_capacity(num)
self.bitset[num] = True
def remove(self, num):
"""从bitset中移除一个数字"""
if num < self.size:
self.bitset[num] = False
def contains(self, num):
"""检查bitset是否包含某个数字"""
return num < self.size and self.bitset[num]
def __contains__(self, num):
return self.contains(num)
def update(self, iterable_or_bitset):
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
if isinstance(iterable_or_bitset, BitSet):
# 如果参数是BitSet则使用numpy的向量化操作更新
self._ensure_capacity(iterable_or_bitset.size)
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
else:
# 如果参数是可迭代对象,则遍历并添加每个元素
for num in iterable_or_bitset:
self.add(num)
def __sub__(self, other):
"""实现减法运算符使用numpy向量化操作来高效地创建一个新的bitset"""
# 创建一个新的bitset实例
result = BitSet(max(self.size, other.size))
# 使用numpy的向量化逻辑运算
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
return result
def __isub__(self, other):
"""实现就地减法运算符利用numpy向量化操作进行高效的元素移除"""
# 首先确保other的大小不超过当前bitset的大小
min_size = min(self.size, other.size)
# 使用numpy的向量化逻辑运算进行元素移除
self.bitset[:min_size] &= ~other.bitset[:min_size]
return self
def __str__(self):
"""返回bitset的字符串表示列出所有为真的位的索引"""
# 找出所有为真的位的索引
true_indices = np.where(self.bitset)[0]
# 将这些索引转换为字符串并用逗号分隔
indices_str = ", ".join(map(str, true_indices))
return f"BitSet({indices_str})"
def __len__(self):
"""返回bitset中为真的元素个数"""
return self.bitset.sum()
def capacity(self):
return self.size
def density(self):
return len(self) / self.size
def memory_usage(self):
"""返回bitset所占用的内存大小以KB、MB或GB为单位"""
bytes_usage = self.bitset.nbytes
if bytes_usage < 1024:
return f"{bytes_usage} B"
elif bytes_usage < 1024**2:
return f"{bytes_usage / 1024:.2f} KB"
elif bytes_usage < 1024**3:
return f"{bytes_usage / 1024**2:.2f} MB"
else:
return f"{bytes_usage / 1024**3:.2f} GB"
def to_list(self):
"""返回一个包含所有为真位索引的列表"""
return list(np.where(self.bitset)[0])
def save(self, filename):
"""将bitset保存到文件"""
def random_hash():
"""返回一个随机哈希值"""
return random.randint(0, 2**64 - 1)
filename_with_suffix = filename + ".{}.npy".format(random_hash())
dirname = os.path.dirname(filename_with_suffix)
os.makedirs(dirname, exist_ok=True)
np.save(filename_with_suffix, self.bitset)
return os.path.basename(filename_with_suffix) # 返回最后的名字不带前缀支持tranfer项目
@classmethod
def load(cls, filename_with_suffix):
"""从文件加载bitset并创建一个新的BitSet实例"""
bitset_array = np.load(filename_with_suffix)
bitset = cls(bitset_array.size)
bitset.bitset = bitset_array
return bitset
def bitset_diff(normal_set, bitset):
"""返回存在于normal_set中但不在bitset中的元素集合"""
ret = {elem for elem in normal_set if not bitset.contains(elem)}
return ret
if __name__ == "__main__":
# 示例使用
bitset1 = BitSet(1024)
bitset1.update([100, 200, 300, 1023])
bitset2 = BitSet(1024)
bitset2.update([100, 400, 1023])
result_bitset = bitset1 - bitset2
print(100 in result_bitset) # 应该输出False
print(200 in result_bitset) # 应该输出True
print(300 in result_bitset) # 应该输出True
print(1023 in result_bitset) # 应该输出False
bitset1 -= bitset2
print(result_bitset) # BitSet(200, 300)
print(bitset1) # BitSet(200, 300)
print(bitset2) # BitSet(100, 400, 1023)
bitsetlarge = BitSet(1024**3)
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
print("BitSet memory usage:", bitsetlarge.memory_usage())
print(bitset_diff({100, 200}, bitset2))
bitset1.update(bitset2)
bitsetlarge.add(52260134)
bitset2.update(bitsetlarge)
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
print(bitset2) # BitSet(100, 400, 1023, 52260134)