1#!/usr/bin/env python
2# coding: utf-8
3
4# # PCA with histogram
5#
6# created: 2/14/2026
7#
8# van
9
10# In[18]:
11
12
13import os
14import sys
15from glob import glob
16import re
17import numpy as np
18import matplotlib.pyplot as plt
19from matplotlib.ticker import FuncFormatter
20import seaborn as sb
21import pytraj as pt
22
23
24# In[2]:
25
26
27os.makedirs('img', exist_ok=True)
28os.makedirs('pca', exist_ok=True)
29
30
31# In[10]:
32
33
34# select systems for analysis
35
36dnames = ["wt", "r918a", "r918k"]
37# dnames = ["wt", "r921a", "r921k"]
38
39
40# In[11]:
41
42
43# set vars
44
45# convert dnames to join string
46dn = '_'.join(str(dname) for dname in dnames)
47os.makedirs(f'pca/{dn}', exist_ok=True)
48
49parm = "step3_pbcsetup_1264.parm7"
50cord = "prod*.nc"
51
52# Atom mask selection
53res_mask = ":909-1076,1404-1410"
54atm_mask = "@CA,C,N,O,P,O5',O3',OP1,OP2,C1',C2',C3',C4',C5'"
55# atm_mask = '@CA'
56ambermask = f'{res_mask}&{atm_mask}'
57
58clean_ambermask = re.sub(r'[^a-zA-Z0-9]', '_', ambermask)
59
60analysis=f'mask{clean_ambermask}'
61
62print('save name: ', analysis)
63print('mask selection:', ambermask)
64
65# change this to change number of prod files
66n_cords = 11
67
68
69
70# In[12]:
71
72
73# get trajectories, strip atoms, save new trajectory/parm
74
75
76label_idx = [] # save frames here to iterate in scatter
77
78for dname in dnames:
79 nc_files = sorted(glob(f'../{dname}/{cord}'))[n_cords:]
80 t1 = pt.iterload(nc_files, f'../{dname}/{parm}')
81 t1 = pt.superpose(t1, mask=ambermask)
82
83 # Strip everything not in atommask and save new trajectory
84 t1_strip = pt.strip(t1, f'!({ambermask})')
85 label_idx.append(t1_strip.n_frames)
86 nframes = str(t1_strip.n_frames)
87
88 out_dir = f'pca/{dn}/{nframes}/{analysis}'
89 os.makedirs(out_dir, exist_ok=True)
90 t1_strip.save(f'{out_dir}/{dname}.nc', overwrite=True)
91
92 print('fnames: ', nc_files)
93 print('traj frames/atoms: ', t1.n_frames, 'frames / ', t1.n_atoms, 'atoms')
94 print('stripped info: ', t1_strip.n_frames, 'frames /', t1_strip.n_atoms, 'atoms', f'\n')
95
96
97n = '_'.join(str(num) for num in label_idx) # convert int list to join string
98
99combine_out = f'pca/{dn}/{n}/{analysis}'
100os.makedirs(combine_out, exist_ok=True)
101combined_name = f'{combine_out}/step3'
102
103# save stripped parm
104top_keep = pt.strip(t1.top, f'!({ambermask})')
105pt.save(f'{combined_name}.parm7', top_keep, overwrite=True)
106
107# save stripped combined traj
108fnames= [fn for dname in dnames for fn in glob(f'pca/{dn}/{nframes}/{analysis}/{dname}.nc')] # find the correct files
109traj = pt.iterload(fnames, top_keep)
110traj = pt.superpose(traj, mask=atm_mask, ref=0)
111print('combined stripped info: ', traj)
112
113# save stripped trajectory
114pt.save(f'{combined_name}.nc', traj, overwrite=True)
115
116print('combined parm: ', f'{combined_name}.parm7')
117print('combined traj: ', f'{combined_name}.nc')
118
119
120# In[21]:
121
122
123# run PCA
124
125n_vectors = 10 # set this for cpptraj later
126
127traj = pt.iterload(f'{combined_name}.nc', f'{combined_name}.parm7')
128data = pt.pca(traj, mask=atm_mask, n_vecs=n_vectors)
129
130print(traj)
131
132
133# In[20]:
134
135
136# Get variance of PCs / flip data
137
138pc_eval = data[1][0] # eigenvalues (or variance) for PC1
139x_label = (pc_eval[0] / np.sum(pc_eval)) * 100 # percent variance for pc1
140y_label = (pc_eval[1] / np.sum(pc_eval)) * 100 # percent variance for pc2
141
142flip1, flip2 = -1, 1 # incase the axis needs to be flipped
143x_data = data[0][0] * flip1
144y_data = data[0][1] * flip2
145
146print('PC1 variance (%): ', x_label)
147print('PC2 variance (%): ', y_label)
148
149
150# In[15]:
151
152
153# Make figure
154
155axis_lim = 50
156opacity = 0.5
157colors = ['tab:blue', 'tab:orange', 'tab:green']
158
159fig = plt.figure(dpi=300, constrained_layout=True)
160#fig = plt.figure(figsize=(7.5, 7.5), dpi=300, constrained_layout=True)
161
162gs = fig.add_gridspec(2, 2, width_ratios=(4, 1), height_ratios=(1, 4), wspace=0.0, hspace=0.0)
163
164ax = fig.add_subplot(gs[1, 0])
165ax_histx = fig.add_subplot(gs[0, 0], sharex=ax)
166ax_histy = fig.add_subplot(gs[1, 1], sharey=ax)
167
168# set range from i (first frame) to new_i (last frame for sys)
169i = 0
170for ii, n in enumerate(label_idx):
171 new_i = i + n
172 xs = x_data[i:new_i]
173 ys = y_data[i:new_i]
174
175 ax.scatter(xs, ys, marker='o', c=colors[ii], alpha=opacity, label=dnames[ii])
176
177 # KDE plot
178 sb.kdeplot(x=xs, fill=True, common_norm=True, common_grid=True, ax=ax_histx, alpha=opacity)
179 sb.kdeplot(y=ys, fill=True, common_norm=True, common_grid=True, ax=ax_histy, alpha=opacity)
180
181 i = new_i # update i
182
183
184ax_histx.tick_params(axis='x', labelbottom=False, bottom=False) # top panel: hide x labels
185ax_histy.tick_params(axis='y', labelleft=False, left=False) # right panel: hide y labels
186# ax.tick_params(axis='x', labelrotation=45)
187# ax_histy.tick_params(axis='x', labelrotation=45)
188
189# ax_histx: remove 0.00 on density (y)
190base_y = ax_histx.yaxis.get_major_formatter()
191ax_histx.yaxis.set_major_formatter(FuncFormatter(lambda v, p: "" if np.isclose(v, 0.0) else base_y(v, p)))
192
193# ax_histy: remove 0.00 on density (x)
194base_x = ax_histy.xaxis.get_major_formatter()
195ax_histy.xaxis.set_major_formatter(FuncFormatter(lambda v, p: "" if np.isclose(v, 0.0) else base_x(v, p)))
196
197# remove lines in subplots
198ax_histx.spines['top'].set_visible(False)
199ax_histx.spines['right'].set_visible(False)
200ax_histx.spines['left'].set_visible(False)
201
202ax_histy.spines['top'].set_visible(False)
203ax_histy.spines['right'].set_visible(False)
204ax_histy.spines['bottom'].set_visible(False)
205
206
207ax.set_xlabel(f"PC1 ({np.round(x_label, 1)} %)", fontsize=14)
208ax.set_ylabel(f"PC2 ({np.round(y_label, 1)} %)", fontsize=14)
209ax.set_xlim(-axis_lim, axis_lim)
210ax.set_ylim(-axis_lim, axis_lim)
211ax.grid(linestyle='--', alpha=0.2)
212ax.legend(ncols=3, loc='lower center', fontsize=14)
213
214# optional: light grids for marginals too
215ax_histx.grid(linestyle='--', alpha=0.15)
216ax_histy.grid(linestyle='--', alpha=0.15)
217
218plt.savefig(f'img/{dn}_{nframes}{analysis}.png')
219
220
221# In[16]:
222
223
224# run cpptraj to get NMD and NC files for PCs
225
226text = f"""
227parm {combined_name}.parm7
228trajin {combined_name}.nc
229
230rms first {atm_mask}
231average crdset step3-average
232createcrd step3-trajectories
233run
234
235crdaction step3-trajectories rms ref step3-average {atm_mask}
236crdaction step3-trajectories matrix covar name step3-covar {atm_mask}
237
238runanalysis diagmatrix step3-covar out {combine_out}/step3-evecs.dat \
239 vecs {n_vectors} name myEvecs \
240 nmwiz nmwizvecs {n_vectors} nmwizfile {combine_out}/step3.nmd nmwizmask {atm_mask}
241
242runanalysis modes name myEvecs trajout {combine_out}/step3-mode1.nc \
243 pcmin -{axis_lim} pcmax {axis_lim} tmode 1 trajoutmask {atm_mask} trajoutfmt netcdf
244
245runanalysis modes name myEvecs trajout {combine_out}/step3-mode2.nc \
246 pcmin -{axis_lim} pcmax {axis_lim} tmode 2 trajoutmask {atm_mask} trajoutfmt netcdf
247"""
248
249state = pt.load_batch(traj, text)
250state.run()
251
252
253# In[ ]:
254
255
256
257