diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py index 1d338eebc..eee7e8d99 100644 --- a/tools/releasetools/blockimgdiff.py +++ b/tools/releasetools/blockimgdiff.py @@ -16,7 +16,9 @@ from __future__ import print_function from collections import deque, OrderedDict from hashlib import sha1 +import array import common +import functools import heapq import itertools import multiprocessing @@ -24,6 +26,7 @@ import os import re import subprocess import threading +import time import tempfile from rangelib import RangeSet @@ -204,6 +207,23 @@ class Transfer(object): " to " + str(self.tgt_ranges) + ">") +@functools.total_ordering +class HeapItem(object): + def __init__(self, item): + self.item = item + # Negate the score since python's heap is a min-heap and we want + # the maximum score. + self.score = -item.score + def clear(self): + self.item = None + def __bool__(self): + return self.item is None + def __eq__(self, other): + return self.score == other.score + def __le__(self, other): + return self.score <= other.score + + # BlockImageDiff works on two image objects. An image object is # anything that provides the following attributes: # @@ -734,7 +754,7 @@ class BlockImageDiff(object): # - we write every block we care about exactly once. # Start with no blocks having been touched yet. - touched = RangeSet() + touched = array.array("B", "\0" * self.tgt.total_blocks) # Imagine processing the transfers in order. for xf in self.transfers: @@ -745,14 +765,22 @@ class BlockImageDiff(object): for _, sr in xf.use_stash: x = x.subtract(sr) - assert not touched.overlaps(x) - # Check that the output blocks for this transfer haven't yet been touched. - assert not touched.overlaps(xf.tgt_ranges) - # Touch all the blocks written by this transfer. - touched = touched.union(xf.tgt_ranges) + for s, e in x: + for i in range(s, e): + assert touched[i] == 0 + + # Check that the output blocks for this transfer haven't yet + # been touched, and touch all the blocks written by this + # transfer. + for s, e in xf.tgt_ranges: + for i in range(s, e): + assert touched[i] == 0 + touched[i] = 1 # Check that we've written every target block. - assert touched == self.tgt.care_map + for s, e in self.tgt.care_map: + for i in range(s, e): + assert touched[i] == 1 def ImproveVertexSequence(self): print("Improving vertex order...") @@ -889,6 +917,7 @@ class BlockImageDiff(object): for xf in self.transfers: xf.incoming = xf.goes_after.copy() xf.outgoing = xf.goes_before.copy() + xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values()) # We use an OrderedDict instead of just a set so that the output # is repeatable; otherwise it would depend on the hash values of @@ -899,52 +928,67 @@ class BlockImageDiff(object): s1 = deque() # the left side of the sequence, built from left to right s2 = deque() # the right side of the sequence, built from right to left - while G: + heap = [] + for xf in self.transfers: + xf.heap_item = HeapItem(xf) + heap.append(xf.heap_item) + heapq.heapify(heap) + sinks = set(u for u in G if not u.outgoing) + sources = set(u for u in G if not u.incoming) + + def adjust_score(iu, delta): + iu.score += delta + iu.heap_item.clear() + iu.heap_item = HeapItem(iu) + heapq.heappush(heap, iu.heap_item) + + while G: # Put all sinks at the end of the sequence. - while True: - sinks = [u for u in G if not u.outgoing] - if not sinks: - break + while sinks: + new_sinks = set() for u in sinks: + if u not in G: continue s2.appendleft(u) del G[u] for iu in u.incoming: - del iu.outgoing[u] + adjust_score(iu, -iu.outgoing.pop(u)) + if not iu.outgoing: new_sinks.add(iu) + sinks = new_sinks # Put all the sources at the beginning of the sequence. - while True: - sources = [u for u in G if not u.incoming] - if not sources: - break + while sources: + new_sources = set() for u in sources: + if u not in G: continue s1.append(u) del G[u] for iu in u.outgoing: - del iu.incoming[u] + adjust_score(iu, +iu.incoming.pop(u)) + if not iu.incoming: new_sources.add(iu) + sources = new_sources - if not G: - break + if not G: break # Find the "best" vertex to put next. "Best" is the one that # maximizes the net difference in source blocks saved we get by # pretending it's a source rather than a sink. - max_d = None - best_u = None - for u in G: - d = sum(u.outgoing.values()) - sum(u.incoming.values()) - if best_u is None or d > max_d: - max_d = d - best_u = u + while True: + u = heapq.heappop(heap) + if u and u.item in G: + u = u.item + break - u = best_u s1.append(u) del G[u] for iu in u.outgoing: - del iu.incoming[u] + adjust_score(iu, +iu.incoming.pop(u)) + if not iu.incoming: sources.add(iu) + for iu in u.incoming: - del iu.outgoing[u] + adjust_score(iu, -iu.outgoing.pop(u)) + if not iu.outgoing: sinks.add(iu) # Now record the sequence in the 'order' field of each transfer, # and by rearranging self.transfers to be in the chosen sequence. @@ -960,10 +1004,38 @@ class BlockImageDiff(object): def GenerateDigraph(self): print("Generating digraph...") + + # Each item of source_ranges will be: + # - None, if that block is not used as a source, + # - a transfer, if one transfer uses it as a source, or + # - a set of transfers. + source_ranges = [] + for b in self.transfers: + for s, e in b.src_ranges: + if e > len(source_ranges): + source_ranges.extend([None] * (e-len(source_ranges))) + for i in range(s, e): + if source_ranges[i] is None: + source_ranges[i] = b + else: + if not isinstance(source_ranges[i], set): + source_ranges[i] = set([source_ranges[i]]) + source_ranges[i].add(b) + for a in self.transfers: - for b in self.transfers: - if a is b: - continue + intersections = set() + for s, e in a.tgt_ranges: + for i in range(s, e): + if i >= len(source_ranges): break + b = source_ranges[i] + if b is not None: + if isinstance(b, set): + intersections.update(b) + else: + intersections.add(b) + + for b in intersections: + if a is b: continue # If the blocks written by A are read by B, then B needs to go before A. i = a.tgt_ranges.intersect(b.src_ranges) @@ -1092,6 +1164,7 @@ class BlockImageDiff(object): """Assert that all the RangeSets in 'seq' form a partition of the 'total' RangeSet (ie, they are nonintersecting and their union equals 'total').""" + so_far = RangeSet() for i in seq: assert not so_far.overlaps(i)