"""Strip weights from ONNX models to reduce file size for version control."""
from pathlib import Path
import argparse
import onnx


def strip_weights_from_onnx(input_path: str, output_path: str) -> None:
    """
    Remove weights from ONNX model by clearing initializer data.

    Args:
        input_path: Path to input ONNX model with weights
        output_path: Path to save stripped ONNX model without weights
    """
    # Load the model
    model = onnx.load(input_path, load_external_data=False)

    # Clear all initializer tensors (weights, biases, constants)
    for initializer in model.graph.initializer:
        # Keep the tensor structure but clear the data
        initializer.raw_data = b''
        # Clear external data location if present
        if initializer.HasField('data_location') and initializer.data_location == onnx.TensorProto.EXTERNAL:
            initializer.data_location = onnx.TensorProto.DEFAULT
            # Clear external data fields by deleting them one by one
            while len(initializer.external_data) > 0:
                del initializer.external_data[0]
        # Dynamically clear all *_data fields
        for attr_name in dir(initializer):
            if attr_name.endswith('_data') and not attr_name.startswith('_'):
                attr = getattr(initializer, attr_name)
                # Clear repeated fields (lists)
                if hasattr(attr, '__delitem__'):
                    try:
                        del attr[:]
                    except TypeError:
                        pass  # Skip if deletion not supported

    # Save the stripped model (no external data)
    onnx.save(model, output_path, save_as_external_data=False)

    # Print size comparison
    input_size = Path(input_path).stat().st_size / (1024 * 1024)  # MB
    output_size = Path(output_path).stat().st_size / (1024 * 1024)  # MB
    print(f"Original: {input_size:.2f} MB")
    print(f"Stripped: {output_size:.2f} MB")
    print(f"Reduction: {(1 - output_size/input_size)*100:.1f}%")


def strip_weights_from_directory(directory: str | Path) -> None:
    """
    Strip weights from all ONNX models in a directory.

    Args:
        directory: Path to directory containing ONNX files
    """
    target_dir = Path(directory).resolve()

    if not target_dir.exists():
        print(f"Error: Directory not found: {target_dir}")
        return

    if not target_dir.is_dir():
        print(f"Error: Not a directory: {target_dir}")
        return

    # Find all ONNX files (excluding already stripped ones)
    onnx_files = [f for f in target_dir.glob("*.onnx")
                  if not f.stem.endswith("_no_weights")]

    if not onnx_files:
        print(f"No ONNX files found in: {target_dir}")
        return

    print(f"Processing directory: {target_dir}")
    for onnx_file in onnx_files:
        print(f"\nProcessing: {onnx_file.name}")
        output_file = target_dir / f"{onnx_file.stem}_no_weights.onnx"

        try:
            print(f"  Input: {onnx_file}")
            print(f"  Output: {output_file}")
            strip_weights_from_onnx(str(onnx_file), str(output_file))
            print(f"Saved to: {output_file.name}")
        except Exception as e:  # pylint: disable=W0718
            print(f"Error processing {onnx_file.name}: {e}")


def main():
    """Strip weights from all ONNX models in the specified directory."""
    parser = argparse.ArgumentParser(
        description="Strip weights from ONNX models to reduce file size"
    )
    parser.add_argument(
        "--dir",
        type=str,
        default=None,
        help="Directory containing ONNX files (default: current script directory)"
    )
    args = parser.parse_args()

    # Use specified directory or script directory
    target_dir = args.dir if args.dir else Path(__file__).parent

    strip_weights_from_directory(target_dir)


if __name__ == "__main__":
    main()
