diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 4bb162c1d649..12d9905b429f 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -189,7 +189,7 @@ static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu, } static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu, - dma_addr_t start, size_t size) + dma_addr_t start, u64 size) { struct rb_node *res = NULL; struct rb_node *node = iommu->dma_list.rb_node; @@ -1288,7 +1288,7 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, int ret = -EINVAL, retries = 0; unsigned long pgshift; dma_addr_t iova = unmap->iova; - unsigned long size = unmap->size; + u64 size = unmap->size; bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL; bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR; struct rb_node *n, *first_n; @@ -1304,14 +1304,12 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, if (unmap_all) { if (iova || size) goto unlock; - size = SIZE_MAX; - } else if (!size || size & (pgsize - 1)) { + size = U64_MAX; + } else if (!size || size & (pgsize - 1) || + iova + size - 1 < iova || size > SIZE_MAX) { goto unlock; } - if (iova + size - 1 < iova || size > SIZE_MAX) - goto unlock; - /* When dirty tracking is enabled, allow only min supported pgsize */ if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) && (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {