diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py index 2d20e23fd..b5e01d332 100644 --- a/tools/releasetools/blockimgdiff.py +++ b/tools/releasetools/blockimgdiff.py @@ -26,7 +26,8 @@ import os.path import re import sys import threading -from collections import deque, OrderedDict +import zlib +from collections import deque, namedtuple, OrderedDict from hashlib import sha1 import common @@ -36,8 +37,12 @@ __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"] logger = logging.getLogger(__name__) +# The tuple contains the style and bytes of a bsdiff|imgdiff patch. +PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"]) + def compute_patch(srcfile, tgtfile, imgdiff=False): + """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo.""" patchfile = common.MakeTempFile(prefix='patch-') cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff'] @@ -52,7 +57,7 @@ def compute_patch(srcfile, tgtfile, imgdiff=False): raise ValueError(output) with open(patchfile, 'rb') as f: - return f.read() + return PatchInfo(imgdiff, f.read()) class Image(object): @@ -203,17 +208,17 @@ class Transfer(object): self.id = len(by_id) by_id.append(self) - self._patch = None + self._patch_info = None @property - def patch(self): - return self._patch + def patch_info(self): + return self._patch_info - @patch.setter - def patch(self, patch): - if patch: + @patch_info.setter + def patch_info(self, info): + if info: assert self.style == "diff" - self._patch = patch + self._patch_info = info def NetStashChange(self): return (sum(sr.size() for (_, sr) in self.stash_before) - @@ -224,7 +229,7 @@ class Transfer(object): self.use_stash = [] self.style = "new" self.src_ranges = RangeSet() - self.patch = None + self.patch_info = None def __str__(self): return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style + @@ -462,16 +467,7 @@ class BlockImageDiff(object): self.AbbreviateSourceNames() self.FindTransfers() - # Find the ordering dependencies among transfers (this is O(n^2) - # in the number of transfers). - self.GenerateDigraph() - # Find a sequence of transfers that satisfies as many ordering - # dependencies as possible (heuristically). - self.FindVertexSequence() - # Fix up the ordering dependencies that the sequence didn't - # satisfy. - self.ReverseBackwardEdges() - self.ImproveVertexSequence() + self.FindSequenceForTransfers() # Ensure the runtime stash size is under the limit. if common.OPTIONS.cache_size is not None: @@ -829,7 +825,7 @@ class BlockImageDiff(object): # These are identical; we don't need to generate a patch, # just issue copy commands on the device. xf.style = "move" - xf.patch = None + xf.patch_info = None tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize if xf.src_ranges != xf.tgt_ranges: logger.info( @@ -839,11 +835,10 @@ class BlockImageDiff(object): xf.tgt_name + " (from " + xf.src_name + ")"), str(xf.tgt_ranges), str(xf.src_ranges)) else: - if xf.patch: - # We have already generated the patch with imgdiff, while - # splitting large APKs (i.e. in FindTransfers()). - assert not self.disable_imgdiff - imgdiff = True + if xf.patch_info: + # We have already generated the patch (e.g. during split of large + # APKs or reduction of stash size) + imgdiff = xf.patch_info.imgdiff else: imgdiff = self.CanUseImgdiff( xf.tgt_name, xf.tgt_ranges, xf.src_ranges) @@ -854,85 +849,16 @@ class BlockImageDiff(object): else: assert False, "unknown style " + xf.style - if diff_queue: - if self.threads > 1: - logger.info("Computing patches (using %d threads)...", self.threads) - else: - logger.info("Computing patches...") - - diff_total = len(diff_queue) - patches = [None] * diff_total - error_messages = [] - - # Using multiprocessing doesn't give additional benefits, due to the - # pattern of the code. The diffing work is done by subprocess.call, which - # already runs in a separate process (not affected much by the GIL - - # Global Interpreter Lock). Using multiprocess also requires either a) - # writing the diff input files in the main process before forking, or b) - # reopening the image file (SparseImage) in the worker processes. Doing - # neither of them further improves the performance. - lock = threading.Lock() - def diff_worker(): - while True: - with lock: - if not diff_queue: - return - xf_index, imgdiff, patch_index = diff_queue.pop() - xf = self.transfers[xf_index] - - patch = xf.patch - if not patch: - src_ranges = xf.src_ranges - tgt_ranges = xf.tgt_ranges - - src_file = common.MakeTempFile(prefix="src-") - with open(src_file, "wb") as fd: - self.src.WriteRangeDataToFd(src_ranges, fd) - - tgt_file = common.MakeTempFile(prefix="tgt-") - with open(tgt_file, "wb") as fd: - self.tgt.WriteRangeDataToFd(tgt_ranges, fd) - - message = [] - try: - patch = compute_patch(src_file, tgt_file, imgdiff) - except ValueError as e: - message.append( - "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % ( - "imgdiff" if imgdiff else "bsdiff", - xf.tgt_name if xf.tgt_name == xf.src_name else - xf.tgt_name + " (from " + xf.src_name + ")", - xf.tgt_ranges, xf.src_ranges, e.message)) - if message: - with lock: - error_messages.extend(message) - - with lock: - patches[patch_index] = (xf_index, patch) - - threads = [threading.Thread(target=diff_worker) - for _ in range(self.threads)] - for th in threads: - th.start() - while threads: - threads.pop().join() - - if error_messages: - logger.error('ERROR:') - logger.error('\n'.join(error_messages)) - logger.error('\n\n\n') - sys.exit(1) - else: - patches = [] + patches = self.ComputePatchesForInputList(diff_queue, False) offset = 0 with open(prefix + ".patch.dat", "wb") as patch_fd: - for index, patch in patches: + for index, patch_info, _ in patches: xf = self.transfers[index] - xf.patch_len = len(patch) + xf.patch_len = len(patch_info.content) xf.patch_start = offset offset += xf.patch_len - patch_fd.write(patch) + patch_fd.write(patch_info.content) tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize logger.info( @@ -999,6 +925,32 @@ class BlockImageDiff(object): for i in range(s, e): assert touched[i] == 1 + def FindSequenceForTransfers(self): + """Finds a sequence for the given transfers. + + The goal is to minimize the violation of order dependencies between these + transfers, so that fewer blocks are stashed when applying the update. + """ + + # Clear the existing dependency between transfers + for xf in self.transfers: + xf.goes_before = OrderedDict() + xf.goes_after = OrderedDict() + + xf.stash_before = [] + xf.use_stash = [] + + # Find the ordering dependencies among transfers (this is O(n^2) + # in the number of transfers). + self.GenerateDigraph() + # Find a sequence of transfers that satisfies as many ordering + # dependencies as possible (heuristically). + self.FindVertexSequence() + # Fix up the ordering dependencies that the sequence didn't + # satisfy. + self.ReverseBackwardEdges() + self.ImproveVertexSequence() + def ImproveVertexSequence(self): logger.info("Improving vertex order...") @@ -1248,6 +1200,105 @@ class BlockImageDiff(object): b.goes_before[a] = size a.goes_after[b] = size + def ComputePatchesForInputList(self, diff_queue, compress_target): + """Returns a list of patch information for the input list of transfers. + + Args: + diff_queue: a list of transfers with style 'diff' + compress_target: If True, compresses the target ranges of each + transfers; and save the size. + + Returns: + A list of (transfer order, patch_info, compressed_size) tuples. + """ + + if not diff_queue: + return [] + + if self.threads > 1: + logger.info("Computing patches (using %d threads)...", self.threads) + else: + logger.info("Computing patches...") + + diff_total = len(diff_queue) + patches = [None] * diff_total + error_messages = [] + + # Using multiprocessing doesn't give additional benefits, due to the + # pattern of the code. The diffing work is done by subprocess.call, which + # already runs in a separate process (not affected much by the GIL - + # Global Interpreter Lock). Using multiprocess also requires either a) + # writing the diff input files in the main process before forking, or b) + # reopening the image file (SparseImage) in the worker processes. Doing + # neither of them further improves the performance. + lock = threading.Lock() + + def diff_worker(): + while True: + with lock: + if not diff_queue: + return + xf_index, imgdiff, patch_index = diff_queue.pop() + xf = self.transfers[xf_index] + + message = [] + compressed_size = None + + patch_info = xf.patch_info + if not patch_info: + src_file = common.MakeTempFile(prefix="src-") + with open(src_file, "wb") as fd: + self.src.WriteRangeDataToFd(xf.src_ranges, fd) + + tgt_file = common.MakeTempFile(prefix="tgt-") + with open(tgt_file, "wb") as fd: + self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd) + + try: + patch_info = compute_patch(src_file, tgt_file, imgdiff) + except ValueError as e: + message.append( + "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % ( + "imgdiff" if imgdiff else "bsdiff", + xf.tgt_name if xf.tgt_name == xf.src_name else + xf.tgt_name + " (from " + xf.src_name + ")", + xf.tgt_ranges, xf.src_ranges, e.message)) + + if compress_target: + tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges) + try: + # Compresses with the default level + compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS) + compressed_data = (compress_obj.compress("".join(tgt_data)) + + compress_obj.flush()) + compressed_size = len(compressed_data) + except zlib.error as e: + message.append( + "Failed to compress the data in target range {} for {}:\n" + "{}".format(xf.tgt_ranges, xf.tgt_name, e.message)) + + if message: + with lock: + error_messages.extend(message) + + with lock: + patches[patch_index] = (xf_index, patch_info, compressed_size) + + threads = [threading.Thread(target=diff_worker) + for _ in range(self.threads)] + for th in threads: + th.start() + while threads: + threads.pop().join() + + if error_messages: + logger.error('ERROR:') + logger.error('\n'.join(error_messages)) + logger.error('\n\n\n') + sys.exit(1) + + return patches + def FindTransfers(self): """Parse the file_map to generate all the transfers.""" @@ -1585,7 +1636,7 @@ class BlockImageDiff(object): self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges), "diff", self.transfers) - transfer_split.patch = patch + transfer_split.patch_info = PatchInfo(True, patch) def AbbreviateSourceNames(self): for k in self.src.file_map.keys():