diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c index 07f44555fd03..a10aa8563e41 100644 --- a/mm/userfaultfd.c +++ b/mm/userfaultfd.c @@ -18,6 +18,36 @@ #include #include "internal.h" +static __always_inline +struct vm_area_struct *find_dst_vma(struct mm_struct *dst_mm, + unsigned long dst_start, + unsigned long len) +{ + /* + * Make sure that the dst range is both valid and fully within a + * single existing vma. + */ + struct vm_area_struct *dst_vma; + + dst_vma = find_vma(dst_mm, dst_start); + if (!dst_vma) + return NULL; + + if (dst_start < dst_vma->vm_start || + dst_start + len > dst_vma->vm_end) + return NULL; + + /* + * Check the vma is registered in uffd, this is required to + * enforce the VM_MAYWRITE check done at uffd registration + * time. + */ + if (!dst_vma->vm_userfaultfd_ctx.ctx) + return NULL; + + return dst_vma; +} + static int mcopy_atomic_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd, struct vm_area_struct *dst_vma, @@ -220,20 +250,9 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm, */ if (!dst_vma) { err = -ENOENT; - dst_vma = find_vma(dst_mm, dst_start); + dst_vma = find_dst_vma(dst_mm, dst_start, len); if (!dst_vma || !is_vm_hugetlb_page(dst_vma)) goto out_unlock; - /* - * Check the vma is registered in uffd, this is - * required to enforce the VM_MAYWRITE check done at - * uffd registration time. - */ - if (!dst_vma->vm_userfaultfd_ctx.ctx) - goto out_unlock; - - if (dst_start < dst_vma->vm_start || - dst_start + len > dst_vma->vm_end) - goto out_unlock; err = -EINVAL; if (vma_hpagesize != vma_kernel_pagesize(dst_vma)) @@ -468,20 +487,9 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm, * both valid and fully within a single existing vma. */ err = -ENOENT; - dst_vma = find_vma(dst_mm, dst_start); + dst_vma = find_dst_vma(dst_mm, dst_start, len); if (!dst_vma) goto out_unlock; - /* - * Check the vma is registered in uffd, this is required to - * enforce the VM_MAYWRITE check done at uffd registration - * time. - */ - if (!dst_vma->vm_userfaultfd_ctx.ctx) - goto out_unlock; - - if (dst_start < dst_vma->vm_start || - dst_start + len > dst_vma->vm_end) - goto out_unlock; err = -EINVAL; /*