\(\newcommand{\B}[1]{ {\bf #1} }\) \(\newcommand{\R}[1]{ {\rm #1} }\)
split_covariate.py¶
View page sourcesplit_covariate: Python Source Code¶
# ----------------------------------------------------------------------------
# imports
# ----------------------------------------------------------------------------
import math
import sys
import os
import copy
import time
import csv
import random
import shutil
import dismod_at
from math import exp
#
# import at_cascade with a preference current directory version
current_directory = os.getcwd()
if os.path.isfile( current_directory + '/at_cascade/__init__.py' ) :
sys.path.insert(0, current_directory)
import at_cascade
# -----------------------------------------------------------------------------
# global variables
# -----------------------------------------------------------------------------
# BEGIN fit_goal_set
fit_goal_set = { 'n3', 'n4', 'n5', 'n6' }
# END fit_goal_set
#
# BEGIN option_all_table
option_all = {
'result_dir': 'build/example',
'root_node_name': 'n0',
'root_split_reference_name': 'both',
'split_covariate_name': 'sex',
'shift_prior_std_factor': 1e3,
}
option_all['root_database'] = option_all['result_dir'] + '/root.db'
# END option_all_table
#
#
# BEGIN split_reference_table
split_reference_table = [
{'split_reference_name': 'female', 'split_reference_value': 1.0},
{'split_reference_name': 'both', 'split_reference_value': 2.0},
{'split_reference_name': 'male', 'split_reference_value': 3.0},
]
split_reference_list = list()
for row in split_reference_table :
split_reference_list.append( row['split_reference_value'] )
# END split_reference_table
# BEGIN node_split_table
node_split_table = [ { 'node_name' : 'n0'} ]
# END node_split_table
#
# BEGIN root_split_reference_id
root_split_reference_id = 1
assert \
split_reference_table[root_split_reference_id]['split_reference_name']=='both'
# END root_split_reference_id
#
# BEGIN avg_income
avg_income = dict()
leaf_node_set = { 3, 4, 5, 6 }
for node_id in leaf_node_set :
node_name = 'n' + str(node_id)
avg_income[node_name] = [ 1.0 - node_id / 10.0, 1.0, 1.0 + node_id / 10.0 ]
# child_list
# children of node 0, 1, 2 in that order
child_list = [ (1,2), (3,4), (5,6) ]
for node_id in [2, 1, 0] :
avg_list = list()
for split_reference_id in range(3) :
avg = 0.0
for child_id in child_list[node_id] :
child_name = 'n' + str(child_id)
avg += avg_income[child_name][split_reference_id]
avg = avg / len( child_list[node_id] )
avg_list.append( avg )
node_name = 'n' + str(node_id)
#
avg_income[node_name] = avg_list
# END avg_income
#
# BEGIN alpha_true
alpha_true = - 0.2
# END alpha_true
# ----------------------------------------------------------------------------
# functions
# ----------------------------------------------------------------------------
# BEGIN rate_true
def rate_true(rate, a, t, n, c) :
# true_iota
true_iota = {
'n3' : 1e-2,
'n4' : 2e-2,
'n5' : 3e-2,
'n6' : 4e-2
}
true_iota['n1'] = (true_iota['n3'] + true_iota['n4']) / 2.9
true_iota['n2'] = (true_iota['n5'] + true_iota['n6']) / 2.9
true_iota['n0'] = (true_iota['n1'] + true_iota['n2']) / 2.9
#
# effect
sex = c[0]
income = c[1]
#
# split_reference_id
split_reference_id = None
for (row_id, row) in enumerate(split_reference_table) :
if row['split_reference_value'] == sex :
split_reference_id = row_id
#
r_income = avg_income[n][split_reference_id]
effect = alpha_true * ( income - r_income )
#
if rate == 'iota' :
return true_iota[n] * exp(effect)
if rate == 'omega' :
return 2.0 * true_iota[n] * exp(effect)
return 0.0
# END rate_true
# ----------------------------------------------------------------------------
def root_node_db(file_name) :
#
# iota_n0
sex = split_reference_list[root_split_reference_id]
income = avg_income['n0'][root_split_reference_id]
c = [ sex, income ]
iota_n0 = rate_true('iota', None, None, 'n0', c)
#
# prior_table
prior_table = list()
prior_table.append(
# BEGIN parent_value_prior
{ 'name': 'parent_value_prior',
'density': 'gaussian',
'lower': iota_n0 / 10.0,
'upper': iota_n0 * 10.0,
'mean': iota_n0 ,
'std': iota_n0 * 10.0,
'eta': iota_n0 * 1e-3
}
# END parent_value_prior
)
prior_table.append(
# BEGIN alpha_value_prior
{ 'name': 'alpha_value_prior',
'density': 'gaussian',
'lower': - 10 * abs(alpha_true),
'upper': + 10 * abs(alpha_true),
'std': + 10 * abs(alpha_true),
'mean': 0.0,
}
# END alpha_value_prior
)
#
# smooth_table
smooth_table = list()
#
# parent_smooth
fun = lambda a, t : ('parent_value_prior', None, None)
smooth_table.append({
'name': 'parent_smooth',
'age_id': [0],
'time_id': [0],
'fun': fun,
})
#
# alpha_smooth
fun = lambda a, t : ('alpha_value_prior', None, None)
smooth_table.append({
'name': 'alpha_smooth',
'age_id': [0],
'time_id': [0],
'fun': fun,
})
#
# node_table
node_table = [
{ 'name':'n0', 'parent':'' },
{ 'name':'n1', 'parent':'n0' },
{ 'name':'n2', 'parent':'n0' },
{ 'name':'n3', 'parent':'n1' },
{ 'name':'n4', 'parent':'n1' },
{ 'name':'n5', 'parent':'n2' },
{ 'name':'n6', 'parent':'n2' },
]
#
# rate_table
rate_table = [ {
'name': 'iota',
'parent_smooth': 'parent_smooth',
'child_smooth': None ,
} ]
#
# covariate_table
covariate_table = list()
sex = split_reference_list[root_split_reference_id]
income = avg_income['n0'][root_split_reference_id]
covariate_table.append(
{ 'name': 'sex', 'reference': sex, 'max_difference': 1.1 }
)
covariate_table.append( { 'name': 'income', 'reference': income } )
#
# mulcov_table
mulcov_table = [ {
# alpha
'covariate': 'income',
'type': 'rate_value',
'effected': 'iota',
'group': 'world',
'smooth': 'alpha_smooth',
} ]
#
# subgroup_table
subgroup_table = [ {'subgroup': 'world', 'group':'world'} ]
#
# integrand_table
integrand_table = [ {'name':'Sincidence'} ]
for mulcov_id in range( len(mulcov_table) ) :
integrand_table.append( { 'name': f'mulcov_{mulcov_id}' } )
#
# avgint_table
avgint_table = list()
row = {
'node': 'n0',
'subgroup': 'world',
'weight': '',
'time_lower': 2000.0,
'time_upper': 2000.0,
'age_lower': 50.0,
'age_upper': 50.0,
'sex': None,
'income': None,
'integrand': 'Sincidence',
}
avgint_table.append( copy.copy(row) )
#
# data_table
data_table = list()
leaf_set = { 'n3', 'n4', 'n5', 'n6' }
row = {
'subgroup': 'world',
'weight': '',
'time_lower': 2000.0,
'time_upper': 2000.0,
'age_lower': 50.0,
'age_upper': 50.0,
'integrand': 'Sincidence',
'density': 'gaussian',
'hold_out': False,
}
assert split_reference_table[0]['split_reference_name'] == 'female'
assert split_reference_table[2]['split_reference_name'] == 'male'
for split_reference_id in [ 0, 2 ] :
for node in leaf_set :
sex = split_reference_list[split_reference_id]
r_income = avg_income[node][split_reference_id]
for factor in [ 0.5, 1.0, 1.5 ] :
income = factor * r_income
c = [sex, income]
meas_value = rate_true('iota', None, None, node, c)
row['node'] = node
row['meas_value'] = meas_value
row['sex'] = sex
row['income'] = income
row['meas_std'] = meas_value / 10.0
data_table.append( copy.copy(row) )
#
# age_grid
age_grid = [ 0.0, 100.0 ]
#
# time_grid
time_grid = [ 1980.0, 2020.0 ]
#
# weight table:
weight_table = list()
#
# nslist_table
nslist_table = dict()
#
# option_table
option_table = [
{ 'name':'parent_node_name', 'value':'n0'},
{ 'name':'rate_case', 'value':'iota_pos_rho_zero'},
{ 'name': 'zero_sum_child_rate', 'value':'iota'},
{ 'name':'quasi_fixed', 'value':'false'},
{ 'name':'max_num_iter_fixed', 'value':'50'},
{ 'name':'tolerance_fixed', 'value':'1e-8'},
]
# ----------------------------------------------------------------------
# create database
dismod_at.create_database(
file_name,
age_grid,
time_grid,
integrand_table,
node_table,
subgroup_table,
weight_table,
covariate_table,
avgint_table,
data_table,
prior_table,
smooth_table,
nslist_table,
rate_table,
mulcov_table,
option_table
)
# ----------------------------------------------------------------------------
# main
# ----------------------------------------------------------------------------
def main(refit_split) :
# BEGIN refit_split
if refit_split :
option_all['refit_split'] = 'true'
else :
option_all['refit_split'] = 'false'
# END refit_split
# -------------------------------------------------------------------------
# result_dir
result_dir = option_all['result_dir']
at_cascade.empty_directory(result_dir)
#
# Create root.db
root_database = option_all['root_database']
root_node_db(root_database)
#
# omega_grid
connection = dismod_at.create_connection(
root_database, new = False, readonly = True
)
age_table = dismod_at.get_table_dict(connection, 'age')
time_table = dismod_at.get_table_dict(connection, 'time')
age_id_list = list( range( len(age_table) ) )
time_id_list = list( range( len(age_table) ) )
omega_grid = { 'age': age_id_list, 'time' : time_id_list }
connection.close()
#
# n_split
n_split = len( split_reference_list )
#
# omega_data
omega_data = dict()
for node_name in [ 'n0', 'n1', 'n2', 'n3', 'n4', 'n5', 'n6' ] :
omega_data[node_name] = list()
for k in range(n_split) :
omega_data[node_name].append( list() )
for age_id in omega_grid['age'] :
for time_id in omega_grid['time'] :
age = age_table[age_id]['age']
time = time_table[time_id]['time']
sex = split_reference_list[k]
income = avg_income[node_name][k]
cov = [ sex, income ]
omega = rate_true('omega', None, None, node_name, cov)
omega_data[node_name][k].append( omega )
#
# Create all_node.db
all_node_database = f'{result_dir}/all_node.db'
at_cascade.create_all_node_db(
all_node_database = all_node_database,
split_reference_table = split_reference_table,
node_split_table = node_split_table,
option_all = option_all,
omega_grid = omega_grid,
omega_data = omega_data,
)
#
# root_node_dir
root_node_dir = f'{result_dir}/n0'
os.mkdir(root_node_dir)
#
# avgint_table
# This also erases the avgint table from root_database
avgint_table = at_cascade.extract_avgint( root_database )
#
# cascade starting at root node
at_cascade.cascade_root_node(
all_node_database = all_node_database ,
fit_goal_set = fit_goal_set ,
)
#
# check results
for sex in [ 'female', 'male' ] :
for subdir in [ 'n1/n3', 'n1/n4', 'n2/n5', 'n2/n6' ] :
goal_database = f'{result_dir}/n0/{sex}/{subdir}/dismod.db'
at_cascade.check_cascade_node(
rate_true = rate_true,
all_node_database = all_node_database,
fit_database = goal_database,
avgint_table = avgint_table,
relative_tolerance = 1e-5,
)
#
#
# fit_iota, fit_alpha, fit_reference_income
fit_database = f'{result_dir}/n0/dismod.db'
connection = dismod_at.create_connection(
fit_database, new = False, readonly = True
)
var_table = dismod_at.get_table_dict(connection, 'var')
fit_var_table = dismod_at.get_table_dict(connection, 'fit_var')
rate_table = dismod_at.get_table_dict(connection, 'rate')
prior_table = dismod_at.get_table_dict(connection, 'prior')
covariate_table = dismod_at.get_table_dict(connection, 'covariate')
connection.close()
for (var_id, row) in enumerate(var_table) :
rate_id = row['rate_id']
rate_name = rate_table[rate_id]['rate_name']
if rate_name == 'iota' :
if row['var_type'] == 'rate' :
fit_iota = fit_var_table[var_id]['fit_var_value']
else :
assert row['var_type'] == 'mulcov_rate_value'
fit_alpha = fit_var_table[var_id]['fit_var_value']
for row in covariate_table :
if row['covariate_name'] == 'income' :
fit_reference_income = row['reference']
#
# shift_mean, shift_reference_income
if refit_split :
shift_database = f'{result_dir}/n0/female/dismod.db'
else :
shift_database = f'{result_dir}/n0/female/n1/dismod.db'
connection = dismod_at.create_connection(
shift_database, new = False, readonly = True
)
rate_table = dismod_at.get_table_dict(connection, 'rate')
smooth_grid_table = dismod_at.get_table_dict(connection, 'smooth_grid')
prior_table = dismod_at.get_table_dict(connection, 'prior')
covariate_table = dismod_at.get_table_dict(connection, 'covariate')
connection.close()
for row in rate_table :
if row['rate_name'] == 'iota' :
smooth_id = row['parent_smooth_id']
for row in smooth_grid_table :
if row['smooth_id'] == smooth_id :
prior_id = row['value_prior_id']
shift_mean = prior_table[prior_id]['mean']
for row in covariate_table :
if row['covariate_name'] == 'income' :
shift_reference_income = row['reference']
#
# check
income_difference = shift_reference_income - fit_reference_income
check = fit_iota * exp( fit_alpha * income_difference )
assert abs(1.0 - shift_mean/check) < 1e-12
#
if __name__ == '__main__' :
main(False)
main(True)
print('split_covariate: OK')