split_covariate.py

View page source

split_covariate: Python Source Code

# ----------------------------------------------------------------------------
# imports
# ----------------------------------------------------------------------------
import math
import sys
import os
import copy
import time
import csv
import random
import shutil
import dismod_at
from math import exp
#
# import at_cascade with a preference current directory version
current_directory = os.getcwd()
if os.path.isfile( current_directory + '/at_cascade/__init__.py' ) :
    sys.path.insert(0, current_directory)
import at_cascade
# -----------------------------------------------------------------------------
# global variables
# -----------------------------------------------------------------------------
# BEGIN fit_goal_set
fit_goal_set = { 'n3', 'n4', 'n5', 'n6' }
# END fit_goal_set
#
# BEGIN option_all_table
option_all            = {
    'result_dir':                 'build/example',
    'root_node_name':             'n0',
    'root_split_reference_name':  'both',
    'split_covariate_name':       'sex',
    'shift_prior_std_factor':      1e3,
}
option_all['root_database'] = option_all['result_dir'] + '/root.db'
# END option_all_table
#
#
# BEGIN split_reference_table
split_reference_table = [
    {'split_reference_name': 'female', 'split_reference_value': 1.0},
    {'split_reference_name': 'both',   'split_reference_value': 2.0},
    {'split_reference_name': 'male',   'split_reference_value': 3.0},
]
split_reference_list = list()
for row in split_reference_table :
    split_reference_list.append( row['split_reference_value'] )
# END split_reference_table
# BEGIN node_split_table
node_split_table = [ { 'node_name' :   'n0'} ]
# END node_split_table
#
# BEGIN root_split_reference_id
root_split_reference_id = 1
assert  \
split_reference_table[root_split_reference_id]['split_reference_name']=='both'

# END root_split_reference_id
#
# BEGIN avg_income
avg_income = dict()
leaf_node_set     = { 3, 4, 5, 6 }
for node_id in leaf_node_set :
    node_name = 'n' + str(node_id)
    avg_income[node_name] = [ 1.0 - node_id / 10.0, 1.0, 1.0 + node_id / 10.0 ]
# child_list
# children of node 0, 1, 2 in that order
child_list = [ (1,2), (3,4), (5,6) ]
for node_id in [2, 1, 0] :
    avg_list = list()
    for split_reference_id in range(3) :
        avg = 0.0
        for child_id in child_list[node_id] :
            child_name = 'n' + str(child_id)
            avg += avg_income[child_name][split_reference_id]
        avg = avg / len( child_list[node_id] )
        avg_list.append( avg )
    node_name = 'n' + str(node_id)
    #
    avg_income[node_name] = avg_list
# END avg_income
#
# BEGIN alpha_true
alpha_true = - 0.2
# END alpha_true
# ----------------------------------------------------------------------------
# functions
# ----------------------------------------------------------------------------
# BEGIN rate_true
def rate_true(rate, a, t, n, c) :
    # true_iota
    true_iota = {
        'n3' : 1e-2,
        'n4' : 2e-2,
        'n5' : 3e-2,
        'n6' : 4e-2
    }
    true_iota['n1'] = (true_iota['n3'] + true_iota['n4']) / 2.9
    true_iota['n2'] = (true_iota['n5'] + true_iota['n6']) / 2.9
    true_iota['n0'] = (true_iota['n1'] + true_iota['n2']) / 2.9
    #
    # effect
    sex    = c[0]
    income = c[1]
    #
    # split_reference_id
    split_reference_id = None
    for (row_id, row) in enumerate(split_reference_table) :
        if row['split_reference_value'] == sex :
            split_reference_id = row_id
    #
    r_income = avg_income[n][split_reference_id]
    effect   = alpha_true * ( income - r_income )
    #
    if rate == 'iota' :
        return true_iota[n] * exp(effect)
    if rate == 'omega' :
        return 2.0 * true_iota[n] * exp(effect)
    return 0.0
# END rate_true
# ----------------------------------------------------------------------------
def root_node_db(file_name) :
    #
    # iota_n0
    sex       = split_reference_list[root_split_reference_id]
    income    = avg_income['n0'][root_split_reference_id]
    c         = [ sex, income ]
    iota_n0   = rate_true('iota', None, None, 'n0', c)
    #
    # prior_table
    prior_table = list()
    prior_table.append(
        # BEGIN parent_value_prior
        {    'name':    'parent_value_prior',
            'density': 'gaussian',
            'lower':   iota_n0 / 10.0,
            'upper':   iota_n0 * 10.0,
            'mean':    iota_n0 ,
            'std':     iota_n0 * 10.0,
            'eta':     iota_n0 * 1e-3
        }
        # END parent_value_prior
    )
    prior_table.append(
        # BEGIN alpha_value_prior
        {    'name':    'alpha_value_prior',
            'density': 'gaussian',
            'lower':   - 10 * abs(alpha_true),
            'upper':   + 10 * abs(alpha_true),
            'std':     + 10 * abs(alpha_true),
            'mean':    0.0,
        }
        # END alpha_value_prior
    )
    #
    # smooth_table
    smooth_table = list()
    #
    # parent_smooth
    fun = lambda a, t : ('parent_value_prior', None, None)
    smooth_table.append({
        'name':       'parent_smooth',
        'age_id':     [0],
        'time_id':    [0],
        'fun':        fun,
    })
    #
    # alpha_smooth
    fun = lambda a, t : ('alpha_value_prior', None, None)
    smooth_table.append({
        'name':       'alpha_smooth',
        'age_id':     [0],
        'time_id':    [0],
        'fun':        fun,
    })
    #
    # node_table
    node_table = [
        { 'name':'n0',        'parent':''   },
        { 'name':'n1',        'parent':'n0' },
        { 'name':'n2',        'parent':'n0' },
        { 'name':'n3',        'parent':'n1' },
        { 'name':'n4',        'parent':'n1' },
        { 'name':'n5',        'parent':'n2' },
        { 'name':'n6',        'parent':'n2' },
    ]
    #
    # rate_table
    rate_table = [ {
        'name':           'iota',
        'parent_smooth':  'parent_smooth',
        'child_smooth':   None ,
    } ]
    #
    # covariate_table
    covariate_table = list()
    sex    = split_reference_list[root_split_reference_id]
    income =  avg_income['n0'][root_split_reference_id]
    covariate_table.append(
        { 'name': 'sex',      'reference': sex, 'max_difference': 1.1 }
    )
    covariate_table.append( { 'name':  'income',  'reference': income } )
    #
    # mulcov_table
    mulcov_table = [ {
        # alpha
        'covariate':  'income',
        'type':       'rate_value',
        'effected':   'iota',
        'group':      'world',
        'smooth':     'alpha_smooth',
    } ]
    #
    # subgroup_table
    subgroup_table = [ {'subgroup': 'world', 'group':'world'} ]
    #
    # integrand_table
    integrand_table = [ {'name':'Sincidence'} ]
    for mulcov_id in range( len(mulcov_table) ) :
        integrand_table.append( { 'name': f'mulcov_{mulcov_id}' } )
    #
    # avgint_table
    avgint_table = list()
    row = {
        'node':         'n0',
        'subgroup':     'world',
        'weight':       '',
        'time_lower':   2000.0,
        'time_upper':   2000.0,
        'age_lower':    50.0,
        'age_upper':    50.0,
        'sex':          None,
        'income':       None,
        'integrand':    'Sincidence',
    }
    avgint_table.append( copy.copy(row) )
    #
    # data_table
    data_table  = list()
    leaf_set    = { 'n3', 'n4', 'n5', 'n6' }
    row = {
        'subgroup':     'world',
        'weight':       '',
        'time_lower':   2000.0,
        'time_upper':   2000.0,
        'age_lower':      50.0,
        'age_upper':      50.0,
        'integrand':    'Sincidence',
        'density':      'gaussian',
        'hold_out':     False,
    }
    assert split_reference_table[0]['split_reference_name'] == 'female'
    assert split_reference_table[2]['split_reference_name'] == 'male'
    for split_reference_id in [ 0, 2 ] :
        for node in leaf_set :
            sex      = split_reference_list[split_reference_id]
            r_income = avg_income[node][split_reference_id]
            for factor in [ 0.5, 1.0, 1.5 ] :
                income = factor * r_income
                c      = [sex, income]
                meas_value = rate_true('iota', None, None, node, c)
                row['node']       = node
                row['meas_value'] = meas_value
                row['sex']        = sex
                row['income']     = income
                row['meas_std']   = meas_value / 10.0
                data_table.append( copy.copy(row) )
    #
    # age_grid
    age_grid = [ 0.0, 100.0 ]
    #
    # time_grid
    time_grid = [ 1980.0, 2020.0 ]
    #
    # weight table:
    weight_table = list()
    #
    # nslist_table
    nslist_table = dict()
    #
    # option_table
    option_table = [
        { 'name':'parent_node_name',      'value':'n0'},
        { 'name':'rate_case',             'value':'iota_pos_rho_zero'},
        { 'name': 'zero_sum_child_rate',  'value':'iota'},
        { 'name':'quasi_fixed',           'value':'false'},
        { 'name':'max_num_iter_fixed',    'value':'50'},
        { 'name':'tolerance_fixed',       'value':'1e-8'},
    ]
    # ----------------------------------------------------------------------
    # create database
    dismod_at.create_database(
        file_name,
        age_grid,
        time_grid,
        integrand_table,
        node_table,
        subgroup_table,
        weight_table,
        covariate_table,
        avgint_table,
        data_table,
        prior_table,
        smooth_table,
        nslist_table,
        rate_table,
        mulcov_table,
        option_table
    )
# ----------------------------------------------------------------------------
# main
# ----------------------------------------------------------------------------
def main(refit_split) :
    # BEGIN refit_split
    if refit_split :
        option_all['refit_split'] = 'true'
    else :
        option_all['refit_split'] = 'false'
    # END refit_split
    # -------------------------------------------------------------------------
    # result_dir
    result_dir = option_all['result_dir']
    at_cascade.empty_directory(result_dir)
    #
    # Create root.db
    root_database       = option_all['root_database']
    root_node_db(root_database)
    #
    # omega_grid
    connection   = dismod_at.create_connection(
        root_database, new = False, readonly = True
    )
    age_table    = dismod_at.get_table_dict(connection, 'age')
    time_table   = dismod_at.get_table_dict(connection, 'time')
    age_id_list  = list( range( len(age_table) ) )
    time_id_list = list( range( len(age_table) ) )
    omega_grid   = { 'age': age_id_list, 'time' : time_id_list }
    connection.close()
    #
    # n_split
    n_split  = len( split_reference_list )
    #
    # omega_data
    omega_data      = dict()
    for node_name in [ 'n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6' ] :
        omega_data[node_name] = list()
        for k in range(n_split) :
            omega_data[node_name].append( list() )
            for age_id in omega_grid['age'] :
                for time_id in omega_grid['time'] :
                    age    = age_table[age_id]['age']
                    time   = time_table[time_id]['time']
                    sex    = split_reference_list[k]
                    income = avg_income[node_name][k]
                    cov    = [ sex, income ]
                    omega  = rate_true('omega', None, None, node_name, cov)
                    omega_data[node_name][k].append( omega )
    #
    # Create all_node.db
    all_node_database = f'{result_dir}/all_node.db'
    at_cascade.create_all_node_db(
        all_node_database      = all_node_database,
        split_reference_table  = split_reference_table,
        node_split_table       = node_split_table,
        option_all             = option_all,
        omega_grid             = omega_grid,
        omega_data             = omega_data,
    )
    #
    # root_node_dir
    root_node_dir = f'{result_dir}/n0'
    os.mkdir(root_node_dir)
    #
    # avgint_table
    # This also erases the avgint table from root_database
    avgint_table = at_cascade.extract_avgint( root_database )
    #
    # cascade starting at root node
    at_cascade.cascade_root_node(
        all_node_database  = all_node_database ,
        fit_goal_set       = fit_goal_set      ,
    )
    #
    # check results
    for sex in [ 'female', 'male' ] :
        for subdir in [ 'n1/n3', 'n1/n4', 'n2/n5', 'n2/n6' ] :
            goal_database = f'{result_dir}/n0/{sex}/{subdir}/dismod.db'
            at_cascade.check_cascade_node(
                rate_true          = rate_true,
                all_node_database  = all_node_database,
                fit_database       = goal_database,
                avgint_table       = avgint_table,
                relative_tolerance = 1e-5,
            )
    #
    #
    # fit_iota, fit_alpha, fit_reference_income
    fit_database      = f'{result_dir}/n0/dismod.db'
    connection        = dismod_at.create_connection(
        fit_database, new = False, readonly = True
    )
    var_table         = dismod_at.get_table_dict(connection, 'var')
    fit_var_table     = dismod_at.get_table_dict(connection, 'fit_var')
    rate_table        = dismod_at.get_table_dict(connection, 'rate')
    prior_table       = dismod_at.get_table_dict(connection, 'prior')
    covariate_table   = dismod_at.get_table_dict(connection, 'covariate')
    connection.close()
    for (var_id, row) in enumerate(var_table) :
        rate_id   = row['rate_id']
        rate_name = rate_table[rate_id]['rate_name']
        if rate_name == 'iota' :
            if row['var_type'] == 'rate' :
                fit_iota = fit_var_table[var_id]['fit_var_value']
            else :
                assert row['var_type'] == 'mulcov_rate_value'
                fit_alpha = fit_var_table[var_id]['fit_var_value']
    for row in covariate_table :
        if row['covariate_name'] == 'income' :
            fit_reference_income = row['reference']
    #
    # shift_mean, shift_reference_income
    if refit_split :
        shift_database    = f'{result_dir}/n0/female/dismod.db'
    else :
        shift_database    = f'{result_dir}/n0/female/n1/dismod.db'
    connection        = dismod_at.create_connection(
        shift_database, new = False, readonly = True
    )
    rate_table        = dismod_at.get_table_dict(connection, 'rate')
    smooth_grid_table = dismod_at.get_table_dict(connection, 'smooth_grid')
    prior_table       = dismod_at.get_table_dict(connection, 'prior')
    covariate_table   = dismod_at.get_table_dict(connection, 'covariate')
    connection.close()
    for row in rate_table :
        if row['rate_name'] == 'iota' :
            smooth_id = row['parent_smooth_id']
    for row in smooth_grid_table :
        if row['smooth_id'] == smooth_id :
            prior_id = row['value_prior_id']
    shift_mean = prior_table[prior_id]['mean']
    for row in covariate_table :
        if row['covariate_name'] == 'income' :
            shift_reference_income = row['reference']
    #
    # check
    income_difference = shift_reference_income - fit_reference_income
    check = fit_iota * exp( fit_alpha * income_difference )
    assert abs(1.0 - shift_mean/check) < 1e-12
#
if __name__ == '__main__' :
    main(False)
    main(True)
    print('split_covariate: OK')