#!/usr/bin/env python3
from __future__ import annotations

import csv
from collections import defaultdict
from datetime import datetime
from pathlib import Path

BASE = Path(__file__).resolve().parents[1]
INPUT = BASE / '11_Input'
OUTPUT = BASE / '12_Output'
OUTPUT.mkdir(exist_ok=True)

CUSTOMERS = INPUT / 'CustomerList.csv'
MEMBERS = INPUT / 'MembershipCustomers.csv'
TXNS = INPUT / 'OrderTransactionDetail.csv'
PACKAGES = INPUT / 'PrepayServices.csv'


def norm(s: str) -> str:
    return (s or '').strip()


def lower(s: str) -> str:
    return norm(s).lower()


def dt(value: str):
    value = norm(value)
    if not value:
        return None
    for fmt in ('%m/%d/%Y %H:%M:%S', '%m/%d/%Y %I:%M:%S %p', '%m/%d/%Y %H:%M', '%m/%d/%Y'):
        try:
            return datetime.strptime(value, fmt)
        except ValueError:
            pass
    return None


customers = {}
email_to_customer = {}
name_to_customer = defaultdict(set)

with CUSTOMERS.open(newline='', encoding='utf-8-sig') as f:
    for row in csv.DictReader(f):
        cid = norm(row.get('CustomerId'))
        if not cid:
            continue
        first = norm(row.get('FirstName'))
        last = norm(row.get('LastName'))
        full = ' '.join(x for x in [first, last] if x).strip()
        email = lower(row.get('EmailAddress'))
        customers[cid] = {
            'CustomerId': cid,
            'AccountId': norm(row.get('AccountId')),
            'FirstName': first,
            'LastName': last,
            'CustomerName': full,
            'EmailAddress': email,
            'CellPhone': norm(row.get('CellPhone')),
            'CreatedDate': norm(row.get('CreatedDate')),
            'IsEmailValid': norm(row.get('IsEmailValid')),
            'OptOut': norm(row.get('OptOut')),
            'MembershipStatus': '',
            'MembershipPlan': '',
            'HasPackage': 'No',
            'PackageNames': '',
            'RemainingPackageUnits': 0.0,
            'PackageCount': 0,
            'BookingCount': 0,
            'LastBooking': '',
            'GrossRevenueProxy': 0.0,
            'PaymentTxnCount': 0,
        }
        if email:
            email_to_customer[email] = cid
        if full:
            name_to_customer[lower(full)].add(cid)

with MEMBERS.open(newline='', encoding='utf-8-sig') as f:
    for row in csv.DictReader(f):
        cid = norm(row.get('CustomerId'))
        if cid in customers:
            customers[cid]['MembershipStatus'] = norm(row.get('MembershipStatus'))
            customers[cid]['MembershipPlan'] = norm(row.get('Membership'))

with PACKAGES.open(newline='', encoding='utf-8-sig') as f:
    for row in csv.DictReader(f):
        cid = norm(row.get('CustomerID'))
        if cid not in customers:
            continue
        customers[cid]['HasPackage'] = 'Yes'
        desc = norm(row.get('Description'))
        rem = float(norm(row.get('RemainingUnits')) or 0)
        customers[cid]['RemainingPackageUnits'] += rem
        customers[cid]['PackageCount'] += 1
        existing = set(filter(None, customers[cid]['PackageNames'].split(' | ')))
        if desc:
            existing.add(desc)
        customers[cid]['PackageNames'] = ' | '.join(sorted(existing))

with TXNS.open(newline='', encoding='utf-8-sig') as f:
    for row in csv.DictReader(f):
        email = lower(row.get('EmailAddress'))
        name = lower(row.get('CustomerName'))
        cid = email_to_customer.get(email)
        if not cid and name in name_to_customer and len(name_to_customer[name]) == 1:
            cid = next(iter(name_to_customer[name]))
        if not cid or cid not in customers:
            continue

        start = dt(row.get('StartTime'))
        if start:
            customers[cid]['BookingCount'] += 1
            prev = dt(customers[cid]['LastBooking'])
            if not prev or start > prev:
                customers[cid]['LastBooking'] = start.strftime('%Y-%m-%d %H:%M:%S')

        total = norm(row.get('Total'))
        try:
            total_val = float(total)
        except ValueError:
            total_val = 0.0
        item_name = lower(row.get('ItemName'))
        transaction_state = lower(row.get('TransactionState'))
        payment_method = norm(row.get('PaymentMethod'))

        if payment_method and not start:
            customers[cid]['PaymentTxnCount'] += 1
        if total_val > 0 and 'included with membership' not in item_name and transaction_state not in {'declined'}:
            customers[cid]['GrossRevenueProxy'] += total_val

out = OUTPUT / 'customer_master_summary.csv'
fields = [
    'CustomerId', 'AccountId', 'CustomerName', 'EmailAddress', 'CellPhone',
    'MembershipStatus', 'MembershipPlan', 'HasPackage', 'PackageNames',
    'PackageCount', 'RemainingPackageUnits', 'BookingCount', 'LastBooking',
    'GrossRevenueProxy', 'PaymentTxnCount', 'IsEmailValid', 'OptOut', 'CreatedDate'
]
with out.open('w', newline='', encoding='utf-8') as f:
    w = csv.DictWriter(f, fieldnames=fields)
    w.writeheader()
    for row in sorted(customers.values(), key=lambda r: (r['GrossRevenueProxy'], r['BookingCount']), reverse=True):
        row = row.copy()
        row['GrossRevenueProxy'] = f"{row['GrossRevenueProxy']:.2f}"
        row['RemainingPackageUnits'] = f"{row['RemainingPackageUnits']:.2f}"
        w.writerow({k: row.get(k, '') for k in fields})

summary = OUTPUT / 'merged_topline_summary.txt'
active_members = sum(1 for r in customers.values() if lower(r['MembershipStatus']) == 'active')
package_holders = sum(1 for r in customers.values() if r['HasPackage'] == 'Yes')
open_units = sum(r['RemainingPackageUnits'] for r in customers.values())
booked_customers = sum(1 for r in customers.values() if r['BookingCount'] > 0)
revenue_proxy = sum(r['GrossRevenueProxy'] for r in customers.values())
summary.write_text(
    '\n'.join([
        f'Active members: {active_members}',
        f'Package holders: {package_holders}',
        f'Open package units: {open_units:.2f}',
        f'Customers with bookings in transaction export: {booked_customers}',
        f'Gross revenue proxy: ${revenue_proxy:.2f}',
        f'Customer master summary: {out.name}',
    ]) + '\n',
    encoding='utf-8'
)

print(f'Wrote {out}')
print(f'Wrote {summary}')
