Petr Tsvetkov
Generate charts for the presentation & diploma;some refactoring; add (commented) Student's t-test
7ab7be2
import pandas as pd | |
from tqdm import tqdm | |
import config | |
import generate_annotated_diffs | |
import dataset_statistics | |
from api_wrappers import grazie_wrapper | |
from generation_steps import examples | |
GENERATION_MULTIPLIER = 3 | |
REL_DELETIONS_THRESHOLD = 0.75 | |
GENERATION_ATTEMPTS = 3 | |
def build_prompt(prediction, diff): | |
return f"""A LLM generated a commit message for the following source code changes: | |
START OF THE SOURCE CODE CHANGES | |
{diff} | |
END OF THE SOURCE CODE CHANGES | |
Here is the message the LLM generated: | |
START OF THE COMMIT MESSAGE | |
{prediction} | |
END OF THE COMMIT MESSAGE | |
This generated message is not perfect. Your task is to rewrite and improve it. | |
You have to simulate a human software developer who manually rewrites the LLM-generated commit message, | |
so the message you print must share some fragments with the generated message. | |
Your message should be concise. | |
Follow the Conventional Commits guidelines. | |
Here are some examples of what you should output: | |
START OF THE EXAMPLES LIST | |
{examples.EXAMPLES_START_TO_END} | |
END OF THE EXAMPLES LIST | |
Print only the improved commit message's text after the | |
token "OUTPUT". | |
OUTPUT""" | |
def generate_end_msg(start_msg, diff): | |
prompt = build_prompt(prediction=start_msg, diff=diff) | |
results = [] | |
for i in range(GENERATION_ATTEMPTS): | |
end_msg_pred = grazie_wrapper.generate_for_prompt(prompt) | |
stats = statistics.get_statistics(start_msg=start_msg, end_msg=end_msg_pred, | |
annotated_msg=generate_annotated_diffs.get_annotated_diff(start_msg, | |
end_msg_pred)) | |
if stats["deletions"] < REL_DELETIONS_THRESHOLD: | |
return end_msg_pred | |
else: | |
results.append((stats["deletions"], end_msg_pred)) | |
results.sort() | |
return results[0][1] | |
COLS_TO_KEEP = ["hash", "repo", "commit_msg_start", "mods", "session", "end_to_start"] | |
def print_config(): | |
print(f"NUMBER OF EXAMPLES PER PROMPT = {examples.N_EXAMPLES}") | |
print(f"GENERATION_MULTIPLIER = {GENERATION_MULTIPLIER}") | |
print(f"REL_DELETIONS_THRESHOLD = {REL_DELETIONS_THRESHOLD}") | |
print(f"GENERATION_ATTEMPTS = {GENERATION_ATTEMPTS}") | |
def transform(df): | |
print(f"Start -> send synthesis:") | |
print_config() | |
df['start_to_end'] = False | |
generated_data = { | |
"commit_msg_end": [] | |
} | |
for col in COLS_TO_KEEP: | |
generated_data[col] = [] | |
for _, row in tqdm(df.iterrows(), total=len(df)): | |
for i in range(GENERATION_MULTIPLIER): | |
commit_msg_end_pred = generate_end_msg(start_msg=row["commit_msg_start"], | |
diff=row["mods"]) | |
generated_data["commit_msg_end"].append(commit_msg_end_pred) | |
for col in COLS_TO_KEEP: | |
generated_data[col].append(row[col]) | |
generated_df = pd.DataFrame.from_dict(generated_data) | |
generated_df['start_to_end'] = True | |
result = pd.concat([df, generated_df], ignore_index=True) | |
result.to_csv(config.START_TO_END_ARTIFACT) | |
print("Done") | |
return result | |
def main(): | |
df = pd.read_csv(config.END_TO_START_ARTIFACT, index_col=[0]) | |
transform(df) | |
if __name__ == '__main__': | |
main() | |