forked from jiuyuan/CPM-9G-8B
160 lines
5.5 KiB
Python
160 lines
5.5 KiB
Python
import os
|
||
import random
|
||
|
||
import bmtrain as bmt
|
||
import numpy as np
|
||
|
||
|
||
class BitSet:
|
||
def __init__(self, size=1024**2):
|
||
self.size = size
|
||
self.bitset = np.zeros(self.size, dtype=bool)
|
||
|
||
def _ensure_capacity(self, num):
|
||
"""确保bitset有足够的容量来存储指定的数字"""
|
||
if num >= self.size:
|
||
# 扩展bitset大小
|
||
new_size = max(num + 1, self.size * 2)
|
||
new_bitset = np.zeros(new_size, dtype=bool)
|
||
new_bitset[: self.size] = self.bitset
|
||
self.bitset = new_bitset
|
||
self.size = new_size
|
||
bmt.print_rank("enlarge size to {}".format(self.size))
|
||
|
||
def add(self, num):
|
||
"""向bitset中添加一个数字"""
|
||
self._ensure_capacity(num)
|
||
self.bitset[num] = True
|
||
|
||
def remove(self, num):
|
||
"""从bitset中移除一个数字"""
|
||
if num < self.size:
|
||
self.bitset[num] = False
|
||
|
||
def contains(self, num):
|
||
"""检查bitset是否包含某个数字"""
|
||
return num < self.size and self.bitset[num]
|
||
|
||
def __contains__(self, num):
|
||
return self.contains(num)
|
||
|
||
def update(self, iterable_or_bitset):
|
||
"""使用可迭代对象或另一个BitSet中的元素更新当前bitset"""
|
||
if isinstance(iterable_or_bitset, BitSet):
|
||
# 如果参数是BitSet,则使用numpy的向量化操作更新
|
||
self._ensure_capacity(iterable_or_bitset.size)
|
||
self.bitset[: iterable_or_bitset.size] |= iterable_or_bitset.bitset
|
||
else:
|
||
# 如果参数是可迭代对象,则遍历并添加每个元素
|
||
for num in iterable_or_bitset:
|
||
self.add(num)
|
||
|
||
def __sub__(self, other):
|
||
"""实现减法运算符,使用numpy向量化操作来高效地创建一个新的bitset"""
|
||
# 创建一个新的bitset实例
|
||
result = BitSet(max(self.size, other.size))
|
||
# 使用numpy的向量化逻辑运算
|
||
result.bitset[: self.size] = self.bitset & ~other.bitset[: self.size]
|
||
return result
|
||
|
||
def __isub__(self, other):
|
||
"""实现就地减法运算符,利用numpy向量化操作进行高效的元素移除"""
|
||
# 首先确保other的大小不超过当前bitset的大小
|
||
min_size = min(self.size, other.size)
|
||
# 使用numpy的向量化逻辑运算进行元素移除
|
||
self.bitset[:min_size] &= ~other.bitset[:min_size]
|
||
return self
|
||
|
||
def __str__(self):
|
||
"""返回bitset的字符串表示,列出所有为真的位的索引"""
|
||
# 找出所有为真的位的索引
|
||
true_indices = np.where(self.bitset)[0]
|
||
# 将这些索引转换为字符串并用逗号分隔
|
||
indices_str = ", ".join(map(str, true_indices))
|
||
return f"BitSet({indices_str})"
|
||
|
||
def __len__(self):
|
||
"""返回bitset中为真的元素个数"""
|
||
return self.bitset.sum()
|
||
|
||
def capacity(self):
|
||
return self.size
|
||
|
||
def density(self):
|
||
return len(self) / self.size
|
||
|
||
def memory_usage(self):
|
||
"""返回bitset所占用的内存大小,以KB、MB或GB为单位"""
|
||
bytes_usage = self.bitset.nbytes
|
||
if bytes_usage < 1024:
|
||
return f"{bytes_usage} B"
|
||
elif bytes_usage < 1024**2:
|
||
return f"{bytes_usage / 1024:.2f} KB"
|
||
elif bytes_usage < 1024**3:
|
||
return f"{bytes_usage / 1024**2:.2f} MB"
|
||
else:
|
||
return f"{bytes_usage / 1024**3:.2f} GB"
|
||
|
||
def to_list(self):
|
||
"""返回一个包含所有为真位索引的列表"""
|
||
return list(np.where(self.bitset)[0])
|
||
|
||
def save(self, filename):
|
||
"""将bitset保存到文件"""
|
||
|
||
def random_hash():
|
||
"""返回一个随机哈希值"""
|
||
return random.randint(0, 2**64 - 1)
|
||
|
||
filename_with_suffix = filename + ".{}.npy".format(random_hash())
|
||
dirname = os.path.dirname(filename_with_suffix)
|
||
os.makedirs(dirname, exist_ok=True)
|
||
np.save(filename_with_suffix, self.bitset)
|
||
return os.path.basename(filename_with_suffix) # 返回最后的名字,不带前缀,支持tranfer项目
|
||
|
||
@classmethod
|
||
def load(cls, filename_with_suffix):
|
||
"""从文件加载bitset并创建一个新的BitSet实例"""
|
||
bitset_array = np.load(filename_with_suffix)
|
||
bitset = cls(bitset_array.size)
|
||
bitset.bitset = bitset_array
|
||
return bitset
|
||
|
||
|
||
def bitset_diff(normal_set, bitset):
|
||
"""返回存在于normal_set中但不在bitset中的元素集合"""
|
||
ret = {elem for elem in normal_set if not bitset.contains(elem)}
|
||
return ret
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 示例使用
|
||
bitset1 = BitSet(1024)
|
||
bitset1.update([100, 200, 300, 1023])
|
||
|
||
bitset2 = BitSet(1024)
|
||
bitset2.update([100, 400, 1023])
|
||
|
||
result_bitset = bitset1 - bitset2
|
||
print(100 in result_bitset) # 应该输出False
|
||
print(200 in result_bitset) # 应该输出True
|
||
print(300 in result_bitset) # 应该输出True
|
||
print(1023 in result_bitset) # 应该输出False
|
||
|
||
bitset1 -= bitset2
|
||
print(result_bitset) # BitSet(200, 300)
|
||
print(bitset1) # BitSet(200, 300)
|
||
print(bitset2) # BitSet(100, 400, 1023)
|
||
|
||
bitsetlarge = BitSet(1024**3)
|
||
print(len(bitsetlarge), bitsetlarge.capacity(), bitsetlarge.density(), bitset1.density())
|
||
print("BitSet memory usage:", bitsetlarge.memory_usage())
|
||
|
||
print(bitset_diff({100, 200}, bitset2))
|
||
|
||
bitset1.update(bitset2)
|
||
bitsetlarge.add(52260134)
|
||
bitset2.update(bitsetlarge)
|
||
print(bitset1) # BitSet(100, 200, 300, 400, 1023)
|
||
print(bitset2) # BitSet(100, 400, 1023, 52260134)
|