codepedia2/apps/test/test.py

105 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# a = [{'linenum': 0, 'nums': 1}, {'linenum': 1, 'nums': 2}, {'linenum': 2, 'nums': 2}]
# annos_count = {}
# for i in a :
# annos_count[i['linenum']] = i['nums']
# print(annos_count)
import numpy as np
import random
import collections
import zipfile
words = 'abcdabcdabcd'
file_name = '/Users/yujie/PycharmProjects/tensorflow/tensorflow_learning/text8.zip'
def read_data(file_name):
with zipfile.ZipFile(file_name) as f:
data = tf
# 构建字典
def build_dataset(words, n_words):
"""
:param words: 传输进来的所有文本组成的一个单词列表
:param n_words: 取所有文本里面频率最高的n个
:return:
"""
count = [['UNK', -1]]
# 出现在文本频率最高的n个单词及其出现次数
count.extend(collections.Counter(words).most_common(n_words-1))
print('count:', count)
# 单词与索引位置
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
print("dictionary:", dictionary, len(dictionary))
# 存有每个单词的索引位置
data = list()
reversed_dictionary = dict()
unk_count = 0
for word in words:
if word in dictionary:
index = dictionary[word]
else:
index = 0
unk_count += 1
data.append(index)
count[0][1] = unk_count
print(data)
# 反向词典 能够由索引得到单词
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
print(reversed_dictionary)
return data, count, dictionary, reversed_dictionary
data, count, dictionary, reversed_dictionary = build_dataset(words, 5)
# 获取批数据
def generate_batch(batch_size, num_skips, skip_window):
"""
:param batch_size: 我们一次取得一批样本的数量
:param num_skips: num_skips对于一个输入数据产生多少个标签数据。
:param skip_window: 窗口的大小,确定取一个词周边多远的词来训练
skip_window决定上下文的长度就是当前词的周围多少个词内的词被视为它的上下文的范围内
然后从这上下文范围内的词中随机取num_skips个与输入组合成num_skips组训练数据
:return:
"""
global data_index
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
batch = np.ndarray(shape=(batch_size), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ] ,buffer的第skip_window个数据永远是当前处理循环里的输入数据
buffer = collections.deque(maxlen=span)
if data_index + span > len(data):
data_index = 0
buffer.extend(data[data_index:data_index + span])
data_index += span
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [skip_window] #自己肯定要排除掉,不能自己作为自己的上下文
for j in range(num_skips):
while target in targets_to_avoid:
target = random.randint(0, span - 1) # 随机取一个,增强随机性,减少训练时进入局部最优解
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips+j, 0] = buffer[target]
if data_index == len(data):
buffer[:] = data[:span]
data_index = span
else:
buffer.append(data[data_index])
data_index += 1
data_index = (data_index + len(data) -span ) %len(data)
# 蝙蝠侠战胜了超人,美国队长却被钢铁侠暴打
# [3,90,600,58,77,888,965]
# batch [90,90,600,600,58,58,77,77,888,888]
# label [3,600,90,58,600,77,58,888,77,965]