#!/usr/bin/env python
"""Encrypt each line of STDIN to STDOUT

Plaintext is string data; ciphertext is is Ascii85 encoded bytes

Because of the encoding, the ciphertext will be somewhat longer than the
plaintext, but the precise ratio will vary depending on what characters are
encoded.

Any Unicode plaintext can be encoded, but for non-ASCII inputs, UTF-8 encoding
will occur initially, which will increase the length of the byte encoding of
the plaintext even before the lengthening Ascii85 encoding is applied to the
ciphertext bytes.
"""

from base64 import a85encode, a85decode
import hashlib
import sys


def amateur_encrypt(plaintext: str, key: str) -> bytes:
    encoded_text = plaintext.encode()
    m = hashlib.sha256()
    m.update(key.encode())
    # Expand length of key to match the full plaintext
    hashed_key = m.digest() * (1 + len(encoded_text)//32)
    ciphertext0 = b"".join(
        (a ^ b).to_bytes()
        for a, b in zip(encoded_text, hashed_key))
    ciphertext = b"".join(
        (a ^ b).to_bytes()
        for a, b in zip(reversed(ciphertext0), hashed_key))
    return ciphertext


def amateur_decrypt(ciphertext: bytes, key: str) -> str:
    m = hashlib.sha256()
    m.update(key.encode())
    hashed_key = m.digest() * (1 + len(ciphertext)//32)
    plainbytes0 = b"".join(
        (a ^ b).to_bytes()
        for a, b in zip(ciphertext, hashed_key))
    plainbytes = b"".join(
        (a ^ b).to_bytes()
        for a, b in zip(reversed(plainbytes0), hashed_key))
    return plainbytes.decode()


if __name__ == "__main__":
    if len(sys.argv) != 3:
        print(
            "Usage: amateur-crypt [-e|-d] passwd < src > dst",
            file=sys.stderr
        )
    else:
        mode = sys.argv[1]
        key = sys.argv[2]
        if mode == "-e":
            for line in sys.stdin:
                ciphertext = amateur_encrypt(line, key)
                print(a85encode(ciphertext).decode())
        elif mode == "-d":
            # Explanatory lines prior to the row of 80 dashes are ignored
            start = False
            for line in sys.stdin:
                if not start:
                    if line.startswith("-"*80):
                        start = True
                    continue
                line = a85decode(line)
                print(amateur_decrypt(line, key), end="")
        else:
            print(
                "Usage: amateur-crypt [-e|-d] passwd < src > dst",
                file=sys.stderr
            )
