import urllib
from io import StringIO

import matplotlib.pyplot as plt
import numpy as np
import argparse

import requests


class NISTData:
    # Adapted from Bernd's script
    code = {'H2O': "C7732185", 'CO2': "C124389"}

    def __init__(self):

        self.query = {
            "Action": "Data",
            "Wide": "on",
            "Type": "IsoTherm",
            "Digits": "12",
            "PLow": "2e7",
            "PHigh": "3e7",
            "PInc": "1e6",
            "RefState": "IIR",
            "TUnit": "K",
            "PUnit": "Pa",
            "DUnit": "kg/m3",
            "HUnit": "kJ/kg",
            "WUnit": "m/s",
            "VisUnit": "uPas",
            "STUnit": "N/m",
        }

    def getdata_(self, code, temperature):

        self.query["ID"] = code
        self.query["T"] = temperature
        response = requests.get(
            "https://webbook.nist.gov/cgi/fluid.cgi?" + urllib.parse.urlencode(self.query)
        )
        response.encoding = "utf-8"

        text = response.text
        phase = np.genfromtxt(StringIO(text), delimiter="\t", dtype=str, usecols=[-1], skip_header=1)
        values = np.genfromtxt(StringIO(text), delimiter="\t", names=True)

        # NIST provides additional samples at the transition points (if there is a
        # phase transition within the requested data range). Since the code which
        # uses the tables generated by this script can't deal with these additional
        # sample points, they are removed.
        phaseBoundaryIndices = []
        for j in range(1, len(phase) - 1):
            if phase[j] != phase[j + 1]:
                phaseBoundaryIndices += [j, j + 1]

        pressure = np.delete(values["Pressure_Pa"], phaseBoundaryIndices)
        density = np.delete(values["Density_kgm3"], phaseBoundaryIndices)
        viscosity = np.delete(values["Viscosity_uPas"], phaseBoundaryIndices)
        # transform unit (1e-6.Pa.s -> Pa.s)
        viscosity *= 1e-6
        enthalpy = np.delete(values["Enthalpy_kJkg"], phaseBoundaryIndices)
        # transform unit (kJ/kg -> J/kg)

        cv = np.delete(values["Cv_JgK"], phaseBoundaryIndices)
        cp = np.delete(values["Cp_JgK"], phaseBoundaryIndices)
        thermCond = np.delete(values["Therm_Cond_WmK"], phaseBoundaryIndices)
        # transform unit (kJ/kg -> J/kg)
        enthalpy *= 1000
        cv *= 1000
        cp *= 1000
        return  np.vstack([pressure, density, viscosity, enthalpy, thermCond, cv, cp]).transpose()

    def getCO2Data(self,temperature):
        return self.getdata_(self.code['CO2'],temperature)

    def getH2OData(self,temperature):
        return self.getdata_(self.code['H2O'],temperature)

def main():
    # Initialize the argument parser
    parser = argparse.ArgumentParser(description="Script to generate figure from tutorial.")

    # Add arguments to accept individual file paths
    parser.add_argument('--outputDir', help='Path to output directory', default='.')

    # Parse the command-line arguments
    args = parser.parse_args()

    # File path
    base = args.outputDir

    fsize = 25
    msize = 12
    lw = 4
    fig, ax = plt.subplots(2, 2, figsize=(32, 18))
    cmap = plt.get_cmap("tab10")

    # On one hand data queried from NIST data base thanks to Bernd's script
    # On the other, data produce from input files with PVTDriver
    # NIST data goes T,P,rho,mu,H,KCP,cv,cp for each phase
    # geos is produced temperature-wise , none, temp ,total dens, gas dens, water dens, gas visc, water visc
    geos_data = 8 * [None]
    te = [283, 293, 303, 313, 323, 333, 343, 353]
    for i, ti in enumerate(te):
        geos_data[i] = np.loadtxt(f'{base}/geos_pvt_{ti}.txt', skiprows=7)

    for i in range(7):
        ax[0, 0].plot(geos_data[i][:, 1], geos_data[i][:, 6], label=f'T={te[i]}', ls=':', lw=lw)
    for i in range(7):
        nist_co2 = NISTData().getCO2Data(te[i])
        ax[0, 0].plot(nist_co2[:, 0], nist_co2[:, 1], marker='*',markersize=lw)
    ax[0, 0].legend()
    ax[0, 0].set_xlabel('Pressure [Pa]', size=fsize)
    ax[0, 0].set_ylabel('Gas densities [kg / m3]', size=fsize)

    for i in range(7):
        ax[0, 1].plot(geos_data[i][:, 1], geos_data[i][:, 7], label=f'T={te[i]}', ls=':', lw=lw)
    for i in range(7):
        nist_h2o = NISTData().getH2OData(te[i])
        ax[0, 1].plot(nist_h2o[:, 0], nist_h2o[:, 1], marker='*', markersize=lw)
    ax[0, 1].legend()
    ax[0, 1].set_xlabel('Pressure [Pa]', size=fsize)
    ax[0, 1].set_ylabel('Water densities [kg / m3]', size=fsize)

    for i in range(7):
        ax[1, 0].plot(geos_data[i][:, 1], geos_data[i][:, 8], label=f'T={te[i]}', ls=':', lw=lw)
    for i in range(7):
        nist_co2 = NISTData().getCO2Data(te[i])
        ax[1, 0].plot(nist_co2[:, 0], nist_co2[:, 2], marker='*',markersize=lw)
    ax[1, 0].legend()
    ax[1, 0].set_xlabel('Pressure [Pa]', size=fsize)
    ax[1, 0].set_ylabel('Gas viscosity [Pa s]', size=fsize)

    for i in range(7):
        ax[1, 1].plot(geos_data[i][:, 1], geos_data[i][:, 9], label=f'T={te[i]}', ls=':', lw=lw)
    for i in range(7):
        nist_h2o = NISTData().getH2OData(te[i])
        ax[1, 1].plot(nist_h2o[:, 0], nist_h2o[:, 2], marker='*',markersize=lw)
    ax[1, 1].legend()
    ax[1, 1].set_xlabel('Pressure [Pa]', size=fsize)
    ax[1, 1].set_ylabel('Water viscosity [Pa s]', size=fsize)

    plt.show()


if __name__ == "__main__":
    main()
