From 34fcc98a134aac20715573b443922580220d5aeb Mon Sep 17 00:00:00 2001 From: gouzhonglin Date: Mon, 18 Dec 2023 14:34:13 +0800 Subject: [PATCH] update easymodel-plugins --- plugins/config.yaml | 25 +- plugins/src/backend/drivers/driver_adaptor.py | 66 ++- .../src/backend/drivers/driver_interface.py | 6 +- plugins/src/backend/drivers/gitee.py | 101 +++- plugins/src/backend/models/ChatCPT.py | 84 +++- plugins/src/backend/models/model_adaptor.py | 56 ++- plugins/src/backend/models/model_interface.py | 10 +- plugins/src/common/configs/driver_args.py | 18 +- plugins/src/common/configs/generate_args.py | 7 + plugins/src/common/configs/model_args.py | 1 - plugins/src/common/configs/project_args.py | 67 ++- plugins/src/plugins/plugin_adaptor.py | 1 + plugins/src/plugins/plugin_interface.py | 8 +- plugins/src/plugins/pr_review/commenter.py | 95 ++++ plugins/src/plugins/pr_review/pr_review.py | 431 ++++++++++++++++-- plugins/src/task/manage.py | 9 +- plugins/src/task/router/router.py | 1 - 17 files changed, 831 insertions(+), 155 deletions(-) create mode 100644 plugins/src/plugins/pr_review/commenter.py diff --git a/plugins/config.yaml b/plugins/config.yaml index 924155d..48ebea6 100644 --- a/plugins/config.yaml +++ b/plugins/config.yaml @@ -1,14 +1,23 @@ driver: - driver_type: gitee + driver_type: driver_base_url: driver_token: -auth: - auth_rul: + path_rules: + - "!a.py" + - "*bbb/*.java" model: max_token_length: - encoding_name: - model_base_url: - limit: - prompt: - system_message: + encoding_name: + model_base_url: + limit: + prompt: + model_name: + model_type: + +generate: + grant_type: + auth_url: + app_id: + app_secret: + plugin_type: diff --git a/plugins/src/backend/drivers/driver_adaptor.py b/plugins/src/backend/drivers/driver_adaptor.py index 0ecad9f..854f0a0 100644 --- a/plugins/src/backend/drivers/driver_adaptor.py +++ b/plugins/src/backend/drivers/driver_adaptor.py @@ -1,19 +1,42 @@ - +import json +from loguru import logger +import requests +import fnmatch from abc import ABC, abstractmethod class BaseDriver(ABC): def __init__(self, project_args, driver_args) -> None: super().__init__() - self.driver_args = driver_args + self.access_token = driver_args.driver_token + self.driver_url = driver_args.driver_base_url + self.owner = project_args.owner + self.repo = project_args.repo + self.pr_number = project_args.pull_request_number + self.rules = driver_args.path_rules + self.rules = self._init_rules() + + def _init_rules(self): + rules_dic = {} + if not self.rules: + return rules_dic + for rule in self.rules: + rule = rule.strip() + if rule: + if rule.startswith("!"): + rules_dic[rule[1:].strip()] = True + else: + rules_dic[rule] = False + + return rules_dic @abstractmethod def list_comments(self): pass @abstractmethod - def get_all_commit_ids_by_pr(self): + def get_all_commit_ids(self): pass - + @abstractmethod def list_review_comments(self): pass @@ -26,14 +49,37 @@ class BaseDriver(ABC): def fetch_pr(self): pass - @abstractmethod - def fetch_file_content(self): - pass - @abstractmethod def get_pr_diff(self): pass - + @abstractmethod def get_driver_type(self): - return self.driver_args.driver_type + pass + + @abstractmethod + def compare(self): + pass + + def check_path(self, path): + if len(self.rules) == 0: + return True + included = False + excluded = False + inclusionRuleExists = False + + for aRule, exclude in self.rules.items(): + if fnmatch.fnmatch(path, aRule): + if exclude: + excluded = True + else: + included = True + if not exclude: + inclusionRuleExists = True + return ((not inclusionRuleExists) or included) and (not excluded) + + def assemble_commit_ids(self, current_commit_id, reviewed_commit_ids): + commit_ids = current_commit_id + ',' + reviewed_commit_ids + if commit_ids.endswith(","): + commit_ids = commit_ids[:-1] + return commit_ids diff --git a/plugins/src/backend/drivers/driver_interface.py b/plugins/src/backend/drivers/driver_interface.py index 755cc66..3ef178a 100644 --- a/plugins/src/backend/drivers/driver_interface.py +++ b/plugins/src/backend/drivers/driver_interface.py @@ -4,15 +4,15 @@ from abc import abstractmethod from common.configs.driver_args import DriverArguments from backend.drivers.gitee import GiteeDriver - +from backend.drivers.driver_adaptor import BaseDriver class DriverInterface(): def __init__(self, project_args, driver_args: DriverArguments) -> None: super().__init__() self.md = self.create_driver(project_args, driver_args) - @abstractmethod - def create_driver(self, project_args, driver_args): + @classmethod + def create_driver(cls, project_args, driver_args): driver_mapping = { 'gitee': GiteeDriver } diff --git a/plugins/src/backend/drivers/gitee.py b/plugins/src/backend/drivers/gitee.py index 36f5647..5c4ed6a 100644 --- a/plugins/src/backend/drivers/gitee.py +++ b/plugins/src/backend/drivers/gitee.py @@ -1,40 +1,99 @@ from loguru import logger - +import requests from backend.drivers.driver_adaptor import BaseDriver - +import json +import fnmatch class GiteeDriver(BaseDriver): def __init__(self, project_args, driver_args): - self.access_token = driver_args.driver_token - self.gitee_url = driver_args.driver_base_url - # TODO Will be added logic + super().__init__(project_args, driver_args) + self.pr_diff = project_args.pull_request.get("diff_url", None) def list_comments(self): + page = 1 all_comments = [] - # TODO Will be added list comment + while True: + url = f'{self.driver_url}/{self.owner}/{self.repo}/pulls/{self.pr_number}/comments' + params = { + 'access_token': self.access_token, + 'page': page, + 'per_page': 100, + 'direction': 'desc' + } + res = json.loads(requests.get(url=url, params=params).content.decode('utf-8')) + all_comments.extend(res) + page += 1 + if not res or len(res) < 100: + break return all_comments - def get_all_commit_ids_by_pr(self): - # TODO Will be added - return + def get_all_commit_ids(self): + url = f'{self.driver_url}/{self.owner}/{self.repo}/pulls/{self.pr_number}/commits' + params = { + 'access_token': self.access_token + } + return json.loads(requests.get(url = url, params=params).content.decode('utf-8')) def list_review_comments(self): - # TODO Will be added - return + page = 1 + all_comments = [] + while True: + url = f'{self.driver_url}/{self.owner}/{self.repo}/pulls/{self.pr_number}/comments' + params = { + 'access_token': self.access_token, + 'page': page, + 'per_page': 100, + 'direction': 'desc' + } + res = json.loads(requests.get(url = url, params=params).content.decode('utf-8')) + all_comments.extend(res) + page += 1 + if not res or len(res) < 100: + break + return all_comments def submit_comment_to_pr(self, body, commitId, filename, line): - # TODO Will be added - return + url = f'{self.driver_url}/{self.owner}/{self.repo}/pulls/{self.pr_number}/comments' + data = { + 'access_token': self.access_token, + 'body': body, + 'commit_id': commitId, + "path": filename, + "position": line + } + res = requests.post(url = url, data = data) + if res.status_code != 201: + logger.error(f'post to gitee failed: {filename}') + logger.error(res.text) + logger.error(res.status_code) + else: + logger.info(f'post to gitee succeed: {filename}') def fetch_pr(self): - # TODO Will be added - return + url = f'{self.driver_url}/{self.owner}/{self.repo}/pulls/{self.pr_number}' + params = { + 'access_token': self.access_token + } + return json.loads(requests.get(url = url, params = params).content.decode('utf-8')) + + def get_pr_diff(self): + url = self.pr_diff + params = { + 'access_token': self.access_token + } + return requests.get(url = self.pr_diff, params = params).content.decode('utf-8') + + def get_driver_type(self): + pass + + def compare(self, formerSha, latterSha): + url = f'{self.driver_url}/{self.owner}/{self.repo}/compare/{formerSha}...{latterSha}' + params = { + 'access_token': self.access_token, + 'straight': False, + } + return json.loads(requests.get(url = url, params=params).content.decode('utf-8')) - def fetch_file_content(self, rawUrl): - # TODO Will be added - return - def get_pr_diff(self, diff_url): - # TODO Will be added - return \ No newline at end of file + \ No newline at end of file diff --git a/plugins/src/backend/models/ChatCPT.py b/plugins/src/backend/models/ChatCPT.py index dff0953..978f454 100644 --- a/plugins/src/backend/models/ChatCPT.py +++ b/plugins/src/backend/models/ChatCPT.py @@ -1,30 +1,76 @@ import json - +import requests +from loguru import logger +import tiktoken from backend.models.model_adaptor import BaseModel class ChatGPTModel(BaseModel): - def __init__(self, model_args) -> None: - super().__init__(model_args) + def __init__(self, model_args, gen_args) -> None: + super().__init__(model_args, gen_args) self.model_args = model_args - # TODO Will be added - pass - - def get_token(self): - pass + self.gen_args = gen_args def chat(self, prompt): - # TODO Will be added - return + data = { + "model": "gpt-4", + "temperature": 0.05, + "top_p": 1, + "messages": [ + { + "role": "system", + "content": self.get_system_message() + }, + { + "role": "user", + "content": prompt + } + ] + } + response = requests.post( + self.get_url(), json=data + ) + if response.status_code != 200: + logger.info("get answer error") + logger.info(response.text) - def stream_chat(self, prompt): - # TODO Will be added - return + return response - def get_max_token_count(self, content): - # TODO Will be added - return + def stream_chat(self, prompt): + token = self.get_token() + if not token: + logger.error(f"Failed to get token") + return + data = { + "model": "gpt-4", + "temperature": 0.05, + "top_p": 1, + "messages": [ + { + "role": "system", + "content": self.get_system_message() + }, + { + "role": "user", + "content": prompt + } + ] + } + header = {'Authorization': token} + response = requests.post( + self.get_url(), json=data, headers=header + ) + if response.status_code != 200: + logger.info("get answer error") + logger.info(response.text) + + response = requests.post( + self.get_url(), json=data, headers=header + ) + resp = '' + for res in response.iter_lines(): + item = res.decode('utf-8') + answer = json.loads(item.split('data:')[-1]).get('answer') + resp += answer + return resp - def get_max_prompt_length(self, message): - # TODO Will be added - return \ No newline at end of file diff --git a/plugins/src/backend/models/model_adaptor.py b/plugins/src/backend/models/model_adaptor.py index afb5a4c..8aebc43 100644 --- a/plugins/src/backend/models/model_adaptor.py +++ b/plugins/src/backend/models/model_adaptor.py @@ -1,10 +1,15 @@ - +import copy +import requests +from loguru import logger +import tiktoken from abc import ABC, abstractmethod class BaseModel(ABC): - def __init__(self, model_args) -> None: + def __init__(self, model_args, gen_args) -> None: super().__init__() - + self.model_args = model_args + self.gen_args = gen_args + @abstractmethod def chat(self, prompt): pass @@ -13,12 +18,43 @@ class BaseModel(ABC): def stream_chat(self, prompt): pass - @abstractmethod - def get_max_prompt_length(self, content): - pass - - @abstractmethod - def get_max_token_count(self, message): - pass + def get_max_token_length(self): + return self.model_args.max_token_length + + def get_system_message(self): + return self.model_args.system_message + + def get_prompt(self): + return self.model_args.prompt + + def get_url(self): + return self.model_args.model_base_url + + def calculate_token_from_content(self, content): + encoding = tiktoken.get_encoding(self.model_args.encoding_name) + tokens = encoding.encode(content) + return len(tokens) + + def get_token(self): + params = { + 'grant_type': self.gen_args.grant_type, + 'app_id': self.gen_args.app_id, + 'app_secret': self.gen_args.app_secret + } + try: + resp = requests.get(url=self.gen_args.auth_url, params=params) + data = resp.json() + token = data.get('token') + return token + except Exception as e: + logger.error(f"Failed to get token: {e}") + return None + + def assemble_prompt(self, prompt_content): + prompt = copy.deepcopy(self.model_args.prompt) + for param, value in prompt_content.items(): + prompt = prompt.replace(param, value) + prompt = prompt.replace("$language", "Chinese") + return prompt diff --git a/plugins/src/backend/models/model_interface.py b/plugins/src/backend/models/model_interface.py index 1d4bc97..15ed3c1 100644 --- a/plugins/src/backend/models/model_interface.py +++ b/plugins/src/backend/models/model_interface.py @@ -4,18 +4,18 @@ from abc import abstractmethod from backend.models.ChatCPT import ChatGPTModel class ModelInterface(): - def __init__(self, model_args) -> None: + def __init__(self, model_args, gen_args) -> None: super().__init__() - self.md = self.create_chat_model(model_args) + self.md = self.create_chat_model(model_args, gen_args) - @abstractmethod - def create_chat_model(self, model_args): + @classmethod + def create_chat_model(cls, model_args, gen_args): model_mapping = { 'ChatGPT': ChatGPTModel } model_type = model_args.model_type if model_type in model_mapping: - return model_mapping[model_type](model_args) + return model_mapping[model_type](model_args, gen_args) else: raise ValueError('Invalid model_type.') diff --git a/plugins/src/common/configs/driver_args.py b/plugins/src/common/configs/driver_args.py index d8e0ee9..21d83ee 100644 --- a/plugins/src/common/configs/driver_args.py +++ b/plugins/src/common/configs/driver_args.py @@ -27,13 +27,21 @@ class DriverArguments: }, ) + # If the path rule begins with ! (!a.py), the path will not be reviewed. Vice versa. + path_rules: str = field( + default="", + metadata={ + "help": "If the path rule begins with ! (!a.py), the path will not be reviewed. Vice versa." + } + ) + def __post_init__(self): if self.driver_base_url is not None: - print("zzzz self.driver_base_url =", self.driver_base_url) - assert self.driver_base_url in ["http", "https"], "We only accept http or https address." + assert any(x in self.driver_base_url for x in [ + "http", + "https"]), "We only accept http or https address." if self.driver_type is not None: - assert self.driver_type in [ + assert any(x in self.driver_type for x in [ "gitee", - "github", - ], "We only accept `gitee` or `github` code platform type." + "github"]), "We only accept `gitee` or `github` code platform type." diff --git a/plugins/src/common/configs/generate_args.py b/plugins/src/common/configs/generate_args.py index 5713261..8c60bc1 100644 --- a/plugins/src/common/configs/generate_args.py +++ b/plugins/src/common/configs/generate_args.py @@ -4,6 +4,13 @@ from typing import Optional @dataclass class GenerateArguments: + grant_type: str = field( + default="", + metadata={ + "help": "The grant_type is used to specify the authorization type when obtaining a token." + } + ) + auth_url: str = field( default="", metadata={ diff --git a/plugins/src/common/configs/model_args.py b/plugins/src/common/configs/model_args.py index ae02c39..4f74c4d 100644 --- a/plugins/src/common/configs/model_args.py +++ b/plugins/src/common/configs/model_args.py @@ -2,7 +2,6 @@ from dataclasses import dataclass, field from typing import Optional - @dataclass class ModelArguments: # The model max token length diff --git a/plugins/src/common/configs/project_args.py b/plugins/src/common/configs/project_args.py index 6585001..1da25a4 100644 --- a/plugins/src/common/configs/project_args.py +++ b/plugins/src/common/configs/project_args.py @@ -3,6 +3,7 @@ from loguru import logger from dataclasses import dataclass, field +import functools @@ -10,58 +11,52 @@ from dataclasses import dataclass, field class ProjectArguments: action: str pull_request: dict - state: str + pull_request_state: str + pull_request_diff: str + pull_request_number: str comment: str - body: str noteable_type: str - pull_request_title: str - project: dict + owner: str + repo: str - - def validate_arguments(func): + @staticmethod + def validate_arguments(structured_data): validation_rules = { "action": lambda value: isinstance(value, str) and value != "comment", "pull_request": lambda value: isinstance(value, dict) and value is None, - "state": lambda value: isinstance(value, str) and value is None, + "pull_request_state": lambda value: isinstance(value, str) and value != "open", + "pull_request_number": lambda value: isinstance(value, str) and value is None, "comment": lambda value: isinstance(value, str) and value is None, - "body": lambda value: isinstance(value, str) and value is None, "noteable_type": lambda value: isinstance(value, str) and value != "PullRequest", + "owner": lambda value: isinstance(value, str) and value is None, + "repo": lambda value: isinstance(value, str) and value is None } - # @functools.wraps(func) - def wrapper(*args, **kwargs): - for param, rule in validation_rules.items(): - if param in kwargs and not rule(kwargs[param]): - logger.error("Invalid value for parameter '{param}'.") - return False - return func(*args, **kwargs) - - return wrapper + for param, rule in validation_rules.items(): + if param in structured_data.keys() and rule(structured_data[param]): + logger.error(f"Invalid value for parameter '{param}'.") @classmethod - @validate_arguments def from_webhook_data(cls, data): - data_mapping = { - "action": "action", - "pull_request": "pull_request", - "state": "state", - "comment": "comment", - "body": "body", - "noteable_type": "noteable_type", - "pull_request_title": "title", - "project": "project" - } - structured_data = cls.construct_from_mapping(data_mapping, data) + structured_data = cls.construct_from_mapping(data) + cls.validate_arguments(structured_data) return cls(**structured_data) @staticmethod - def construct_from_mapping(data_mapping, data): - structured_data = {} - for key, value in data_mapping.items(): - if isinstance(value, dict): - structured_data[key] = ProjectArguments.construct_from_mapping(value, data) - else: - structured_data[key] = value + def construct_from_mapping(data): + try: + structured_data = {} + structured_data["action"] = data.get("action", None) + structured_data["pull_request"] = data.get("pull_request", None) + structured_data["pull_request_state"] = structured_data["pull_request"].get("state", None) + structured_data["pull_request_diff"] = structured_data["pull_request"].get("state", None) + structured_data["pull_request_number"] = structured_data["pull_request"].get("number", None) + structured_data["comment"] = data.get("comment", None) + structured_data["noteable_type"] = data.get("noteable_type", None) + structured_data["repo"] = data.get("repository", None).get("name", None) + structured_data["owner"] = data.get("repository", None).get("owner", None).get("login", None) + except: + pass return structured_data \ No newline at end of file diff --git a/plugins/src/plugins/plugin_adaptor.py b/plugins/src/plugins/plugin_adaptor.py index 8eedc01..89f03e8 100644 --- a/plugins/src/plugins/plugin_adaptor.py +++ b/plugins/src/plugins/plugin_adaptor.py @@ -10,6 +10,7 @@ class BasePlugin(ABC): # Parse the parameters of webhook to see if they meet the requirements pass + @abstractmethod def parse_project_input(self): # If the requirements are met, the parameters are parsed into the desired structure. pass diff --git a/plugins/src/plugins/plugin_interface.py b/plugins/src/plugins/plugin_interface.py index 8244731..0dc3033 100644 --- a/plugins/src/plugins/plugin_interface.py +++ b/plugins/src/plugins/plugin_interface.py @@ -1,7 +1,5 @@ -from abc import abstractmethod - from common.configs.project_args import ProjectArguments from common.configs.generate_args import GenerateArguments from backend.drivers.driver_interface import DriverInterface @@ -15,12 +13,12 @@ class PluginInterface(): super().__init__() self.md = self.create_plugin(gen_args, project_args, driver_obj, model_obj) - @abstractmethod - def create_plugin(self, gen_args, project_args, driver_obj, model_obj): + @classmethod + def create_plugin(cls, gen_args, project_args, driver_obj, model_obj): plugin_mapping = { 'pr_review': CodeReviewPlugin } - plugin_type = gen_args.plugin_type() + plugin_type = gen_args.plugin_type if plugin_type in plugin_mapping: return plugin_mapping[plugin_type](project_args, driver_obj, model_obj) else: diff --git a/plugins/src/plugins/pr_review/commenter.py b/plugins/src/plugins/pr_review/commenter.py new file mode 100644 index 0000000..e6fbb46 --- /dev/null +++ b/plugins/src/plugins/pr_review/commenter.py @@ -0,0 +1,95 @@ +from loguru import logger +from backend.drivers.gitee import GiteeDriver +import json + +class Commenter: + + STATUS_START_TAG = "" + + STATUS_END_TAG = "" + + COMMENT_REPLY_TAG = "" + + def __init__(self, driver_obj: GiteeDriver): + self.driver_obj = driver_obj + + def list_comments(self): + return self.driver_obj.list_comments() + + def find_comment_with_tag(self, tag): + comments = self.list_comments() + for comment in comments: + if comment.get('body', None) and tag in comment.get('body', None): + return comment + return None + + def get_all_commit_ids(self): + commits = self.driver_obj.get_all_commit_ids() + allCommits = [] + for commit in commits: + allCommits.append(commit.get('sha', None)) + return allCommits + + def get_reviewed_commit_ids(self, comment_body): + start_idx = comment_body.find(Commenter.STATUS_START_TAG) + len(Commenter.STATUS_START_TAG) + end_idx = comment_body.find(Commenter.STATUS_END_TAG) + dic_str = comment_body[start_idx: end_idx] + dic = json.loads(dic_str) + commit_ids = dic.get("commit_ids", "") + if commit_ids: + return commit_ids + return None + + def get_highest_reviewed_commit_id(self, commitIds, reviewedCommitIds): + for i in range(len(commitIds)): + if commitIds[i] in reviewedCommitIds: + return commitIds[i] + return '' + + def get_comment_chains_within_range(self, path, startLine, endLine, tag = ''): + existing_comments = self.get_comments_within_range(path, startLine, endLine) + top_level_comments = [] + for comment in existing_comments: + if not comment.get('in_reply_to_id', None): + top_level_comments.append(comment) + all_chains = '' + chain_num = 0 + for top_level_comment in top_level_comments: + chain = self.compose_comment_chain(existing_comments, top_level_comment) + if chain and tag in chain: + chain = chain.replace(tag, "") + chain_num += 1 + all_chains += 'Conversation Chain {}: {}\n---\n'.format(chain_num, chain) + return all_chains + + def compose_comment_chain(self, review_comments, top_level_comment): + try: + conversation_chain = ['{}: {}'.format(top_level_comment.get('user', None).get('login', None), top_level_comment.get('body', None))] + for comment in review_comments: + if comment.get('in_reply_to_id', None) == top_level_comment.get('id', None): + conversation_chain.append('{}: {}'.format(comment.get('user', None).get('login', None), comment.get('body', None))) + except Exception as e: + logger.info('compose_comment_chain failed', e) + + return '\n---\n'.join(conversation_chain) + + def get_comments_within_range(self, path, startLine, endLine): + comments = self.list_review_comments() + required_comments = [] + try: + for comment in comments: + if comment.get('path', None) == path and \ + comment.get('body', None) and \ + comment.get('comment_type', None) == 'diff_comment' and \ + comment.get('new_line') and \ + comment.get('new_line') >= startLine and \ + comment.get('new_line') <= endLine: + required_comments.append(comment) + except Exception as e: + logger.info('get comments within range failed', e) + return required_comments + + def list_review_comments(self): + return self.driver_obj.list_review_comments() + + diff --git a/plugins/src/plugins/pr_review/pr_review.py b/plugins/src/plugins/pr_review/pr_review.py index e171baa..617c5e1 100644 --- a/plugins/src/plugins/pr_review/pr_review.py +++ b/plugins/src/plugins/pr_review/pr_review.py @@ -1,19 +1,255 @@ - from plugins.plugin_adaptor import BasePlugin - +from backend.drivers.driver_adaptor import BaseDriver +from backend.models.model_adaptor import BaseModel +from plugins.pr_review.commenter import Commenter +from loguru import logger +import re +import copy +import json class CodeReviewPlugin(BasePlugin): def __init__(self, project_args=None, driver_obj=None, model_obj=None): - self.driver_obj = driver_obj - # TODO Will be added + self.project_args = project_args + self.driver_obj: BaseDriver = driver_obj + self.model_obj: BaseModel = model_obj + self.commenter = Commenter(self.driver_obj) + self.prompt_content = {} + self.reviewed_commit_ids = "" + self.current_commit_id = "" + + def run(self): + if not self.check_project_args(): + return + + # Get the comments that already exist in this PR + existing_commit_ids = self.get_existing_comments() + self.reviewed_commit_ids = existing_commit_ids + + # Get commit ids which will be compared to generate patches + head_sha, base_sha, highest_reviewed_commit_id = self.get_commit_ids(existing_commit_ids) + + if (not highest_reviewed_commit_id) or highest_reviewed_commit_id == head_sha: + logger.info('will review from the base commit: {}'.format(base_sha)) + highest_reviewed_commit_id = base_sha + else: + logger.info('will review from commit: {}'.format(highest_reviewed_commit_id)) + + logger.info(f'base_sha = {base_sha}, head_sha = {head_sha}, highest_reviewed_commit_id = {highest_reviewed_commit_id}') + + # Get the patches generated by three commits + incremental_files, target_branch_files = self.generate_patches_from_commits(head_sha, base_sha, highest_reviewed_commit_id) + + if (not incremental_files) and (not target_branch_files): + logger.warning('skipped: files data is missing') + return + + if not self.current_commit_id: + logger.warning('skipped: commits is null') + return + + # Merge incremental_files and target_branch_files into patch_files + patch_files = self.merge_patches(incremental_files, target_branch_files) + + if len(patch_files) == 0: + logger.warning('skipped: patch files is null') + return + + #Filter patches by paths + filter_selected_files, filter_ignored_files = self.filter_patches_by_paths(patch_files) + + if len(filter_selected_files) == 0: + logger.warning('skipped: filter_selected_files is null') + return + + filtered_files, file_comment_line = self.get_hunks_from_patches(filter_selected_files) + + if len(filtered_files) == 0: + logger.error('skipped: no files to review') + return + + hunk_answers = self._code_review(filtered_files, file_comment_line) + + if len(hunk_answers) == 0: + logger.error('skipped: no files be reviewed by AI') + return + + + commit_ids = self.driver_obj.assemble_commit_ids(self.current_commit_id, self.reviewed_commit_ids) + self.submit_res(hunk_answers, commit_ids) + + def _delta(self, line_numbers, step): + res = {} + for key in line_numbers: + res.update({key: line_numbers[key] + step}) + return res + + def _parse(self, diff, patch_lines): + lines = diff.split('\n') + + line_numbers = {} + current_line_number = 0 + new_line = 0 + + for line in lines: + current_line_number += 1 + if line.startswith('@@'): + continue + if line.startswith('-'): + continue + if line == '\\ No newline at end of file': + continue + else: + line_numbers.update({patch_lines['newHunk']['startLine'] + new_line: current_line_number}) + new_line += 1 + + return line_numbers + + def _get_diff_new_line_dic(self, file_patches): + i = 0 + step = [0] + diff_new_line_dic = {} + patches = self._split_patch(file_patches) + try: + for patch in patches: + patch_lines = self._patch_start_end_line(patch) + diff_new_line_dic.update(self._delta(self._parse(patch, patch_lines), step[i])) + end_line = patch_lines['newHunk']['startLine'] + patch_lines['newHunk']['endLine'] - 1 + step.append(diff_new_line_dic.get(end_line)) + i += 1 + except Exception as e: + logger.error(e) + + # If there are modifications in the first four lines, + # the first line "@@" will not be displayed, + # all line numbers need to be decremented by 1. + if self.modifications_position(file_patches): + for key in diff_new_line_dic: + diff_new_line_dic.update({key: diff_new_line_dic[key] - 1}) + + return diff_new_line_dic + + # Modifications were made to the first four lines of the file or not + def modifications_position(self, patch): + patch_lines = self._patch_start_end_line(patch) + if patch_lines['oldHunk']['startLine'] <= 4 or patch_lines['newHunk']['startLine'] <= 4: + return True + return False + + # Get the comments that already exist in this PR + def get_existing_comments(self): + existing_comment = self.commenter.find_comment_with_tag(Commenter.STATUS_START_TAG) + if existing_comment: + existing_comment_body = existing_comment.get('body', None) + existing_commit_ids = self.commenter.get_reviewed_commit_ids(existing_comment_body) + return existing_commit_ids + return "" + + # get commit ids which will be compared to generate patches + def get_commit_ids(self, existing_commit_ids): + all_commit_ids = self.commenter.get_all_commit_ids() + + highest_reviewed_commit_id = '' + if existing_commit_ids: + highest_reviewed_commit_id = self.commenter.get_highest_reviewed_commit_id(all_commit_ids, existing_commit_ids) + + prs = self.driver_obj.fetch_pr() + + head_sha = None + base_sha = None + if prs.get('head', None): + head_sha = prs.get('head', None).get('sha', None) + if prs.get('base', None): + base_sha = prs.get('base', None).get('sha', None) + + return head_sha, base_sha, highest_reviewed_commit_id + + # Get the patches generated by different commits + def generate_patches_from_commits(self, head_sha, base_sha, highest_reviewed_commit_id): + incremental_diff = self.driver_obj.compare(highest_reviewed_commit_id, head_sha) + target_branch_diff = self.driver_obj.compare(base_sha, head_sha) + + incremental_files = incremental_diff.get('files', None) + target_branch_files = target_branch_diff.get('files', None) + + commits = [] + if incremental_diff.get('commits', None): + for commit in incremental_diff.get('commits', None): + commits.append(commit.get('sha', None)) + + self.current_commit_id = commits[0] + + return incremental_files, target_branch_files + + # Merge incremental_files and target_branch_files into patch_files + def merge_patches(self, incremental_files, target_branch_files): + incremental_files_names = [] + for incremental_file in incremental_files: + if incremental_file.get('filename', None): + incremental_files_names.append(incremental_file.get('filename')) + + diff_files = self.driver_obj.get_pr_diff().split('diff --git') + diff_files.pop(0) + pr_files = [dif.split('\n')[0].split(' ')[-1][2::] for dif in diff_files if dif.strip()] + + patch_files = [] + for target_branch_file in target_branch_files: + if target_branch_file.get('filename', None) in incremental_files_names and target_branch_file.get('filename', None) in pr_files: + patch_files.append(target_branch_file) + + return patch_files + + def filter_patches_by_paths(self, patch_files): + filter_selected_files = [] + filter_ignored_files = [] + + for a_file in patch_files: + if not self.driver_obj.check_path(a_file.get('filename', None)): + filter_selected_files.append(a_file) + else: + logger.info('skip for excluded path: %s'%(a_file.get('filename', None))) + filter_ignored_files.append(a_file) + + return filter_selected_files, filter_ignored_files - def get_diff_new_line_dic(self): - # TODO Will be added - return + def get_hunks_from_patches(self, filter_selected_files): + filtered_files = [] + file_comment_line = {} + # cut patch into hunks + for a_file in filter_selected_files: + if not self.driver_obj.pr_number: + logger('skipped: pr is null') + continue + file_diff = a_file.get('patch') + if not a_file.get('patch'): + logger.info(f"{a_file.get('filename')} has no patch") + continue + file_comment_line[a_file.get('filename')] = self._get_diff_new_line_dic(file_patches=file_diff) + patches = [] + diff_num = 0 + for patch in self._split_patch(a_file.get('patch', '')): + diff_num += 1 + patch_lines = self._patch_start_end_line(patch) + if not patch_lines: + continue + hunks = self._parse_patch(patch) + if not hunks: + continue + + hunks_str = '''---new_hunk---\n\'\'\'\n%s\n\'\'\'\n---old_hunk---\n\'\'\'\n%s\n\'\'\''''%(hunks.get('newHunk', None), hunks.get('oldHunk', None)) + patches.append([patch_lines.get('newHunk', None).get('startLine', None), patch_lines.get('newHunk', None).get('endLine', None), hunks_str]) + if len(patches) > 0: + filtered_files.append([a_file.get('filename', None), file_diff, patches]) + + return filtered_files, file_comment_line + def check_project_args(self): - # Parse the parameters of webhook to see if they meet the requirements - pass + comment_body = self.project_args.comment.get("body", "").strip() + if not comment_body.startswith("/pr-review"): + logger.info("skipped: it is not /pr-review") + return False + + return True def parse_project_input(self): # If the requirements are met, the parameters are parsed into the desired structure. @@ -31,25 +267,168 @@ class CodeReviewPlugin(BasePlugin): # Parse the return value of the model pass - def submit_res(self): - # Perform the desired result based on the model return value - pass + def submit_res(self, hunk_answers, commit_ids): + # Submit the comment to pr + review_status = {"filenames": [], "commit_ids": commit_ids} + for hunk_answer in hunk_answers: + lgtm, body, filename, comment_line = hunk_answer[0] + if lgtm: + continue + body = f"{Commenter.COMMENT_REPLY_TAG}{body}" + self.driver_obj.submit_comment_to_pr(body = body, commitId = self.current_commit_id, filename = filename, line = comment_line) + review_status["filenames"].append(filename) + review_status_json = json.dumps(review_status) + body = f"PR Review completed.{Commenter.STATUS_START_TAG}{review_status}{Commenter.STATUS_END_TAG}" + body = body.replace("'", "\"") + self.driver_obj.submit_comment_to_pr(body = body, commitId = None, filename = None, line = None) + + def _code_review(self, filtered_files, file_comment_line): + hunk_answers = [] + for filename, _, patches in filtered_files: + review_answers =self._do_review(filename, patches, file_comment_line[filename]) + for review_answer in review_answers: + hunk_answer = self.parse_result(review_answer) + hunk_answers.append(hunk_answer) + return hunk_answers - def _code_review(self): - return + def _split_patch(self, patch): + if not patch: + return [] + results = [] + split_lines = patch.split('\n') + split_lines = split_lines[:-1] + last_line = -1 + for a_line in range(len(split_lines)): + # whether current line matches format: @@ -0,0 +0,0 @@ + re_split = re.split('^@@ -(\d+),(\d+) \+(\d+),(\d+) @@', split_lines[a_line]) + if len(re_split) > 1: + if last_line == -1: + last_line = a_line + else: + results.append('\n'.join(split_lines[last_line: a_line])) + last_line = a_line + if last_line != -1: + results.append('\n'.join(split_lines[last_line:])) + return results + + def _patch_start_end_line(self, patch): + re_split = re.split('^@@ -(\d+),(\d+) \+(\d+),(\d+) @@', patch) + if len(re_split) > 1: + old_begin = int(re_split[1]) + old_diff = int(re_split[2]) + new_begin = int(re_split[3]) + new_diff = int(re_split[4]) + return {'oldHunk': {'startLine': old_begin, 'endLine': old_diff}, 'newHunk': {'startLine': new_begin, 'endLine': new_diff}} + else: + return None + + def _parse_patch(self, patch): + hunk_info = self._patch_start_end_line(patch) + if not hunk_info: + return + + old_hunk_lines = [] + new_hunk_lines = [] + + new_line = hunk_info["newHunk"]["startLine"] + + lines = patch.split('\n')[1:] + + # Remove the last line if it's empty + if lines[-1] == '': + lines.pop() + + # Skip annotations for the first 3 and last 3 lines + skip_start = 3 + skip_end = 3 + + current_line = 0 + + removal_only = not any(line.startswith('+') for line in lines) + + for line in lines: + if line == '\\ No newline at end of file': + continue + current_line += 1 + if line.startswith('-'): + old_hunk_lines.append(line[1:]) + elif line.startswith('+'): + new_hunk_lines.append(f"{new_line}: {line[1:]}") + new_line += 1 + else: + old_hunk_lines.append(line) + if removal_only or (current_line > skip_start and current_line <= len(lines) - skip_end): + new_hunk_lines.append(f"{new_line}: {line}") + else: + new_hunk_lines.append(line) + new_line += 1 + + return { + "oldHunk": '\n'.join(old_hunk_lines), + "newHunk": '\n'.join(new_hunk_lines) + } + + def _do_review(self, filename, patches, diff_new_line_dic): + logger.info('reviewing: {}'.format(filename)) + + prompt_content_for_patch = copy.deepcopy(self.prompt_content) - def _splitPatch(self): - # TODO Will be added - return + tokens = self.model_obj.calculate_token_from_content(self.model_obj.get_system_message()) + + patches_to_pack = 0 + for _, _, patch in patches: + patch_tokens = self.model_obj.calculate_token_from_content(patch) + if tokens + patch_tokens > self.model_obj.get_max_token_length(): + logger.info('only packing {}/{} patches, tokens: {}/{}'.format(patches_to_pack, len(patches), tokens, self.model_obj.get_max_token_length())) + break + patches_to_pack += 1 - def _patchStartEndLine(self): - # TODO Will be added - return + review_answers = [] + patches_packed = 0 + for startLine, endLine, patch in patches: + if patches_packed >= patches_to_pack: + logger.info('unable to pack more patches into this request, packed: {}, total patches: {}, skipping'.format(patches_packed, len(patches))) + patches_packed += 1 + comment_chain = '' + all_chains = self.commenter.get_comment_chains_within_range(filename, startLine, endLine, self.commenter.COMMENT_REPLY_TAG) + if len(all_chains) > 0: + logger.info('Found comment chains: {} for {}'.format(all_chains, filename)) + comment_chain = all_chains + comment_chain_tokens = self.model_obj.calculate_token_from_content(comment_chain) + if tokens + comment_chain_tokens > self.model_obj.get_max_token_length(): + comment_chain = '' + else: + tokens += comment_chain_tokens + + prompt_content_for_patch["$filename"] = filename + prompt_content_for_patch["$patches"] = patch + if comment_chain: + prompt_content_for_patch["$patches"] += '---comment_chains---\n\'\'\'{}\'\'\'---end_change_section---'.format(comment_chain) + if patches_packed > 0: + messages = self.model_obj.assemble_prompt(prompt_content_for_patch) + res = self.model_obj.chat(messages) - def _parsePatch(self): - # TODO Will be added - return + if res.status_code != 200: + logger.info('review: nothing obtained from openai') + return '{} (no response)'.format(filename) + + answer = res.json() + review_answers.append([diff_new_line_dic, filename, answer]) + + return review_answers - def _do_review(self): - # TODO Will be added - return \ No newline at end of file + def parse_result(self, review_answer): + diff_new_line_dic, filename, answer = review_answer + lgtm = True + result = re.split('(\d+-\d+:)', answer) + result = [x.strip() for x in result if x.strip()] + hunk_answer = [] + for i in range(len(result) - 1): + if re.match('\d+-\d+:', result[i]): + line = int(result[i].split('-')[-1][:-1]) + comment_line = diff_new_line_dic[int(line)] + + lgtm = ('LGTM' in result[i + 1]) + hunk_answer.append([lgtm, result[i+1], filename, comment_line]) + + return hunk_answer \ No newline at end of file diff --git a/plugins/src/task/manage.py b/plugins/src/task/manage.py index 638da7f..7a69878 100644 --- a/plugins/src/task/manage.py +++ b/plugins/src/task/manage.py @@ -1,7 +1,6 @@ import os import yaml - from loguru import logger from common.configs.driver_args import DriverArguments @@ -11,17 +10,17 @@ from common.configs.project_args import ProjectArguments from backend.models.model_interface import ModelInterface from backend.drivers.driver_interface import DriverInterface from plugins.plugin_interface import PluginInterface +from backend.drivers.driver_adaptor import BaseDriver def init_plugin(data): gen_args, driver_args, model_args = init_for_config() - model_inter = ModelInterface.create_chat_model(model_args) - driver_inter = DriverInterface.create_driver(driver_args) - + model_inter = ModelInterface.create_chat_model(model_args, gen_args) project_args = ProjectArguments.from_webhook_data(data) if project_args is False: logger.error("Failed to get webhook arguments.") return + driver_inter = DriverInterface.create_driver(project_args, driver_args) plugin_inter = PluginInterface.create_plugin( gen_args, project_args, driver_inter, model_inter) return plugin_inter @@ -46,4 +45,4 @@ def init_for_config(): def assgin_task(data): plugin_inter = init_plugin(data) - plugin_inter.submit_res() + plugin_inter.run() diff --git a/plugins/src/task/router/router.py b/plugins/src/task/router/router.py index d6b6e45..9ae83d0 100644 --- a/plugins/src/task/router/router.py +++ b/plugins/src/task/router/router.py @@ -1,4 +1,3 @@ - import threading from flask import request from flask import Flask -- Gitee