Created
April 1, 2024 02:58
-
-
Save SinclairCoder/eb0965755a001e78942eab1eded44675 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import math | |
import matplotlib.pyplot as plt | |
plt.rcParams['figure.dpi'] = 1000 #分辨率 | |
class ComplexRadar(): | |
""" | |
Create a complex radar chart with different scales for each variable | |
Parameters | |
---------- | |
fig : figure object | |
A matplotlib figure object to add the axes on | |
variables : list | |
A list of variables | |
ranges : list | |
A list of tuples (min, max) for each variable | |
n_ring_levels: int, defaults to 5 | |
Number of ordinate or ring levels to draw | |
show_scales: bool, defaults to True | |
Indicates if we the ranges for each variable are plotted | |
""" | |
def __init__(self, fig, variables, ranges, n_ring_levels=5, show_scales=True): | |
# Calculate angles and create for each variable an axes | |
# Consider here the trick with having the first axes element twice (len+1) | |
angles = np.arange(0, 360, 360./len(variables)) | |
axes = [fig.add_axes([0.1,0.1,0.9,0.9], polar=True, label = "axes{}".format(i)) for i in range(len(variables)+1)] | |
# Ensure clockwise rotation (first variable at the top N) | |
for ax in axes: | |
ax.set_theta_zero_location('N') | |
ax.set_theta_direction(-1) | |
ax.set_axisbelow(True) | |
# Writing the ranges on each axes | |
for i, ax in enumerate(axes): | |
# Here we do the trick by repeating the first iteration | |
j = 0 if (i==0 or i==1) else i-1 | |
ax.set_ylim(*ranges[j]) | |
# Set endpoint to True if you like to have values right before the last circle | |
grid = np.linspace(*ranges[j], num=n_ring_levels, | |
endpoint=False) | |
gridlabel = ["{}".format(round(x,2)) for x in grid] | |
gridlabel[0] = "" # remove values from the center | |
lines, labels = ax.set_rgrids(grid, labels=gridlabel, angle=angles[j]) | |
ax.set_ylim(*ranges[j]) | |
ax.spines["polar"].set_visible(False) | |
ax.grid(visible=False) | |
if show_scales == False: | |
ax.set_yticklabels([]) | |
# Set all axes except the first one unvisible | |
for ax in axes[1:]: | |
ax.patch.set_visible(False) | |
ax.xaxis.set_visible(False) | |
# Setting the attributes | |
self.angle = np.deg2rad(np.r_[angles, angles[0]]) | |
self.ranges = ranges | |
self.ax = axes[0] | |
self.ax1 = axes[1] | |
self.plot_counter = 0 | |
# Draw (inner) circles and lines | |
self.ax.yaxis.grid() | |
self.ax.xaxis.grid() | |
# Draw outer circle | |
self.ax.spines['polar'].set_visible(True) | |
# ax1 is the duplicate of axes[0] (self.ax) | |
# Remove everything from ax1 except the plot itself | |
self.ax1.axis('off') | |
self.ax1.set_zorder(9) | |
# Create the outer labels for each variable | |
l, text = self.ax.set_thetagrids(angles, labels=variables) | |
# Beautify them | |
labels = [t.get_text() for t in self.ax.get_xticklabels()] | |
# labels = ['\n'.join(textwrap.wrap(l, 15, | |
# break_long_words=False)) for l in labels] | |
self.ax.set_xticklabels(labels, fontsize = 13) | |
for t,a in zip(self.ax.get_xticklabels(),angles): | |
if a == 0: | |
t.set_ha('center') | |
elif a > 0 and a < 180: | |
t.set_ha('left') | |
elif a == 180: | |
t.set_ha('center') | |
else: | |
t.set_ha('right') | |
self.ax.tick_params(axis='both', pad=15) | |
def _scale_data(self, data, ranges): | |
"""Scales data[1:] to ranges[0]""" | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
assert (y1 <= d <= y2) or (y2 <= d <= y1) | |
x1, x2 = ranges[0] | |
d = data[0] | |
sdata = [d] | |
for d, (y1, y2) in zip(data[1:], ranges[1:]): | |
sdata.append((d-y1) / (y2-y1) * (x2 - x1) + x1) | |
return sdata | |
def plot(self, data, *args, **kwargs): | |
"""Plots a line""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.plot(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
self.plot_counter = self.plot_counter+1 | |
def fill(self, data, *args, **kwargs): | |
"""Plots an area""" | |
sdata = self._scale_data(data, self.ranges) | |
self.ax1.fill(self.angle, np.r_[sdata, sdata[0]], *args, **kwargs) | |
def use_legend(self, *args, **kwargs): | |
"""Shows a legend""" | |
self.ax1.legend(*args, **kwargs) | |
def set_title(self, title, pad=25, **kwargs): | |
"""Set a title""" | |
self.ax.set_title(title,pad=pad, **kwargs) | |
if __name__ == '__main__': | |
# restaurant-acos dataset | |
rest_single_extraction = [58.31, 66.93, 65.23, 72.96, 70.14, 75.64, 93.85, 92.90, 83.63, 87.41, 83.20] | |
rest_unify_extraction = [59.12, 70.71, 67.98, 77.86, 75.28, 77.04, 94.41, 93.51, 84.15, 88.45, 84.28] | |
rest_single_paraphrase = [59.38, 68.81, 68.22, 76.33, 73.27, 75.81, 90.66, 89.69, 83.04, 86.81, 84.00] | |
rest_unify_paraphrase = [60.60, 73.06, 70.16, 78.31, 77.76, 77.91, 91.95, 91.17, 84.62, 88.24, 86.12] | |
rest_f1 = np.array([rest_single_extraction, rest_unify_extraction, rest_single_paraphrase, rest_unify_paraphrase]) | |
task = ['ACOSQE', 'AOSTE', 'ACSTE', 'ASPE', 'AOPE', 'CSPE', 'AOSC', 'COSC', 'ACD', 'AOOE', 'ATE'] | |
rest_df = pd.DataFrame({ | |
# 'group': ['Single-Extraction', 'Unify-Extraction', 'Single-Paraphrase', 'Unify-Paraphrase'], | |
task[0]: rest_f1[:,0], | |
task[1]: rest_f1[:,1], | |
task[2]: rest_f1[:,2], | |
task[3]: rest_f1[:,3], | |
task[4]: rest_f1[:,4], | |
task[5]: rest_f1[:,5], | |
task[6]: rest_f1[:,6], | |
task[7]: rest_f1[:,7], | |
task[8]: rest_f1[:,8], | |
task[9]: rest_f1[:,9], | |
task[10]: rest_f1[:,10], | |
}, | |
# index = ['Single-Extraction', 'Unify-Extraction', 'Single-Paraphrase', 'Unify-Paraphrase'] | |
index = ['Single-E', 'Unify-E', 'Single-P', 'Unify-P'] | |
) | |
data = rest_df | |
min_max_per_variable = data.describe().T[['min', 'max']] | |
min_max_per_variable['min'] = min_max_per_variable['min'].apply(lambda x: int(x)) | |
min_max_per_variable['max'] = min_max_per_variable['max'].apply(lambda x: math.ceil(x)) | |
variables = data.columns | |
ranges = list(min_max_per_variable.itertuples(index=False, name=None)) | |
fig1 = plt.figure(figsize=(6, 6)) | |
radar = ComplexRadar(fig1, variables, ranges, show_scales=True) | |
for g in data.index: | |
radar.plot(data.loc[g].values, label=f"{g}") | |
radar.fill(data.loc[g].values, alpha=0.40) | |
radar.set_title("Comparison on Restaurant-ACOS", fontsize=17) | |
radar.use_legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), ncol=radar.plot_counter, fontsize = 13) | |
plt.savefig('radar-on-rest-norm.pdf',dpi=3000, format='pdf',bbox_inches='tight') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment