#include<bits/stdc++.h>
typedef long long ll;
#define pb push_back
#define fr first
#define sc second
#define endl '\n'
using namespace std;
#define mid ((left+right)>>1)
#define int ll
const ll mod=1e9+7;
ll sum(ll x,ll y){
return (x+y)%mod;
}
ll mul(ll x,ll y){
x%=mod;y%=mod;
return (x*y)%mod;
}
ll fpow(ll x,ll y){
ll res=1;
while(y>0){
if(y&1){
res=mul(res,x);
}
x=mul(x,x);
y>>=1;
}
return res;
}
int n;
ll d;
vector<int>komsu[100023];
int dp[100023],cnt[100023];
int dp2[100023],cnt2[100023];
int dp3[100023],cnt3[100023];
int par[100023];
ll ans=0;
void dfs2(int pos){
dp[pos]=dp2[pos];
cnt[pos]=cnt2[pos];
if(par[pos]){
dp3[pos]=dp[par[pos]];
cnt3[pos]=cnt[par[pos]];
if(dp2[pos]==0){
dp3[pos]--;
if(dp3[pos]==0){
cnt3[pos]=0;
for(int x:komsu[par[pos]]){
if(x==pos)continue;
if(x==par[par[pos]])continue;
cnt3[pos]+=cnt2[x];
}
if(par[par[pos]]){
cnt3[pos]+=cnt3[par[pos]];
}
}
else if(dp3[pos]==1){
for(int x:komsu[par[pos]]){
if(x==pos)continue;
if(x==par[par[pos]])continue;
if(dp2[x]==0){
cnt3[pos]+=cnt2[x];
}
}
if(par[par[pos]]&&dp3[par[pos]]==0){
cnt3[pos]+=cnt3[par[pos]];
}
}
}
else if(dp3[pos]==0){
cnt3[pos]-=cnt2[pos];
}
dp[pos]=!dp3[pos];
cnt[pos]=0;
for(int x:komsu[pos]){
if(x==par[pos])continue;
dp[pos]+=!dp2[x];
}
if(dp[pos]==0){
cnt[pos]++;
}
for(int x:komsu[pos]){
if(x==par[pos]){
if(dp[pos]==1){
if(dp3[pos]==0){
cnt[pos]+=cnt3[pos];
}
}
else if(dp[pos]==0){
cnt[pos]+=cnt3[pos];
}
continue;
}
if(dp[pos]==1){
if(dp2[x]==0){
cnt[pos]+=cnt2[x];
}
}
else if(dp[pos]==0){
cnt[pos]+=cnt2[x];
}
}
}
for(int x:komsu[pos]){
if(x==par[pos])continue;
dfs2(x);
}
}
void dfs(int pos){
for(int x:komsu[pos]){
if(x==par[pos])continue;
par[x]=pos;
dfs(x);
dp2[pos]+=!dp2[x];
}
if(dp2[pos]>1)return;
if(dp2[pos]==0)cnt2[pos]++;
for(int x:komsu[pos]){
if(x==par[pos])continue;
if(dp2[pos]==1){
if(dp2[x]==0){
cnt2[pos]+=cnt2[x];
}
}
else{
cnt2[pos]+=cnt2[x];
}
}
}
const bool deb=false;
int32_t main(){
ios_base::sync_with_stdio(23^23);cin.tie(NULL);
cin>>n>>d;
for(int i=1;i<n;i++){
int x,y;cin>>x>>y;
komsu[x].pb(y);
komsu[y].pb(x);
}
dfs(1);
dfs2(1);
int ez=0;
ans=0;
ll kazan[2]={0,0};
ll devam[2]={0,0};
ll iht[2]={0,0};
if(dp[1]){
ans=mul(n-cnt[1],fpow(n,2*d-1));
iht[1]=cnt[1];
}
else{
iht[0]=cnt[1];
}
for(int i=1;i<=n;i++){
if(deb)cout<<dp[i]<<"-"<<cnt[i]<<endl;
if(dp[i]){
ez++;
kazan[1]=sum(kazan[1],n-cnt[i]);
devam[0]=sum(devam[0],cnt[i]);
}
else{
kazan[0]=sum(kazan[0],n-cnt[i]);
devam[1]=sum(devam[1],cnt[i]);
}
}
if(deb)cout<<ans<<endl;
if(deb)cout<<"devam: "<<devam[0]<<" "<<devam[1]<<endl;
if(deb)cout<<"kazan: "<<kazan[0]<<" "<<kazan[1]<<endl;
for(int i=1;i<d;i++){
if(deb)cout<<"iht: "<<iht[0]<<" "<<iht[1]<<endl;
ans=sum(ans,mul(fpow(n,(d-i)*2-1),
sum(mul(kazan[1],iht[1]),mul(kazan[0],iht[0]))));
ll iht2[2]={0,0};
iht2[0]=sum(mul(iht[0],devam[0]),mul(iht[1],devam[1]));
iht2[1]=sum(mul(iht[0],devam[1]),mul(iht[1],devam[0]));
iht[0]=iht2[0];
iht[1]=iht2[1];
if(deb)cout<<ans<<endl;
}
if(deb)cout<<"iht: "<<iht[0]<<" "<<iht[1]<<endl;
ans=sum(ans,sum(mul(iht[0],n-ez),mul(iht[1],ez)));
cout<<ans<<endl;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |