# coding: utf-8
import itertools
import pprint
import numpy as np
import os.path
from os.path import basename, splitext
import matplotlib.pyplot as plt
from . import dcatools
from . import pdbtools
from .simulation import RNAPrediction
from .simulation import EvalData
from .sysconfig import SysConfig
[docs]def plot_constraint_quality(comparison_pdb, sources, dca_mode=False):
"""
Plot constraint quality.
This visualizes the distances of constraints by comparing it to a reference (native) PDB structure.
:param comparison_pdb: filename to a pdb file to compare to
:param sources: list of dca files or constraints (depending on dca_mode). may also use 'filter:' to filter on-the-fly
:param dca_mode: visualize residue-residue DCA instead of atom-atom constraints
"""
sim = RNAPrediction(SysConfig())
pdb = pdbtools.parse_pdb("foo", comparison_pdb)
chain = pdb[0].child_list[0]
print "Comparison PDB: %s" % comparison_pdb
plt.figure(figsize=(14, 8))
if dca_mode:
print "Mode: residue-residue"
bins = np.linspace(0, 60, 60)
else:
print "Mode: atom-atom"
bins = np.linspace(0, 60, 180)
for arg in sources:
print "Processing: %s" % arg
if arg.startswith("filter:"):
title = arg
_, dca_file, filtertext = arg.split(":", 2)
print " Applying filters: %s" % filtertext
dca = dcatools.parse_dca_data(dca_file)
dca_filter_chain = sim.parse_dca_filter_string(filtertext)
dcatools.filter_dca_data(dca_data=dca, dca_filter_chain=dca_filter_chain, quiet=True)
if not dca_mode:
print " Mapping to atom-atom."
cst_info = dcatools.build_cst_info_from_dca_contacts(dca, sequence=sim.config["sequence"], mapping_mode="minAtom", cst_function="FADE -100 26 20 -2 2", number_dca_predictions=100, quiet=True)
else:
if dca_mode:
title = splitext(basename(arg))[0]
print " Reading dca file directly."
dca = dcatools.parse_dca_data(arg)
else:
if ":" in arg:
title = arg
print " Building cst from dca on-the-fly."
arg = arg.split(":")
dca = dcatools.parse_dca_data(arg[0])
cst_info = dcatools.build_cst_info_from_dca_contacts(dca, sequence=sim.config["sequence"], mapping_mode=arg[1] if arg[1] else "minAtom", cst_function=arg[2] if arg[2] else "FADE -100 26 20 -2 2", number_dca_predictions=int(arg[3]) if arg[3] else 100, quiet=True)
else:
title = splitext(basename(arg))[0]
print " Reading cst file directly."
cst_info = sim.parse_cst_file(sim.parse_cst_name_and_filename(arg)[1])
dists = []
if dca_mode:
j = 0
for d in dca:
if not d.use_contact:
continue
j += 1
if j > 100:
break
average_heavy, minimum_heavy, minimum_pair = dcatools.get_contact_information_in_pdb_chain(d, chain)
if minimum_pair is None:
continue
dists.append(minimum_heavy)
else:
for cst in cst_info:
atom1, res1, atom2, res2, fuction = cst
try:
dists.append(chain[res1][atom1] - chain[res2][atom2])
except KeyError:
print atom1, res1, atom2, res2, " NOT FOUND"
print " Average distance:", np.average(dists)
plt.hist(dists, bins, alpha=0.6, label=title)
plt.legend()
plt.title("%s Constraints Quality %s" % ("Residue-Residue" if dca_mode else "Atom-Atom", os.path.basename(os.getcwd())))
plt.xlabel(u"Native distance / Å")
plt.ylabel("Number of contraints")
plt.savefig("/tmp/rna_tools_quality_%s.png" % os.path.basename(os.getcwd()), bbox_inches="tight")
plt.show()
[docs]def plot_clusters(cst, max_models=0.99, score_weights=None):
"""
Plot score over native rmsd.
:param cst: constraints
:param max_models: limit to number of models if > 1, or relative percentage if <= 1
:param score_weights: see EvalData.get_weighted_model_score
"""
sim = RNAPrediction(SysConfig())
cst_name, cst_file = sim.parse_cst_name_and_filename(cst)
eval_data = EvalData.load_from_cst(cst_name)
number = int(max_models)
if number <= 1:
number = int(eval_data.get_model_count() * max_models)
models = eval_data.get_models(range(1, number + 1), "top")
models_c = []
plots = []
descs = []
fig = plt.figure(figsize=(14, 8))
# no cluster
models_c.append([m for m in models if "cluster" not in m])
descs.append("no cluster (%d)" % (len(models_c[0])))
# clusters
i = 1
while True:
models_current_cluster = [m for m in models if "cluster" in m and m["cluster"] == i]
if len(models_current_cluster) == 0:
n_clusters = i - 1
break
models_c.append(models_current_cluster)
descs.append("cluster %d (%d)" % (i, len(models_current_cluster)))
i += 1
if "rmsd_native" in models_c[1][0]:
comparison = "rmsd_native"
plt.xlabel("native rmsd / A")
else:
comparison = "rmsd_cluster_1"
plt.xlabel("rmsd to best structure / A")
# create plots
colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#00FFFF', '#FF00FF', '#ff6600', '#006000', '#600060', '#F7C2CA']
shapes = ['o', 'D', 's', 'p']
def pick_event(event):
cluster = plots.index(event.artist)
for ind in event.ind:
print "cluster %d: %s" % (cluster, models_c[cluster][ind]["tag"])
pprint.pprint(models_c[cluster][ind])
for i in range(0, n_clusters + 1):
color = "w" if i == 0 else colors[(i - 1) % len(colors)]
shape = "o" if i == 0 else shapes[((i - 1) / len(colors)) % len(shapes)]
plots.append(plt.scatter([x[comparison] for x in models_c[i]], [x["score"] if score_weights is None else EvalData.get_weighted_model_score(x, score_weights) for x in models_c[i]], s=35, c=color, marker=shape, picker=True))
fig.canvas.mpl_connect('pick_event', pick_event)
plt.title("model clusters %s, constraints: %s" % (sim.config["name"], cst_name))
plt.ylabel("rosetta score")
plt.legend(plots, descs, loc="upper right", prop={'size': 12})
plt.savefig("/tmp/rna_tools_rmsdscore_%s.png" % os.path.basename(os.getcwd()), bbox_inches="tight")
plt.show()
[docs]def plot_pdb_comparison(pdb_ref_filename, pdbs_sample_filenames):
"""
Compare PDB files by plotting the distance of the residues.
:param pdb_ref_filename: reference PDB filename
:param pdbs_sample_filenames: list of sample PDB filenames
"""
pdb_ref = pdbtools.parse_pdb("foo", pdb_ref_filename)
sim = RNAPrediction(SysConfig())
pdbs_sample = []
for sample in pdbs_sample_filenames:
if ":" in sample:
fields = sample.split(":")
pdbs_sample.append(sim.extract_pdb(fields[0], RNAPrediction.get_models(fields[0], [fields[2]], fields[1])[0]))
else:
pdbs_sample.append(sample)
pdbs_sample = [pdbtools.parse_pdb(i, i) for i in pdbs_sample]
sim = RNAPrediction(SysConfig())
print sim.config["sequence"]
print sim.config["secstruc"]
plt.figure(figsize=(14, 8))
ax = plt.gca()
plt.title("Residue Distances %s" % (sim.config["name"]))
plt.xlabel("Residue")
plt.ylabel(u"Distance to Native Structure / Å")
x_values = np.arange(1, len(sim.config["sequence"]) + 1)
x_labels = ["%s\n%s" % ("\n".join(list("%.2d" % (i + 1))), x) for i, x in enumerate(sim.config["sequence"])]
x_min = 0.5
x_max = len(x_values) + 0.5
plt.xlim([x_min, x_max])
plt.xticks(x_values, x_labels)
y_min = 0
y_max = 0 # y limits will be set later when we know the final size
for pdb_sample in pdbs_sample:
dists_res, dists_atom, rmsd, rotran = pdbtools.align_structure(pdb_ref, pdb_sample, assign_b_factors=True)
plt.plot(x_values, dists_res, "-o", label=pdb_sample.id)
plt.plot((x_min, x_max), (rmsd, rmsd), "--", color=ax.lines[-1].get_color())
y_max = max(y_max, max(dists_res))
y_max *= 1.025
plt.ylim([y_min, y_max])
hatch = ""
last_x = ""
start_i = 0
for i, x in enumerate(sim.config["secstruc"]+"."):
if x != last_x and last_x != "":
if last_x != ".":
if last_x == "(":
hatch = "//"
elif last_x == ")":
hatch = "\\\\"
plt.fill_between([start_i + 0.5, i + 0.5], y_min, y_max, hatch=hatch, color="grey", edgecolor="black", alpha=0.25)
start_i = i
last_x = x
plt.legend(prop={'size': 12})
plt.savefig("/tmp/rna_tools_seqdist_%s.png" % os.path.basename(os.getcwd()), bbox_inches="tight")
plt.show()
[docs]def plot_gdt(pdb_ref_filename, pdbs_sample_filenames):
sim = RNAPrediction(SysConfig())
pdb_ref = pdbtools.parse_pdb("foo", pdb_ref_filename)
pdbs_sample = []
for sample in pdbs_sample_filenames:
if ":" in sample:
fields = sample.split(":")
pdbs_sample.append(sim.extract_pdb(fields[0], RNAPrediction.get_models(fields[0], [fields[2]], fields[1])[0]))
else:
pdbs_sample.append(sample)
pdbs_sample = [pdbtools.parse_pdb(i, i) for i in pdbs_sample]
print sim.config["sequence"]
print sim.config["secstruc"]
plt.figure(figsize=(14, 8))
ax = plt.gca()
plt.title("GDT Plot %s" % (sim.config["name"]))
plt.xlabel("Percent of Residues")
plt.ylabel(u"Distance Cutoff / Å")
x_min = 0
x_max = 100
plt.xlim([x_min, x_max])
y_min = 0
y_max = 0 # y limits will be set later when we know the final size
for pdb_sample in pdbs_sample:
dists_res, dists_atom, rmsd, rotran = pdbtools.align_structure(pdb_ref, pdb_sample, assign_b_factors=True)
dists_res = sorted(dists_res)
count = len(dists_res)
x_values = []
y_values = []
for i in xrange(0, count):
x_values.append(100 * (i + 1.0) / count)
y_values.append(dists_res[i])
plt.plot(x_values, y_values, "-", label=pdb_sample.id)
plt.plot((x_min, x_max), (rmsd, rmsd), "--", color=ax.lines[-1].get_color())
y_max = max(y_max, max(dists_res))
y_max *= 1.025
plt.ylim([y_min, y_max])
plt.legend(prop={'size': 12}, loc="upper left")
plt.savefig("/tmp/rna_tools_gdtplot_%s.png" % os.path.basename(os.getcwd()), bbox_inches="tight")
plt.show()
[docs]def plot_tp_rate(pdb_ref_filename, dca_filenames, tp_cutoff=8.0):
pdb_ref = pdbtools.parse_pdb("foo", pdb_ref_filename)
chain = pdb_ref[0].child_list[0]
for p, dca_filename in enumerate(dca_filenames):
count = 0
tp = 0
dca = dcatools.parse_dca_data(dca_filename)
x_values = []
y_values = []
for d in dca:
count += 1;
average_heavy, minimum_heavy, minimum_pair = dcatools.get_contact_information_in_pdb_chain(d, chain, heavy_only=False)
if minimum_heavy < tp_cutoff:
tp += 1
print d.res1, d.res2, 1.0 * tp / count, minimum_heavy
x_values.append(count)
y_values.append(1.0 * tp / count)
print count, tp
plt.plot(x_values, y_values, "-", label=dca_filename)
plt.title("TP-Rate %s, Cutoff = %s" % (os.path.basename(os.getcwd()), tp_cutoff))
plt.xlabel("Rank")
plt.ylabel("TP-Rate / Rank")
plt.xlim([1, 1000])
plt.ylim([0, 1.1])
plt.xscale("log")
plt.legend()
plt.savefig("/tmp/rna_predict_tprate_%s.png" % os.path.basename(os.getcwd()), bbox_inches="tight")
plt.show()