diff --git a/app/utils.py b/app/utils.py index 21c7fd6..034c39c 100644 --- a/app/utils.py +++ b/app/utils.py @@ -20,7 +20,7 @@ from flask_user import * from flask_login import login_user, logout_user from app.models import * from app import app -import random, string, os +import random, string, os, imghdr def getExtension(filename): return filename.rsplit(".", 1)[1].lower() if "." in filename else None @@ -28,6 +28,10 @@ def getExtension(filename): def isFilenameAllowed(filename, exts): return getExtension(filename) in exts +ALLOWED_IMAGES = set(["jpeg", "png"]) +def isAllowedImage(data): + return imghdr.what(None, data) in ALLOWED_IMAGES + def shouldReturnJson(): return "application/json" in request.accept_mimetypes and \ not "text/html" in request.accept_mimetypes @@ -36,16 +40,32 @@ def randomString(n): return ''.join(random.choice(string.ascii_lowercase + \ string.ascii_uppercase + string.digits) for _ in range(n)) -def doFileUpload(file, allowedExtensions, fileTypeName): +def doFileUpload(file, fileType, fileTypeDesc): if not file or file is None or file.filename == "": flash("No selected file", "error") return None + allowedExtensions = [] + isImage = False + if fileType == "image": + allowedExtensions = ["jpg", "jpeg", "png"] + isImage = True + elif filetype == "zip": + allowedExtensions = ["zip"] + else: + raise Exception("Invalid fileType") + ext = getExtension(file.filename) if ext is None or not ext in allowedExtensions: - flash("Please upload load " + fileTypeName, "error") + flash("Please upload load " + fileTypeDesc, "danger") return None + if isImage and not isAllowedImage(file.stream.read()): + flash("Uploaded image isn't actually an image", "danger") + return None + + file.stream.seek(0) + filename = randomString(10) + "." + ext file.save(os.path.join("app/public/uploads", filename)) return "/uploads/" + filename diff --git a/app/views/packages/releases.py b/app/views/packages/releases.py index af17c93..2199675 100644 --- a/app/views/packages/releases.py +++ b/app/views/packages/releases.py @@ -96,7 +96,7 @@ def create_release_page(package): return redirect(url_for("check_task", id=rel.task_id, r=rel.getEditURL())) else: - uploadedPath = doFileUpload(form.fileUpload.data, ["zip"], "a zip file") + uploadedPath = doFileUpload(form.fileUpload.data, "zip", "a zip file") if uploadedPath is not None: rel = PackageRelease() rel.package = package diff --git a/app/views/packages/screenshots.py b/app/views/packages/screenshots.py index cba7030..dbb002b 100644 --- a/app/views/packages/screenshots.py +++ b/app/views/packages/screenshots.py @@ -49,7 +49,7 @@ def create_screenshot_page(package, id=None): # Initial form class from post data and default data form = CreateScreenshotForm() if request.method == "POST" and form.validate(): - uploadedPath = doFileUpload(form.fileUpload.data, ["png", "jpg", "jpeg"], + uploadedPath = doFileUpload(form.fileUpload.data, "image", "a PNG or JPG image file") if uploadedPath is not None: ss = PackageScreenshot()