function create_grouped_correlation_plot
Creates and saves a dual-panel heatmap visualization showing correlation matrices grouped by treatment and challenge regimen variables.
/tf/active/vicechatdev/vice_ai/smartstat_scripts/5a059cb7-3903-4020-8519-14198d1f39c9/analysis_1.py
304 - 350
moderate
Purpose
This function generates a comprehensive visualization of correlation data organized by experimental grouping variables. It creates two side-by-side heatmaps: one showing correlations by treatment groups and another by challenge regimen groups. The function filters the input dataframe for relevant grouping variables, pivots the data to create correlation matrices, and renders them as annotated heatmaps with a diverging color scheme centered at zero. The output is saved as a high-resolution PNG file.
Source Code
def create_grouped_correlation_plot(grouped_results_df):
"""Create visualization of correlations by groups"""
if len(grouped_results_df) == 0:
return
# Plot correlations by group
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# By treatment
treatment_data = grouped_results_df[
grouped_results_df['Grouping_Variable'].str.contains('treatment', case=False)
]
if len(treatment_data) > 0:
pivot_treatment = treatment_data.pivot_table(
values='Correlation',
index='Group_Value',
columns='Performance_Variable',
aggfunc='mean'
)
sns.heatmap(pivot_treatment, annot=True, fmt='.3f', cmap='RdBu_r',
center=0, vmin=-1, vmax=1, ax=axes[0])
axes[0].set_title('Correlations by Treatment', fontweight='bold')
# By challenge regimen
challenge_data = grouped_results_df[
grouped_results_df['Grouping_Variable'].str.contains('challenge', case=False)
]
if len(challenge_data) > 0:
pivot_challenge = challenge_data.pivot_table(
values='Correlation',
index='Group_Value',
columns='Performance_Variable',
aggfunc='mean'
)
sns.heatmap(pivot_challenge, annot=True, fmt='.3f', cmap='RdBu_r',
center=0, vmin=-1, vmax=1, ax=axes[1])
axes[1].set_title('Correlations by Challenge Regimen', fontweight='bold')
plt.tight_layout()
plt.savefig('grouped_correlations.png', dpi=300, bbox_inches='tight')
print("Saved: grouped_correlations.png")
plt.close()
Parameters
| Name | Type | Default | Kind |
|---|---|---|---|
grouped_results_df |
- | - | positional_or_keyword |
Parameter Details
grouped_results_df: A pandas DataFrame containing grouped correlation results. Must include columns: 'Grouping_Variable' (string indicating the grouping category), 'Group_Value' (specific group identifier), 'Performance_Variable' (variable being correlated), and 'Correlation' (numeric correlation coefficient). The dataframe should contain rows where Grouping_Variable contains 'treatment' or 'challenge' (case-insensitive) for the function to generate meaningful visualizations.
Return Value
Returns None. The function produces side effects: creates a matplotlib figure with two heatmap subplots, saves it as 'grouped_correlations.png' in the current working directory at 300 DPI resolution, prints a confirmation message to stdout, and closes the plot to free memory.
Dependencies
pandasmatplotlibseaborn
Required Imports
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
Usage Example
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# Create sample grouped correlation data
data = {
'Grouping_Variable': ['treatment', 'treatment', 'treatment', 'challenge_regimen', 'challenge_regimen'],
'Group_Value': ['Control', 'Drug_A', 'Drug_B', 'High_Dose', 'Low_Dose'],
'Performance_Variable': ['Weight_Gain', 'Weight_Gain', 'Weight_Gain', 'Survival_Rate', 'Survival_Rate'],
'Correlation': [0.45, 0.72, -0.23, 0.88, 0.56]
}
grouped_results_df = pd.DataFrame(data)
# Generate the correlation plot
create_grouped_correlation_plot(grouped_results_df)
# Output: Saves 'grouped_correlations.png' and prints confirmation message
Best Practices
- Ensure the input dataframe contains the required columns: 'Grouping_Variable', 'Group_Value', 'Performance_Variable', and 'Correlation'
- Use case-insensitive strings containing 'treatment' or 'challenge' in the 'Grouping_Variable' column for proper filtering
- Verify write permissions exist in the working directory before calling the function
- The function silently returns if the input dataframe is empty, so validate input data beforehand
- Correlation values should be in the range [-1, 1] for proper visualization scaling
- If multiple correlations exist for the same Group_Value and Performance_Variable combination, they will be averaged using 'mean' aggregation
- Close any existing matplotlib figures before calling to avoid memory issues with repeated calls
- Consider renaming or moving the output file if calling the function multiple times to avoid overwriting previous results
Tags
Similar Components
AI-powered semantic similarity - components with related functionality:
-
function create_correlation_heatmap 75.5% similar
-
function grouped_correlation_analysis 62.3% similar
-
function create_scatter_plots 60.2% similar
-
function main_v54 58.5% similar
-
function export_results 55.9% similar