fplot_stats_plots.f90 Source File


Contents

Source Code


Source Code

module fplot_stats_plots
    use iso_fortran_env
    use fplot_plot_object
    use fplot_plot
    use fplot_plot_data_2d
    use fplot_plot_data_histogram
    use fplot_plot_2d
    use fplot_multiplot
    use fplot_terminal
    use fplot_constants
    use fplot_errors
    use fplot_colors
    use fplot_plot_axis
    use collections
    use strings
    use ferror
    implicit none
    private
    public :: correlation_plot

    type, extends(plot_object) :: correlation_plot
        !! Defines a multiplot arrangement designed to illustrate correlation
        !! between data sets.
        type(multiplot), private :: m_plt
            !! The multiplot object.
    contains
        procedure, public :: get_command_string => cp_get_command
        procedure, public :: initialize => cp_init
        procedure, public :: get_row_count => cp_get_rows
        procedure, public :: get_column_count => cp_get_cols
        procedure, public :: get_plot_count => cp_get_count
        procedure, public :: draw => cp_draw
        procedure, public :: save_file => cp_save
        procedure, public :: get => cp_get
        procedure, public :: get_terminal => cp_get_term
        procedure, public :: get_font_name => cp_get_font
        procedure, public :: set_font_name => cp_set_font
        procedure, public :: get_font_size => cp_get_font_size
        procedure, public :: set_font_size => cp_set_font_size
    end type

contains
! ------------------------------------------------------------------------------
    function cp_get_command(this) result(x)
        !! Gets the GNUPLOT commands for this object.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        character(len = :), allocatable :: x
            !! The command string.
    end function

! ------------------------------------------------------------------------------
    subroutine cp_init(this, x, labels, term, width, height, err)
        !! Initializes the correlation_plot object.
        class(correlation_plot), intent(inout) :: this
            !! The correlation_plot object.
        real(real64), intent(in), dimension(:,:) :: x
            !! The data to plot with each column representing a data set.
        type(string), intent(in), optional, dimension(:) :: labels
            !! An optional array containing a label to associate with each
            !! data set in x.  If supplied, this array must have the same length
            !! as x has columns.
        integer(int32), intent(in), optional :: term
            !! An optional input that is used to define the terminal.  The 
            !! default terminal is a WXT terminal.  The acceptable inputs are:
            !!
            !!  - GNUPLOT_TERMINAL_PNG
            !!
            !!  - GNUPLOT_TERMINAL_QT
            !!
            !!  - GNUPLOT_TERMINAL_WIN32
            !!
            !!  - GNUPLOT_TERMINAL_WXT
            !!
            !!  - GNUPLOT_TERMINAL_LATEX
        integer(int32), intent(in), optional :: width
            !! Optionally, the width of the plot window.
        integer(int32), intent(in), optional :: height
            !! Optionally, the height of the plot window.
        class(errors), intent(inout), optional, target :: err
            !! An error handling object.

        ! Local Variables
        integer(int32) :: i, j, k, t, n, flag
        real(real64) :: m, b
        real(real64), allocatable, dimension(:) :: mdl
        class(errors), pointer :: errmgr
        type(errors), target :: deferr
        type(plot_2d), allocatable, dimension(:) :: plts
        type(plot_data_2d) :: pdata, mdata
        type(plot_data_histogram) :: hdata
        class(plot_axis), pointer :: xAxis, yAxis
        
        ! Initialization
        if (present(err)) then
            errmgr => err
        else
            errmgr => deferr
        end if
        n = size(x, 2)
        call this%m_plt%initialize(n, n, term = term, width = width, &
            height = height, err = errmgr)
        if (errmgr%has_error_occurred()) return
        allocate(plts(n * n), stat = flag)
        if (flag /= 0) then
            call report_memory_error(errmgr, "cp_init", flag)
            return
        end if
        call this%m_plt%set_font_size(11)    ! use a small font size

        ! Input Checking
        if (present(labels)) then
            if (size(labels) /= n) then
                call report_array_size_mismatch_error(errmgr, "cp_init", &
                    "labels", n, size(labels))
                return
            end if
        end if

        ! Create plots
        k = 0
        call pdata%set_draw_line(.false.)
        call pdata%set_draw_markers(.true.)
        call pdata%set_marker_style(MARKER_FILLED_CIRCLE)
        call pdata%set_marker_scaling(0.5)
        call mdata%set_line_width(2.0)
        call mdata%set_line_color(CLR_BLACK)
        if (errmgr%has_error_occurred()) return
        do j = 1, n
            do i = 1, n
                k = k + 1
                call plts(k)%initialize(err = errmgr)
                if (errmgr%has_error_occurred()) return
                if (i == j) then
                    ! Plot a histogram of the data
                    call hdata%define_data(x(:,i), err = errmgr)
                    if (errmgr%has_error_occurred()) return
                    call plts(k)%push(hdata)
                else
                    ! Plot a scatter plot
                    call pdata%define_data(x(:,j), x(:,i), err = errmgr)
                    if (errmgr%has_error_occurred()) return
                    call plts(k)%push(pdata)

                    ! Fit a line to the data
                    call compute_linear_fit(x(:,j), x(:,i), m, b)
                    mdl = m * x(:,j) + b

                    ! Plot the fitted line
                    call mdata%define_data(x(:,j), mdl, err = err)
                    if (errmgr%has_error_occurred()) return
                    call plts(k)%push(mdata)
                end if

                ! Deal with axis labels
                if (j == 1) then
                    ! Display y axis labels for these plots
                    yAxis => plts(k)%get_y_axis()
                    if (present(labels)) then
                        call yAxis%set_title(char(labels(i)))
                    else
                        call yAxis%set_title(char("x_{" // to_string(i) // "}"))
                    end if
                end if

                ! Get an x-axis object for the plot
                xAxis => plts(k)%get_x_axis()

                ! Define axis labels
                if (i == n) then
                    ! Display x axis labels for these plots
                    if (present(labels)) then
                        call xAxis%set_title(char(labels(j)))
                    else
                        call xAxis%set_title(char("x_{" // to_string(j) // "}"))
                    end if
                end if

                ! Rotate histogram tic labels
                call xAxis%set_tic_label_angle(45.0)
                call xAxis%set_tic_label_rotation_origin(GNUPLOT_ROTATION_ORIGIN_RIGHT)

                ! Store the plot - the collection makes a copy of the plot and
                ! manages it's lifetime
                call this%m_plt%set(i, j, plts(k))
            end do
        end do
    end subroutine

! ------------------------------------------------------------------------------
    pure function cp_get_rows(this) result(x)
        !! Gets the number of rows of plots.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        integer(int32) :: x
            !! The row count.

        x = this%m_plt%get_row_count()
    end function

! --------------------
    pure function cp_get_cols(this) result(x)
        !! Gets the number of columns of plots.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        integer(int32) :: x
            !! The column count.

        x = this%m_plt%get_column_count()
    end function

! --------------------
    pure function cp_get_count(this) result(x)
        !! Gets the total number of plots.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        integer(int32) :: x
            !! The plot count.

        x = this%m_plt%get_plot_count()
    end function
    
! ------------------------------------------------------------------------------
    subroutine cp_draw(this, persist, err)
        !! Launches GNUPLOT and draws the correlation_plot per the current 
        !! state of the command list.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        logical, intent(in), optional :: persist
            !! An optional parameter that can be used to keep GNUPLOT open.  
            !! Set to true to force GNUPLOT to remain open; else, set to false
            !! to allow GNUPLOT to close after drawing.  The default is true.
        class(errors), intent(inout), optional, target :: err
                !! An error handling object.
        
        call this%m_plt%draw(persist, err)
    end subroutine

! ------------------------------------------------------------------------------
    subroutine cp_save(this, fname, err)
        !! Saves a GNUPLOT command file.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        character(len = *), intent(in) :: fname
            !! The filename.
        class(errors), intent(inout), optional, target :: err
            !! An error handling object.

        call this%m_plt%save_file(fname, err)
    end subroutine

! ------------------------------------------------------------------------------
    function cp_get(this, i, j) result(x)
        !! Gets the requested plot object.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        integer(int32), intent(in) :: i
            !! The row index of the plot to retrieve.
        integer(int32), intent(in) :: j
            !! The column index of the plot to retrieve.
        class(plot), pointer :: x
            !! A pointer to the plot object.

        x => this%m_plt%get(i, j)
    end function

! ------------------------------------------------------------------------------
    function cp_get_term(this) result(x)
        !! Gets the GNUPLOT terminal object.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        class(terminal), pointer :: x
            !! A pointer to the terminal object.
        x => this%m_plt%get_terminal()
    end function

! ------------------------------------------------------------------------------
    function cp_get_font(this) result(x)
        !! Gets the name of the font used for plot text.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        character(len = :), allocatable :: x
            !! The font name.
        x = this%m_plt%get_font_name()
    end function

! --------------------
    subroutine cp_set_font(this, x)
        !! Sets the name of the font used for plot text.
        class(correlation_plot), intent(inout) :: this
            !! The correlation_plot object.
        character(len = *), intent(in) :: x
            !! The font name.
        call this%m_plt%set_font_name(x)
    end subroutine

! ------------------------------------------------------------------------------
    function cp_get_font_size(this) result(x)
        !! Gets the size of the font used by the plot.
        class(correlation_plot), intent(in) :: this
            !! The correlation_plot object.
        integer(int32) :: x
            !! The font size.
        x = this%m_plt%get_font_size()
    end function

! --------------------
    subroutine cp_set_font_size(this, x)
        !! Sets the size of the font used by the plot.
        class(correlation_plot), intent(inout) :: this
            !! The correlation_plot object.
        integer(int32), intent(in) :: x
            !! The font size.
        call this%m_plt%set_font_size(x)
    end subroutine

! ******************************************************************************
! PRIVATE HELPER ROUTINES
! ------------------------------------------------------------------------------
subroutine compute_linear_fit(x, y, m, b)
    !! Computes the coefficients of a linear equation (y = m * x + b) using a
    !! least-squares approach.
    real(real64), intent(in), dimension(:) :: x
        !! The x-coordinate data.
    real(real64), intent(in), dimension(:) :: y
        !! The y-coordinate data.
    real(real64), intent(out) :: m
        !! The slope term.
    real(real64), intent(out) :: b
        !! The intercept term.

    ! Local Variables
    integer(int32) :: i, n
    real(real64) :: sumX, sumY, sumX2, sumY2, sumXY

    ! Initialization
    n = size(x)
    sumX = 0.0d0
    sumY = 0.0d0
    sumX2 = 0.0d0
    sumY2 = 0.0d0
    sumXY = 0.0d0

    ! Process
    do i = 1, n
        sumX = sumX + x(i)
        sumY = sumY + y(i)
        sumXY = sumXY + x(i) * y(i)
        sumX2 = sumX2 + (x(i))**2
        sumY2 = sumY2 + (y(i))**2
    end do
    m = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX**2)
    b = (sumY * sumX2 - sumX * sumXY) / (n * sumX2 - sumX**2)
end subroutine

! ------------------------------------------------------------------------------
end module